source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 8004

Last change on this file since 8004 was 7749, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 18.7 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TRandom.h>
45
46#include "MHMatrix.h"
47#include "MRanTree.h"
48#include "MData.h"
49#include "MDataArray.h"
50#include "MParList.h"
51
52#include "MArrayI.h"
53#include "MArrayF.h"
54
55#include "MLog.h"
56#include "MLogManip.h"
57
58ClassImp(MRanForest);
59
60using namespace std;
61
62// --------------------------------------------------------------------------
63//
64// Default constructor.
65//
66MRanForest::MRanForest(const char *name, const char *title)
67 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
68 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
69{
70 fName = name ? name : "MRanForest";
71 fTitle = title ? title : "Storage container for Random Forest";
72
73 fForest=new TObjArray();
74 fForest->SetOwner(kTRUE);
75}
76
77MRanForest::MRanForest(const MRanForest &rf)
78{
79 // Copy constructor
80 fName = rf.fName;
81 fTitle = rf.fTitle;
82
83 fClassify = rf.fClassify;
84 fNumTrees = rf.fNumTrees;
85 fNumTry = rf.fNumTry;
86 fNdSize = rf.fNdSize;
87 fTreeNo = rf.fTreeNo;
88 fRanTree = NULL;
89
90 fRules=new MDataArray();
91 fRules->Reset();
92
93 MDataArray *newrules=rf.fRules;
94
95 for(Int_t i=0;i<newrules->GetNumEntries();i++)
96 {
97 MData &data=(*newrules)[i];
98 fRules->AddEntry(data.GetRule());
99 }
100
101 // trees
102 fForest=new TObjArray();
103 fForest->SetOwner(kTRUE);
104
105 TObjArray *newforest=rf.fForest;
106 for(Int_t i=0;i<newforest->GetEntries();i++)
107 {
108 MRanTree *rantree=(MRanTree*)newforest->At(i);
109
110 MRanTree *newtree=new MRanTree(*rantree);
111 fForest->Add(newtree);
112 }
113
114 fHadTrue = rf.fHadTrue;
115 fHadEst = rf.fHadEst;
116 fDataSort = rf.fDataSort;
117 fDataRang = rf.fDataRang;
118 fClassPop = rf.fClassPop;
119 fWeight = rf.fWeight;
120 fTreeHad = rf.fTreeHad;
121
122 fNTimesOutBag = rf.fNTimesOutBag;
123
124}
125
126// --------------------------------------------------------------------------
127// Destructor.
128MRanForest::~MRanForest()
129{
130 delete fForest;
131 if (fMatrix)
132 delete fMatrix;
133 if (fRules)
134 delete fRules;
135}
136
137void MRanForest::Print(Option_t *o) const
138{
139 *fLog << inf << GetDescriptor() << ": " << endl;
140 MRanTree *t = GetTree(0);
141 if (t)
142 {
143 *fLog << "Setting up RF for training on target:" << endl;
144 *fLog << " " << t->GetTitle() << endl;
145 }
146 if (fRules)
147 {
148 *fLog << "Following rules are used as input to RF:" << endl;
149 for (Int_t i=0;i<fRules->GetNumEntries();i++)
150 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
151 }
152 *fLog << "Random forest parameters:" << endl;
153 if (t)
154 {
155 *fLog << " - " << (t->IsClassify()?"classification":"regression") << " tree" << endl;
156 *fLog << " - Number of trys: " << t->GetNumTry() << endl;
157 *fLog << " - Node size: " << t->GetNdSize() << endl;
158 }
159 *fLog << " - Number of trees: " << fNumTrees << endl;
160 *fLog << " - User value: " << fUserVal << endl;
161 *fLog << endl;
162}
163
164void MRanForest::SetNumTrees(Int_t n)
165{
166 //at least 1 tree
167 fNumTrees=TMath::Max(n,1);
168}
169
170void MRanForest::SetNumTry(Int_t n)
171{
172 fNumTry=TMath::Max(n,0);
173}
174
175void MRanForest::SetNdSize(Int_t n)
176{
177 fNdSize=TMath::Max(n,1);
178}
179
180void MRanForest::SetWeights(const TArrayF &weights)
181{
182 fWeight=weights;
183}
184
185void MRanForest::SetGrid(const TArrayD &grid)
186{
187 const int n=grid.GetSize();
188
189 for(int i=0;i<n-1;i++)
190 if(grid[i]>=grid[i+1])
191 {
192 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
193 return;
194 }
195
196 fGrid=grid;
197
198 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
199 //for(int i=0;i<n;i++)
200 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
201}
202
203MRanTree *MRanForest::GetTree(Int_t i) const
204{
205 return static_cast<MRanTree*>(fForest->UncheckedAt(i));
206}
207
208Int_t MRanForest::GetNumDim() const
209{
210 return fMatrix ? fMatrix->GetNcols() : 0;
211}
212
213Int_t MRanForest::GetNumData() const
214{
215 return fMatrix ? fMatrix->GetNrows() : 0;
216}
217
218Int_t MRanForest::GetNclass() const
219{
220 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
221
222 return int(fClass[maxidx])+1;
223}
224
225void MRanForest::PrepareClasses()
226{
227 const int numdata=fHadTrue.GetSize();
228
229 if(fGrid.GetSize()>0)
230 {
231 // classes given by grid
232 const int ngrid=fGrid.GetSize();
233
234 for(int j=0;j<numdata;j++)
235 {
236 // Array is supposed to be sorted prior to this call.
237 // If match is found, function returns position of element.
238 // If no match found, function gives nearest element smaller
239 // than value.
240 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
241
242 fClass[j] = k;
243 }
244
245 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
246 int min = fClass[minidx];
247 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
248
249 }else{
250 // classes directly given
251 for (Int_t j=0;j<numdata;j++)
252 fClass[j] = TMath::Nint(fHadTrue[j]);
253 }
254}
255
256Double_t MRanForest::CalcHadroness()
257{
258 TVector event;
259 *fRules >> event;
260
261 return CalcHadroness(event);
262}
263
264Double_t MRanForest::CalcHadroness(const TVector &event)
265{
266 fTreeHad.Set(fNumTrees);
267
268 Double_t hadroness=0;
269 Int_t ntree =0;
270
271 TIter Next(fForest);
272
273 MRanTree *tree;
274 while ((tree=(MRanTree*)Next()))
275 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
276
277 return hadroness/ntree;
278}
279
280Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
281{
282 fRanTree = rantree ? rantree : fRanTree;
283
284 if (!fRanTree) return kFALSE;
285
286 MRanTree *newtree=new MRanTree(*fRanTree);
287 fForest->Add(newtree);
288
289 return kTRUE;
290}
291
292Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
293{
294 //-------------------------------------------------------------------
295 // access matrix, copy last column (target) preliminarily
296 // into fHadTrue
297 if (fMatrix)
298 delete fMatrix;
299 fMatrix = new TMatrix(mat->GetM());
300
301 int dim = fMatrix->GetNcols()-1;
302 int numdata = fMatrix->GetNrows();
303
304 fHadTrue.Set(numdata);
305 fHadTrue.Reset();
306
307 for (Int_t j=0;j<numdata;j++)
308 fHadTrue[j] = (*fMatrix)(j,dim);
309
310 // remove last col
311 fMatrix->ResizeTo(numdata,dim);
312
313 //-------------------------------------------------------------------
314 // setup labels for classification/regression
315 fClass.Set(numdata);
316 fClass.Reset();
317
318 if (fClassify)
319 PrepareClasses();
320
321 //-------------------------------------------------------------------
322 // allocating and initializing arrays
323 fHadEst.Set(numdata);
324 fHadEst.Reset();
325
326 fNTimesOutBag.Set(numdata);
327 fNTimesOutBag.Reset();
328
329 fDataSort.Set(dim*numdata);
330 fDataSort.Reset();
331
332 fDataRang.Set(dim*numdata);
333 fDataRang.Reset();
334
335 Bool_t useweights = fWeight.GetSize()==numdata;
336 if (!useweights)
337 {
338 fWeight.Set(numdata);
339 fWeight.Reset(1.);
340 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
341 }
342
343 //-------------------------------------------------------------------
344 // setup rules to be used for classification/regression
345 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
346 if (allrules==NULL)
347 {
348 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
349 return kFALSE;
350 }
351
352 if (allrules->GetNumEntries()!=dim+1)
353 {
354 *fLog << err <<"Rules of matrix " << allrules->GetNumEntries() << " mismatch dimension+1 " << dim+1 << "...exiting."<< endl;
355 return kFALSE;
356 }
357
358 if (fRules)
359 delete fRules;
360
361 fRules = new MDataArray();
362 fRules->Reset();
363
364 const TString target_rule = (*allrules)[dim].GetRule();
365 for (Int_t i=0;i<dim;i++)
366 fRules->AddEntry((*allrules)[i].GetRule());
367
368 *fLog << inf << endl;
369 *fLog << "Setting up RF for training on target:" << endl;
370 *fLog << " " << target_rule.Data() << endl;
371 *fLog << "Following rules are used as input to RF:" << endl;
372 for (Int_t i=0;i<dim;i++)
373 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
374 *fLog << endl;
375
376 //-------------------------------------------------------------------
377 // prepare (sort) data for fast optimization algorithm
378 if (!CreateDataSort())
379 return kFALSE;
380
381 //-------------------------------------------------------------------
382 // access and init tree container
383 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
384 if(!fRanTree)
385 {
386 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
387 return kFALSE;
388 }
389 //fRanTree->SetName(target_rule); // Is not stored anyhow
390
391 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
392
393 *fLog << inf << endl;
394 *fLog << "Following input for the tree growing are used:"<<endl;
395 *fLog << " Forest type : "<<(fClassify?"classification":"regression")<<endl;
396 *fLog << " Number of Trees : "<<fNumTrees<<endl;
397 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
398 *fLog << " Final Node size : "<<fNdSize<<endl;
399 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
400 *fLog << " Using Weights : "<<(useweights?"Yes":"No")<<endl;
401 *fLog << " Number of Events: "<<numdata<<endl;
402 *fLog << " Number of Params: "<<dim<<endl;
403
404 if(fNumTry==0)
405 {
406 fNumTry=tryest;
407 *fLog << inf << endl;
408 *fLog << "Set no. of trials to the recommended value of round(";
409 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
410 }
411
412 fRanTree->SetNumTry(fNumTry);
413 fRanTree->SetClassify(fClassify);
414 fRanTree->SetNdSize(fNdSize);
415
416 fTreeNo=0;
417
418 return kTRUE;
419}
420
421Bool_t MRanForest::GrowForest()
422{
423 if(!gRandom)
424 {
425 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
426 return kFALSE;
427 }
428
429 fTreeNo++;
430
431 //-------------------------------------------------------------------
432 // initialize running output
433
434 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
435 Bool_t calcResolution=(minfloat>0.001);
436
437 if (fTreeNo==1)
438 {
439 *fLog << inf << endl << underline;
440
441 if(calcResolution)
442 *fLog << "TreeNum BagSize NumNodes TestSize res/% (from oob-data -> overest. of error)" << endl;
443 else
444 *fLog << "TreeNum BagSize NumNodes TestSize rms/% (from oob-data -> overest. of error)" << endl;
445 // 12345678901234567890123456789012345678901234567890
446 }
447
448 const Int_t numdata = GetNumData();
449 const Int_t nclass = GetNclass();
450
451 //-------------------------------------------------------------------
452 // bootstrap aggregating (bagging) -> sampling with replacement:
453
454 MArrayF classpopw(nclass);
455 MArrayI jinbag(numdata); // Initialization includes filling with 0
456 MArrayF winbag(numdata); // Initialization includes filling with 0
457
458 float square=0;
459 float mean=0;
460
461 for (Int_t n=0; n<numdata; n++)
462 {
463 // The integer k is randomly (uniformly) chosen from the set
464 // {0,1,...,numdata-1}, which is the set of the index numbers of
465 // all events in the training sample
466
467 const Int_t k = gRandom->Integer(numdata);
468
469 if(fClassify)
470 classpopw[fClass[k]]+=fWeight[k];
471 else
472 classpopw[0]+=fWeight[k];
473
474 mean +=fHadTrue[k]*fWeight[k];
475 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
476
477 winbag[k]+=fWeight[k]; // Increase weight if chosen more than once
478 jinbag[k]=1;
479 }
480
481 //-------------------------------------------------------------------
482 // modifying sorted-data array for in-bag data:
483
484 // In bagging procedure ca. 2/3 of all elements in the original
485 // training sample are used to build the in-bag data
486 const MArrayF hadtrue(fHadTrue.GetSize(), fHadTrue.GetArray());
487 const MArrayI fclass(fClass.GetSize(), fClass.GetArray());
488 const MArrayI datarang(fDataRang.GetSize(), fDataRang.GetArray());
489
490 MArrayI datsortinbag(fDataSort.GetSize(), fDataSort.GetArray());
491
492 ModifyDataSort(datsortinbag, jinbag);
493
494 fRanTree->GrowTree(fMatrix,hadtrue,fclass,datsortinbag,datarang,classpopw,mean,square,
495 jinbag,winbag,nclass);
496
497 //-------------------------------------------------------------------
498 // error-estimates from out-of-bag data (oob data):
499 //
500 // For a single tree the events not(!) contained in the bootstrap sample of
501 // this tree can be used to obtain estimates for the classification error of
502 // this tree.
503 // If you take a certain event, it is contained in the oob-data of 1/3 of
504 // the trees (see comment to ModifyData). This means that the classification error
505 // determined from oob-data is underestimated, but can still be taken as upper limit.
506
507 Int_t ninbag = 0;
508 for (Int_t ievt=0;ievt<numdata;ievt++)
509 {
510 if (jinbag[ievt]>0)
511 {
512 ninbag++;
513 continue;
514 }
515
516 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
517 fNTimesOutBag[ievt]++;
518 }
519
520 Int_t n=0;
521 Float_t ferr=0;
522
523 for (Int_t ievt=0;ievt<numdata;ievt++)
524 {
525 if(fNTimesOutBag[ievt]!=0)
526 {
527 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
528 if(calcResolution) val/=fHadTrue[ievt];
529
530 ferr += val*val;
531 n++;
532 }
533 }
534 ferr = TMath::Sqrt(ferr/n);
535
536 //-------------------------------------------------------------------
537 // give running output
538 *fLog << setw(4) << fTreeNo;
539 *fLog << Form(" %8.1f", 100.*ninbag/numdata);
540 *fLog << setw(9) << fRanTree->GetNumEndNodes();
541 *fLog << Form(" %9.1f", 100.*n/numdata);
542 *fLog << Form("%18.2f", ferr*100.);
543 *fLog << endl;
544
545 fRanTree->SetError(ferr);
546
547 // adding tree to forest
548 AddTree();
549
550 return fTreeNo<fNumTrees;
551}
552
553Bool_t MRanForest::CreateDataSort()
554{
555 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
556 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
557 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
558 //
559 // There may be more then 1 event with rang r (due to bagging).
560
561 const Int_t numdata = GetNumData();
562 const Int_t dim = GetNumDim();
563
564 TArrayF v(numdata);
565 TArrayI isort(numdata);
566
567
568 for (Int_t mvar=0;mvar<dim;mvar++)
569 {
570
571 for(Int_t n=0;n<numdata;n++)
572 {
573 v[n]=(*fMatrix)(n,mvar);
574 //isort[n]=n;
575
576 if(TMath::IsNaN(v[n]))
577 {
578 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
579 *fLog << err <<" has the value NaN."<<endl;
580 return kFALSE;
581 }
582 }
583
584 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
585
586 // this sorts the v[n] in ascending order. isort[n] is the event number
587 // of that v[n], which is the n-th from the lowest (assume the original
588 // event numbers are 0,1,...).
589
590 // control sorting
591 /*
592 for(int n=0;n<numdata-1;n++)
593 if(v[isort[n]]>v[isort[n+1]])
594 {
595 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
596 *fLog << err <<" not at correct sorting position."<<endl;
597 return kFALSE;
598 }
599 */
600
601 // DataRang is similar to the isort index starting from 0 it is
602 // increased by one for each event which is greater, but stays
603 // the same for the same value. (So to say it counts how many
604 // different values we have)
605 for(Int_t n=0;n<numdata-1;n++)
606 {
607 const Int_t n1=isort[n];
608 const Int_t n2=isort[n+1];
609
610 // FIXME: Copying isort[n] to fDataSort[mvar*numdata] can be accelerated!
611 fDataSort[mvar*numdata+n]=n1;
612 if(n==0) fDataRang[mvar*numdata+n1]=0;
613 if(v[n1]<v[n2])
614 {
615 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
616 }else{
617 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
618 }
619 }
620 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
621 }
622 return kTRUE;
623}
624
625// Reoves all indices which are not in the bag from the datsortinbag
626void MRanForest::ModifyDataSort(MArrayI &datsortinbag, const MArrayI &jinbag)
627{
628 const Int_t numdim=GetNumDim();
629 const Int_t numdata=GetNumData();
630
631 Int_t ninbag=0;
632 for (Int_t n=0;n<numdata;n++)
633 if(jinbag[n]==1) ninbag++;
634
635 for(Int_t m=0;m<numdim;m++)
636 {
637 Int_t *subsort = &datsortinbag[m*numdata];
638
639 Int_t k=0;
640 for(Int_t n=0;n<ninbag;n++)
641 {
642 if(jinbag[subsort[k]]==1)
643 {
644 subsort[n] = subsort[k];
645 k++;
646 }else{
647 for(Int_t j=k+1;j<numdata;j++)
648 {
649 if(jinbag[subsort[j]]==1)
650 {
651 subsort[n] = subsort[j];
652 k = j+1;
653 break;
654 }
655 }
656 }
657 }
658 }
659}
660
661Bool_t MRanForest::AsciiWrite(ostream &out) const
662{
663 Int_t n=0;
664 MRanTree *tree;
665 TIter forest(fForest);
666
667 while ((tree=(MRanTree*)forest.Next()))
668 {
669 tree->AsciiWrite(out);
670 n++;
671 }
672
673 return n==fNumTrees;
674}
Note: See TracBrowser for help on using the repository browser.