source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 8650

Last change on this file since 8650 was 8644, checked in by tbretz, 17 years ago
*** empty log message ***
File size: 19.4 KB
Line 
1/* ======================================================================== *\
2! $Name: not supported by cvs2svn $:$Id: MRanForest.cc,v 1.27 2007-07-24 13:35:39 tbretz Exp $
3! --------------------------------------------------------------------------
4!
5! *
6! * This file is part of MARS, the MAGIC Analysis and Reconstruction
7! * Software. It is distributed to you in the hope that it can be a useful
8! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
9! * It is distributed WITHOUT ANY WARRANTY.
10! *
11! * Permission to use, copy, modify and distribute this software and its
12! * documentation for any purpose is hereby granted without fee,
13! * provided that the above copyright notice appear in all copies and
14! * that both that copyright notice and this permission notice appear
15! * in supporting documentation. It is provided "as is" without express
16! * or implied warranty.
17! *
18!
19!
20! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
21! Author(s): Thomas Bretz <mailto:tbretz@astro.uni-wuerzburg.de>
22!
23! Copyright: MAGIC Software Development, 2000-2006
24!
25!
26\* ======================================================================== */
27
28/////////////////////////////////////////////////////////////////////////////
29//
30// MRanForest
31//
32// ParameterContainer for Forest structure
33//
34// A random forest can be grown by calling GrowForest.
35// In advance SetupGrow must be called in order to initialize arrays and
36// do some preprocessing.
37// GrowForest() provides the training data for a single tree (bootstrap
38// aggregate procedure)
39//
40// Essentially two random elements serve to provide a "random" forest,
41// namely bootstrap aggregating (which is done in GrowForest()) and random
42// split selection (which is subject to MRanTree::GrowTree())
43//
44/////////////////////////////////////////////////////////////////////////////
45#include "MRanForest.h"
46
47#include <TRandom.h>
48
49#include "MHMatrix.h"
50#include "MRanTree.h"
51#include "MData.h"
52#include "MDataArray.h"
53#include "MParList.h"
54
55#include "MArrayI.h"
56#include "MArrayF.h"
57
58#include "MLog.h"
59#include "MLogManip.h"
60
61ClassImp(MRanForest);
62
63using namespace std;
64
65// --------------------------------------------------------------------------
66//
67// Default constructor.
68//
69MRanForest::MRanForest(const char *name, const char *title)
70 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
71 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
72{
73 fName = name ? name : "MRanForest";
74 fTitle = title ? title : "Storage container for Random Forest";
75
76 fForest=new TObjArray();
77 fForest->SetOwner(kTRUE);
78}
79
80MRanForest::MRanForest(const MRanForest &rf)
81{
82 // Copy constructor
83 fName = rf.fName;
84 fTitle = rf.fTitle;
85
86 fClassify = rf.fClassify;
87 fNumTrees = rf.fNumTrees;
88 fNumTry = rf.fNumTry;
89 fNdSize = rf.fNdSize;
90 fTreeNo = rf.fTreeNo;
91 fRanTree = NULL;
92
93 fRules=new MDataArray();
94 fRules->Reset();
95
96 MDataArray *newrules=rf.fRules;
97
98 for(Int_t i=0;i<newrules->GetNumEntries();i++)
99 {
100 MData &data=(*newrules)[i];
101 fRules->AddEntry(data.GetRule());
102 }
103
104 // trees
105 fForest=new TObjArray();
106 fForest->SetOwner(kTRUE);
107
108 TObjArray *newforest=rf.fForest;
109 for(Int_t i=0;i<newforest->GetEntries();i++)
110 {
111 MRanTree *rantree=(MRanTree*)newforest->At(i);
112
113 MRanTree *newtree=new MRanTree(*rantree);
114 fForest->Add(newtree);
115 }
116
117 fHadTrue = rf.fHadTrue;
118 fHadEst = rf.fHadEst;
119 fDataSort = rf.fDataSort;
120 fDataRang = rf.fDataRang;
121 fClassPop = rf.fClassPop;
122 fWeight = rf.fWeight;
123 fTreeHad = rf.fTreeHad;
124
125 fNTimesOutBag = rf.fNTimesOutBag;
126}
127
128// --------------------------------------------------------------------------
129// Destructor.
130MRanForest::~MRanForest()
131{
132 delete fForest;
133 if (fMatrix)
134 delete fMatrix;
135 if (fRules)
136 delete fRules;
137}
138
139void MRanForest::Print(Option_t *o) const
140{
141 *fLog << inf << GetDescriptor() << ": " << endl;
142 MRanTree *t = GetTree(0);
143 if (t)
144 {
145 *fLog << "Setting up RF for training on target:" << endl;
146 *fLog << " " << t->GetTitle() << endl;
147 }
148 if (fRules)
149 {
150 *fLog << "Following rules are used as input to RF:" << endl;
151 for (Int_t i=0;i<fRules->GetNumEntries();i++)
152 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
153 }
154 *fLog << "Random forest parameters:" << endl;
155 if (t)
156 {
157 *fLog << " - " << (t->IsClassify()?"classification":"regression") << " tree" << endl;
158 *fLog << " - Number of trys: " << t->GetNumTry() << endl;
159 *fLog << " - Node size: " << t->GetNdSize() << endl;
160 }
161 *fLog << " - Number of trees: " << fNumTrees << endl;
162 *fLog << " - User value: " << fUserVal << endl;
163 *fLog << endl;
164}
165
166void MRanForest::SetNumTrees(Int_t n)
167{
168 //at least 1 tree
169 fNumTrees=TMath::Max(n,1);
170}
171
172void MRanForest::SetNumTry(Int_t n)
173{
174 fNumTry=TMath::Max(n,0);
175}
176
177void MRanForest::SetNdSize(Int_t n)
178{
179 fNdSize=TMath::Max(n,1);
180}
181
182void MRanForest::SetWeights(const TArrayF &weights)
183{
184 fWeight=weights;
185}
186
187void MRanForest::SetGrid(const TArrayD &grid)
188{
189 const int n=grid.GetSize();
190
191 for(int i=0;i<n-1;i++)
192 if(grid[i]>=grid[i+1])
193 {
194 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
195 return;
196 }
197
198 fGrid=grid;
199
200 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
201 //for(int i=0;i<n;i++)
202 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
203}
204
205MRanTree *MRanForest::GetTree(Int_t i) const
206{
207 return static_cast<MRanTree*>(fForest->UncheckedAt(i));
208}
209
210Int_t MRanForest::GetNumDim() const
211{
212 return fMatrix ? fMatrix->GetNcols() : 0;
213}
214
215Int_t MRanForest::GetNumData() const
216{
217 return fMatrix ? fMatrix->GetNrows() : 0;
218}
219
220Int_t MRanForest::GetNclass() const
221{
222 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
223
224 return int(fClass[maxidx])+1;
225}
226
227void MRanForest::PrepareClasses()
228{
229 const int numdata=fHadTrue.GetSize();
230
231 if(fGrid.GetSize()>0)
232 {
233 // classes given by grid
234 const int ngrid=fGrid.GetSize();
235
236 for(int j=0;j<numdata;j++)
237 {
238 // Array is supposed to be sorted prior to this call.
239 // If match is found, function returns position of element.
240 // If no match found, function gives nearest element smaller
241 // than value.
242 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
243
244 fClass[j] = k;
245 }
246
247 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
248 int min = fClass[minidx];
249 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
250
251 }else{
252 // classes directly given
253 for (Int_t j=0;j<numdata;j++)
254 fClass[j] = TMath::Nint(fHadTrue[j]);
255 }
256}
257
258Double_t MRanForest::CalcHadroness()
259{
260 TVector event;
261 *fRules >> event;
262
263 return CalcHadroness(event);
264}
265
266Double_t MRanForest::CalcHadroness(const TVector &event)
267{
268 fTreeHad.Set(fNumTrees);
269
270 Double_t hadroness=0;
271 Int_t ntree =0;
272
273 TIter Next(fForest);
274
275 MRanTree *tree;
276 while ((tree=(MRanTree*)Next()))
277 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
278
279 return hadroness/ntree;
280}
281
282Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
283{
284 fRanTree = rantree ? rantree : fRanTree;
285
286 if (!fRanTree) return kFALSE;
287
288 MRanTree *newtree=new MRanTree(*fRanTree);
289 fForest->Add(newtree);
290
291 return kTRUE;
292}
293
294Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
295{
296 //-------------------------------------------------------------------
297 // access matrix, copy last column (target) preliminarily
298 // into fHadTrue
299 if (fMatrix)
300 delete fMatrix;
301 fMatrix = new TMatrix(mat->GetM());
302
303 int dim = fMatrix->GetNcols()-1;
304 int numdata = fMatrix->GetNrows();
305
306 fHadTrue.Set(numdata);
307 fHadTrue.Reset();
308
309 for (Int_t j=0;j<numdata;j++)
310 fHadTrue[j] = (*fMatrix)(j,dim);
311
312 // remove last col
313 fMatrix->ResizeTo(numdata,dim);
314
315 //-------------------------------------------------------------------
316 // setup labels for classification/regression
317 fClass.Set(numdata);
318 fClass.Reset();
319
320 if (fClassify)
321 PrepareClasses();
322
323 //-------------------------------------------------------------------
324 // allocating and initializing arrays
325 fHadEst.Set(numdata);
326 fHadEst.Reset();
327
328 fNTimesOutBag.Set(numdata);
329 fNTimesOutBag.Reset();
330
331 fDataSort.Set(dim*numdata);
332 fDataSort.Reset();
333
334 fDataRang.Set(dim*numdata);
335 fDataRang.Reset();
336
337 Bool_t useweights = fWeight.GetSize()==numdata;
338 if (!useweights)
339 {
340 fWeight.Set(numdata);
341 fWeight.Reset(1.);
342 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
343 }
344
345 //-------------------------------------------------------------------
346 // setup rules to be used for classification/regression
347 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
348 if (allrules==NULL)
349 {
350 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
351 return kFALSE;
352 }
353
354 if (allrules->GetNumEntries()!=dim+1)
355 {
356 *fLog << err <<"Rules of matrix " << allrules->GetNumEntries() << " mismatch dimension+1 " << dim+1 << "...exiting."<< endl;
357 return kFALSE;
358 }
359
360 if (fRules)
361 delete fRules;
362
363 fRules = new MDataArray();
364 fRules->Reset();
365
366 const TString target_rule = (*allrules)[dim].GetRule();
367 for (Int_t i=0;i<dim;i++)
368 fRules->AddEntry((*allrules)[i].GetRule());
369
370 *fLog << inf << endl;
371 *fLog << "Setting up RF for training on target:" << endl;
372 *fLog << " " << target_rule.Data() << endl;
373 *fLog << "Following rules are used as input to RF:" << endl;
374 for (Int_t i=0;i<dim;i++)
375 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
376 *fLog << endl;
377
378 //-------------------------------------------------------------------
379 // prepare (sort) data for fast optimization algorithm
380 if (!CreateDataSort())
381 return kFALSE;
382
383 //-------------------------------------------------------------------
384 // access and init tree container
385 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
386 if(!fRanTree)
387 {
388 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
389 return kFALSE;
390 }
391 //fRanTree->SetName(target_rule); // Is not stored anyhow
392
393 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
394
395 *fLog << inf << endl;
396 *fLog << "Following input for the tree growing are used:"<<endl;
397 *fLog << " Forest type : "<<(fClassify?"classification":"regression")<<endl;
398 *fLog << " Number of Trees : "<<fNumTrees<<endl;
399 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
400 *fLog << " Final Node size : "<<fNdSize<<endl;
401 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
402 *fLog << " Using Weights : "<<(useweights?"Yes":"No")<<endl;
403 *fLog << " Number of Events: "<<numdata<<endl;
404 *fLog << " Number of Params: "<<dim<<endl;
405
406 if(fNumTry==0)
407 {
408 fNumTry=tryest;
409 *fLog << inf << endl;
410 *fLog << "Set no. of trials to the recommended value of round(";
411 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
412 }
413
414 fRanTree->SetNumTry(fNumTry);
415 fRanTree->SetClassify(fClassify);
416 fRanTree->SetNdSize(fNdSize);
417
418 fTreeNo=0;
419
420 return kTRUE;
421}
422
423Bool_t MRanForest::GrowForest()
424{
425 if(!gRandom)
426 {
427 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
428 return kFALSE;
429 }
430
431 fTreeNo++;
432
433 //-------------------------------------------------------------------
434 // initialize running output
435
436 float minfloat=TMath::MinElement(fHadTrue.GetSize(),fHadTrue.GetArray());
437 Bool_t calcResolution=(minfloat>FLT_MIN);
438
439 if (fTreeNo==1)
440 {
441 *fLog << inf << endl << underline;
442
443 if(calcResolution)
444 *fLog << "TreeNum BagSize NumNodes TestSize Bias/% var/% res/% (from oob-data)" << endl;
445 else
446 *fLog << "TreeNum BagSize NumNodes TestSize Bias/au var/au rms/au (from oob-data)" << endl;
447 // 12345678901234567890123456789012345678901234567890
448 }
449
450 const Int_t numdata = GetNumData();
451 const Int_t nclass = GetNclass();
452
453 //-------------------------------------------------------------------
454 // bootstrap aggregating (bagging) -> sampling with replacement:
455
456 MArrayF classpopw(nclass);
457 MArrayI jinbag(numdata); // Initialization includes filling with 0
458 MArrayF winbag(numdata); // Initialization includes filling with 0
459
460 float square=0;
461 float mean=0;
462
463 for (Int_t n=0; n<numdata; n++)
464 {
465 // The integer k is randomly (uniformly) chosen from the set
466 // {0,1,...,numdata-1}, which is the set of the index numbers of
467 // all events in the training sample
468
469 const Int_t k = gRandom->Integer(numdata);
470
471 if(fClassify)
472 classpopw[fClass[k]]+=fWeight[k];
473 else
474 classpopw[0]+=fWeight[k];
475
476 mean +=fHadTrue[k]*fWeight[k];
477 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
478
479 winbag[k]+=fWeight[k]; // Increase weight if chosen more than once
480 jinbag[k]=1;
481 }
482
483 //-------------------------------------------------------------------
484 // modifying sorted-data array for in-bag data:
485
486 // In bagging procedure ca. 2/3 of all elements in the original
487 // training sample are used to build the in-bag data
488 const MArrayF hadtrue(fHadTrue.GetSize(), fHadTrue.GetArray());
489 const MArrayI fclass(fClass.GetSize(), fClass.GetArray());
490 const MArrayI datarang(fDataRang.GetSize(), fDataRang.GetArray());
491
492 MArrayI datsortinbag(fDataSort.GetSize(), fDataSort.GetArray());
493
494 ModifyDataSort(datsortinbag, jinbag);
495
496 fRanTree->GrowTree(fMatrix,hadtrue,fclass,datsortinbag,datarang,classpopw,mean,square,
497 jinbag,winbag,nclass);
498
499 const Double_t ferr = EstimateError(jinbag, calcResolution);
500
501 fRanTree->SetError(ferr);
502
503 // adding tree to forest
504 AddTree();
505
506 return fTreeNo<fNumTrees;
507}
508
509//-------------------------------------------------------------------
510// error-estimates from out-of-bag data (oob data):
511//
512// For a single tree the events not(!) contained in the bootstrap
513// sample of this tree can be used to obtain estimates for the
514// classification error of this tree.
515// If you take a certain event, it is contained in the oob-data of
516// 1/3 of the trees (see comment to ModifyData). This means that the
517// classification error determined from oob-data is underestimated,
518// but can still be taken as upper limit.
519//
520Double_t MRanForest::EstimateError(const MArrayI &jinbag, Bool_t calcResolution)
521{
522 const Int_t numdata = GetNumData();
523
524 Int_t ninbag = 0;
525 for (Int_t ievt=0;ievt<numdata;ievt++)
526 {
527 if (jinbag[ievt]>0)
528 {
529 ninbag++;
530 continue;
531 }
532
533 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
534 fNTimesOutBag[ievt]++;
535 }
536
537 Int_t n=0;
538
539 Double_t sum=0;
540 Double_t sq =0;
541 for (Int_t i=0; i<numdata; i++)
542 {
543 if (fNTimesOutBag[i]==0)
544 continue;
545
546 const Float_t hadest = fHadEst[i]/fNTimesOutBag[i];
547
548 const Float_t val = calcResolution ?
549 hadest/fHadTrue[i] - 1 : hadest - fHadTrue[i];
550
551 sum += val;
552 sq += val*val;
553 n++;
554 }
555
556 if (calcResolution)
557 {
558 sum *= 100;
559 sq *= 10000;
560 }
561
562 sum /= n;
563 sq /= n;
564
565 const Double_t var = TMath::Sqrt(sq-sum*sum);
566 const Double_t ferr = TMath::Sqrt(sq);
567
568 //-------------------------------------------------------------------
569 // give running output
570 *fLog << setw(4) << fTreeNo;
571 *fLog << Form(" %8.1f", 100.*ninbag/numdata);
572 *fLog << setw(9) << fRanTree->GetNumEndNodes();
573 *fLog << Form(" %9.1f", 100.*n/numdata);
574 *fLog << Form(" %7.2f", sum);
575 *fLog << Form(" %7.2f", var);
576 *fLog << Form(" %7.2f", ferr);
577 *fLog << endl;
578
579 return ferr;
580}
581
582Bool_t MRanForest::CreateDataSort()
583{
584 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
585 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
586 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
587 //
588 // There may be more then 1 event with rang r (due to bagging).
589
590 const Int_t numdata = GetNumData();
591 const Int_t dim = GetNumDim();
592
593 TArrayF v(numdata);
594 TArrayI isort(numdata);
595
596
597 for (Int_t mvar=0;mvar<dim;mvar++)
598 {
599
600 for(Int_t n=0;n<numdata;n++)
601 {
602 v[n]=(*fMatrix)(n,mvar);
603 //isort[n]=n;
604
605 if (!TMath::Finite(v[n]))
606 {
607 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
608 *fLog << err <<" has a non finite value (eg. NaN)."<<endl;
609 return kFALSE;
610 }
611 }
612
613 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
614
615 // this sorts the v[n] in ascending order. isort[n] is the
616 // event number of that v[n], which is the n-th from the
617 // lowest (assume the original event numbers are 0,1,...).
618
619 // control sorting
620 /*
621 for(int n=0;n<numdata-1;n++)
622 if(v[isort[n]]>v[isort[n+1]])
623 {
624 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
625 *fLog << err <<" not at correct sorting position."<<endl;
626 return kFALSE;
627 }
628 */
629
630 // DataRang is similar to the isort index starting from 0 it is
631 // increased by one for each event which is greater, but stays
632 // the same for the same value. (So to say it counts how many
633 // different values we have)
634 for(Int_t n=0;n<numdata-1;n++)
635 {
636 const Int_t n1=isort[n];
637 const Int_t n2=isort[n+1];
638
639 // FIXME: Copying isort[n] to fDataSort[mvar*numdata]
640 // can be accelerated!
641 fDataSort[mvar*numdata+n]=n1;
642 if(n==0) fDataRang[mvar*numdata+n1]=0;
643 if(v[n1]<v[n2])
644 {
645 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
646 }else{
647 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
648 }
649 }
650 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
651 }
652 return kTRUE;
653}
654
655// Reoves all indices which are not in the bag from the datsortinbag
656void MRanForest::ModifyDataSort(MArrayI &datsortinbag, const MArrayI &jinbag)
657{
658 const Int_t numdim=GetNumDim();
659 const Int_t numdata=GetNumData();
660
661 Int_t ninbag=0;
662 for (Int_t n=0;n<numdata;n++)
663 if(jinbag[n]==1) ninbag++;
664
665 for(Int_t m=0;m<numdim;m++)
666 {
667 Int_t *subsort = &datsortinbag[m*numdata];
668
669 Int_t k=0;
670 for(Int_t n=0;n<ninbag;n++)
671 {
672 if(jinbag[subsort[k]]==1)
673 {
674 subsort[n] = subsort[k];
675 k++;
676 }else{
677 for(Int_t j=k+1;j<numdata;j++)
678 {
679 if(jinbag[subsort[j]]==1)
680 {
681 subsort[n] = subsort[j];
682 k = j+1;
683 break;
684 }
685 }
686 }
687 }
688 }
689}
690
691Bool_t MRanForest::AsciiWrite(ostream &out) const
692{
693 Int_t n=0;
694 MRanTree *tree;
695 TIter forest(fForest);
696
697 while ((tree=(MRanTree*)forest.Next()))
698 {
699 tree->AsciiWrite(out);
700 n++;
701 }
702
703 return n==fNumTrees;
704}
Note: See TracBrowser for help on using the repository browser.