source: branches/Mars_McMismatchStudy/mranforest/MRanForest.cc@ 19921

Last change on this file since 19921 was 10166, checked in by tbretz, 14 years ago
Removed the old obsolete cvs header line.
File size: 19.2 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19! Author(s): Thomas Bretz <mailto:tbretz@astro.uni-wuerzburg.de>
20!
21! Copyright: MAGIC Software Development, 2000-2006
22!
23!
24\* ======================================================================== */
25
26/////////////////////////////////////////////////////////////////////////////
27//
28// MRanForest
29//
30// ParameterContainer for Forest structure
31//
32// A random forest can be grown by calling GrowForest.
33// In advance SetupGrow must be called in order to initialize arrays and
34// do some preprocessing.
35// GrowForest() provides the training data for a single tree (bootstrap
36// aggregate procedure)
37//
38// Essentially two random elements serve to provide a "random" forest,
39// namely bootstrap aggregating (which is done in GrowForest()) and random
40// split selection (which is subject to MRanTree::GrowTree())
41//
42/////////////////////////////////////////////////////////////////////////////
43#include "MRanForest.h"
44
45#include <TMath.h>
46#include <TRandom.h>
47
48#include "MHMatrix.h"
49#include "MRanTree.h"
50#include "MData.h"
51#include "MDataArray.h"
52#include "MParList.h"
53
54#include "MArrayI.h"
55#include "MArrayF.h"
56
57#include "MLog.h"
58#include "MLogManip.h"
59
60ClassImp(MRanForest);
61
62using namespace std;
63
64// --------------------------------------------------------------------------
65//
66// Default constructor.
67//
68MRanForest::MRanForest(const char *name, const char *title)
69 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
70 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
71{
72 fName = name ? name : "MRanForest";
73 fTitle = title ? title : "Storage container for Random Forest";
74
75 fForest=new TObjArray();
76 fForest->SetOwner(kTRUE);
77}
78
79MRanForest::MRanForest(const MRanForest &rf)
80{
81 // Copy constructor
82 fName = rf.fName;
83 fTitle = rf.fTitle;
84
85 fClassify = rf.fClassify;
86 fNumTrees = rf.fNumTrees;
87 fNumTry = rf.fNumTry;
88 fNdSize = rf.fNdSize;
89 fTreeNo = rf.fTreeNo;
90 fRanTree = NULL;
91
92 fRules=new MDataArray();
93 fRules->Reset();
94
95 MDataArray *newrules=rf.fRules;
96
97 for(Int_t i=0;i<newrules->GetNumEntries();i++)
98 {
99 MData &data=(*newrules)[i];
100 fRules->AddEntry(data.GetRule());
101 }
102
103 // trees
104 fForest=new TObjArray();
105 fForest->SetOwner(kTRUE);
106
107 TObjArray *newforest=rf.fForest;
108 for(Int_t i=0;i<newforest->GetEntries();i++)
109 {
110 MRanTree *rantree=(MRanTree*)newforest->At(i);
111
112 MRanTree *newtree=new MRanTree(*rantree);
113 fForest->Add(newtree);
114 }
115
116 fHadTrue = rf.fHadTrue;
117 fHadEst = rf.fHadEst;
118 fDataSort = rf.fDataSort;
119 fDataRang = rf.fDataRang;
120 fClassPop = rf.fClassPop;
121 fWeight = rf.fWeight;
122 fTreeHad = rf.fTreeHad;
123
124 fNTimesOutBag = rf.fNTimesOutBag;
125}
126
127// --------------------------------------------------------------------------
128// Destructor.
129MRanForest::~MRanForest()
130{
131 delete fForest;
132 if (fMatrix)
133 delete fMatrix;
134 if (fRules)
135 delete fRules;
136}
137
138void MRanForest::Print(Option_t *o) const
139{
140 *fLog << inf << GetDescriptor() << ": " << endl;
141 MRanTree *t = GetTree(0);
142 if (t)
143 {
144 *fLog << "Setting up RF for training on target:" << endl;
145 *fLog << " " << t->GetTitle() << endl;
146 }
147 if (fRules)
148 {
149 *fLog << "Following rules are used as input to RF:" << endl;
150 for (Int_t i=0;i<fRules->GetNumEntries();i++)
151 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
152 }
153 *fLog << "Random forest parameters:" << endl;
154 if (t)
155 {
156 *fLog << " - " << (t->IsClassify()?"classification":"regression") << " tree" << endl;
157 *fLog << " - Number of trys: " << t->GetNumTry() << endl;
158 *fLog << " - Node size: " << t->GetNdSize() << endl;
159 }
160 *fLog << " - Number of trees: " << fNumTrees << endl;
161 *fLog << " - User value: " << fUserVal << endl;
162 *fLog << endl;
163}
164
165void MRanForest::SetNumTrees(Int_t n)
166{
167 //at least 1 tree
168 fNumTrees=TMath::Max(n,1);
169}
170
171void MRanForest::SetNumTry(Int_t n)
172{
173 fNumTry=TMath::Max(n,0);
174}
175
176void MRanForest::SetNdSize(Int_t n)
177{
178 fNdSize=TMath::Max(n,1);
179}
180
181void MRanForest::SetWeights(const TArrayF &weights)
182{
183 fWeight=weights;
184}
185
186void MRanForest::SetGrid(const TArrayD &grid)
187{
188 const int n=grid.GetSize();
189
190 for(int i=0;i<n-1;i++)
191 if(grid[i]>=grid[i+1])
192 {
193 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
194 return;
195 }
196
197 fGrid=grid;
198
199 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
200 //for(int i=0;i<n;i++)
201 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
202}
203
204MRanTree *MRanForest::GetTree(Int_t i) const
205{
206 return static_cast<MRanTree*>(fForest->UncheckedAt(i));
207}
208
209Int_t MRanForest::GetNumDim() const
210{
211 return fMatrix ? fMatrix->GetNcols() : 0;
212}
213
214Int_t MRanForest::GetNumData() const
215{
216 return fMatrix ? fMatrix->GetNrows() : 0;
217}
218
219Int_t MRanForest::GetNclass() const
220{
221 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
222
223 return int(fClass[maxidx])+1;
224}
225
226void MRanForest::PrepareClasses()
227{
228 const int numdata=fHadTrue.GetSize();
229
230 if(fGrid.GetSize()>0)
231 {
232 // classes given by grid
233 const int ngrid=fGrid.GetSize();
234
235 for(int j=0;j<numdata;j++)
236 {
237 // Array is supposed to be sorted prior to this call.
238 // If match is found, function returns position of element.
239 // If no match found, function gives nearest element smaller
240 // than value.
241 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), (Double_t)fHadTrue[j]);
242
243 fClass[j] = k;
244 }
245
246 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
247 int min = fClass[minidx];
248 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
249
250 }else{
251 // classes directly given
252 for (Int_t j=0;j<numdata;j++)
253 fClass[j] = TMath::Nint(fHadTrue[j]);
254 }
255}
256
257Double_t MRanForest::CalcHadroness()
258{
259 TVector event;
260 *fRules >> event;
261
262 return CalcHadroness(event);
263}
264
265Double_t MRanForest::CalcHadroness(const TVector &event)
266{
267 fTreeHad.Set(fNumTrees);
268
269 Double_t hadroness=0;
270 Int_t ntree =0;
271
272 TIter Next(fForest);
273
274 MRanTree *tree;
275 while ((tree=(MRanTree*)Next()))
276 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
277
278 return hadroness/ntree;
279}
280
281Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
282{
283 fRanTree = rantree ? rantree : fRanTree;
284
285 if (!fRanTree) return kFALSE;
286
287 MRanTree *newtree=new MRanTree(*fRanTree);
288 fForest->Add(newtree);
289
290 return kTRUE;
291}
292
293Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
294{
295 //-------------------------------------------------------------------
296 // access matrix, copy last column (target) preliminarily
297 // into fHadTrue
298 if (fMatrix)
299 delete fMatrix;
300 fMatrix = new TMatrix(mat->GetM());
301
302 int dim = fMatrix->GetNcols()-1;
303 int numdata = fMatrix->GetNrows();
304
305 fHadTrue.Set(numdata);
306 fHadTrue.Reset();
307
308 for (Int_t j=0;j<numdata;j++)
309 fHadTrue[j] = (*fMatrix)(j,dim);
310
311 // remove last col
312 fMatrix->ResizeTo(numdata,dim);
313
314 //-------------------------------------------------------------------
315 // setup labels for classification/regression
316 fClass.Set(numdata);
317 fClass.Reset();
318
319 if (fClassify)
320 PrepareClasses();
321
322 //-------------------------------------------------------------------
323 // allocating and initializing arrays
324 fHadEst.Set(numdata);
325 fHadEst.Reset();
326
327 fNTimesOutBag.Set(numdata);
328 fNTimesOutBag.Reset();
329
330 fDataSort.Set(dim*numdata);
331 fDataSort.Reset();
332
333 fDataRang.Set(dim*numdata);
334 fDataRang.Reset();
335
336 Bool_t useweights = fWeight.GetSize()==numdata;
337 if (!useweights)
338 {
339 fWeight.Set(numdata);
340 fWeight.Reset(1.);
341 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
342 }
343
344 //-------------------------------------------------------------------
345 // setup rules to be used for classification/regression
346 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
347 if (allrules==NULL)
348 {
349 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
350 return kFALSE;
351 }
352
353 if (allrules->GetNumEntries()!=dim+1)
354 {
355 *fLog << err <<"Rules of matrix " << allrules->GetNumEntries() << " mismatch dimension+1 " << dim+1 << "...exiting."<< endl;
356 return kFALSE;
357 }
358
359 if (fRules)
360 delete fRules;
361
362 fRules = new MDataArray();
363 fRules->Reset();
364
365 const TString target_rule = (*allrules)[dim].GetRule();
366 for (Int_t i=0;i<dim;i++)
367 fRules->AddEntry((*allrules)[i].GetRule());
368
369 *fLog << inf << endl;
370 *fLog << "Setting up RF for training on target:" << endl;
371 *fLog << " " << target_rule.Data() << endl;
372 *fLog << "Following rules are used as input to RF:" << endl;
373 for (Int_t i=0;i<dim;i++)
374 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
375 *fLog << endl;
376
377 //-------------------------------------------------------------------
378 // prepare (sort) data for fast optimization algorithm
379 if (!CreateDataSort())
380 return kFALSE;
381
382 //-------------------------------------------------------------------
383 // access and init tree container
384 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
385 if(!fRanTree)
386 {
387 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
388 return kFALSE;
389 }
390 //fRanTree->SetName(target_rule); // Is not stored anyhow
391
392 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
393
394 *fLog << inf << endl;
395 *fLog << "Following input for the tree growing are used:"<<endl;
396 *fLog << " Forest type : "<<(fClassify?"classification":"regression")<<endl;
397 *fLog << " Number of Trees : "<<fNumTrees<<endl;
398 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
399 *fLog << " Final Node size : "<<fNdSize<<endl;
400 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
401 *fLog << " Using Weights : "<<(useweights?"Yes":"No")<<endl;
402 *fLog << " Number of Events: "<<numdata<<endl;
403 *fLog << " Number of Params: "<<dim<<endl;
404
405 if(fNumTry==0)
406 {
407 fNumTry=tryest;
408 *fLog << inf << endl;
409 *fLog << "Set no. of trials to the recommended value of round(";
410 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
411 }
412
413 fRanTree->SetNumTry(fNumTry);
414 fRanTree->SetClassify(fClassify);
415 fRanTree->SetNdSize(fNdSize);
416
417 fTreeNo=0;
418
419 return kTRUE;
420}
421
422Bool_t MRanForest::GrowForest()
423{
424 if(!gRandom)
425 {
426 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
427 return kFALSE;
428 }
429
430 fTreeNo++;
431
432 //-------------------------------------------------------------------
433 // initialize running output
434
435 float minfloat=TMath::MinElement(fHadTrue.GetSize(),fHadTrue.GetArray());
436 Bool_t calcResolution=(minfloat>FLT_MIN);
437
438 if (fTreeNo==1)
439 {
440 *fLog << inf << endl << underline;
441
442 if(calcResolution)
443 *fLog << "TreeNum BagSize NumNodes TestSize Bias/% var/% res/% (from oob-data)" << endl;
444 else
445 *fLog << "TreeNum BagSize NumNodes TestSize Bias/au var/au rms/au (from oob-data)" << endl;
446 // 12345678901234567890123456789012345678901234567890
447 }
448
449 const Int_t numdata = GetNumData();
450 const Int_t nclass = GetNclass();
451
452 //-------------------------------------------------------------------
453 // bootstrap aggregating (bagging) -> sampling with replacement:
454
455 MArrayF classpopw(nclass);
456 MArrayI jinbag(numdata); // Initialization includes filling with 0
457 MArrayF winbag(numdata); // Initialization includes filling with 0
458
459 float square=0;
460 float mean=0;
461
462 for (Int_t n=0; n<numdata; n++)
463 {
464 // The integer k is randomly (uniformly) chosen from the set
465 // {0,1,...,numdata-1}, which is the set of the index numbers of
466 // all events in the training sample
467
468 const Int_t k = gRandom->Integer(numdata);
469
470 if(fClassify)
471 classpopw[fClass[k]]+=fWeight[k];
472 else
473 classpopw[0]+=fWeight[k];
474
475 mean +=fHadTrue[k]*fWeight[k];
476 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
477
478 winbag[k]+=fWeight[k]; // Increase weight if chosen more than once
479 jinbag[k]=1;
480 }
481
482 //-------------------------------------------------------------------
483 // modifying sorted-data array for in-bag data:
484
485 // In bagging procedure ca. 2/3 of all elements in the original
486 // training sample are used to build the in-bag data
487 const MArrayF hadtrue(fHadTrue.GetSize(), fHadTrue.GetArray());
488 const MArrayI fclass(fClass.GetSize(), fClass.GetArray());
489 const MArrayI datarang(fDataRang.GetSize(), fDataRang.GetArray());
490
491 MArrayI datsortinbag(fDataSort.GetSize(), fDataSort.GetArray());
492
493 ModifyDataSort(datsortinbag, jinbag);
494
495 fRanTree->GrowTree(fMatrix,hadtrue,fclass,datsortinbag,datarang,classpopw,mean,square,
496 jinbag,winbag,nclass);
497
498 const Double_t ferr = EstimateError(jinbag, calcResolution);
499
500 fRanTree->SetError(ferr);
501
502 // adding tree to forest
503 AddTree();
504
505 return fTreeNo<fNumTrees;
506}
507
508//-------------------------------------------------------------------
509// error-estimates from out-of-bag data (oob data):
510//
511// For a single tree the events not(!) contained in the bootstrap
512// sample of this tree can be used to obtain estimates for the
513// classification error of this tree.
514// If you take a certain event, it is contained in the oob-data of
515// 1/3 of the trees (see comment to ModifyData). This means that the
516// classification error determined from oob-data is underestimated,
517// but can still be taken as upper limit.
518//
519Double_t MRanForest::EstimateError(const MArrayI &jinbag, Bool_t calcResolution)
520{
521 const Int_t numdata = GetNumData();
522
523 Int_t ninbag = 0;
524 for (Int_t ievt=0;ievt<numdata;ievt++)
525 {
526 if (jinbag[ievt]>0)
527 {
528 ninbag++;
529 continue;
530 }
531
532 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
533 fNTimesOutBag[ievt]++;
534 }
535
536 Int_t n=0;
537
538 Double_t sum=0;
539 Double_t sq =0;
540 for (Int_t i=0; i<numdata; i++)
541 {
542 if (fNTimesOutBag[i]==0)
543 continue;
544
545 const Float_t hadest = fHadEst[i]/fNTimesOutBag[i];
546
547 const Float_t val = calcResolution ?
548 hadest/fHadTrue[i] - 1 : hadest - fHadTrue[i];
549
550 sum += val;
551 sq += val*val;
552 n++;
553 }
554
555 if (calcResolution)
556 {
557 sum *= 100;
558 sq *= 10000;
559 }
560
561 sum /= n;
562 sq /= n;
563
564 const Double_t var = TMath::Sqrt(sq-sum*sum);
565 const Double_t ferr = TMath::Sqrt(sq);
566
567 //-------------------------------------------------------------------
568 // give running output
569 *fLog << setw(4) << fTreeNo;
570 *fLog << Form(" %8.1f", 100.*ninbag/numdata);
571 *fLog << setw(9) << fRanTree->GetNumEndNodes();
572 *fLog << Form(" %9.1f", 100.*n/numdata);
573 *fLog << Form(" %7.2f", sum);
574 *fLog << Form(" %7.2f", var);
575 *fLog << Form(" %7.2f", ferr);
576 *fLog << endl;
577
578 return ferr;
579}
580
581Bool_t MRanForest::CreateDataSort()
582{
583 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
584 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
585 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
586 //
587 // There may be more then 1 event with rang r (due to bagging).
588
589 const Int_t numdata = GetNumData();
590 const Int_t dim = GetNumDim();
591
592 TArrayF v(numdata);
593 TArrayI isort(numdata);
594
595
596 for (Int_t mvar=0;mvar<dim;mvar++)
597 {
598
599 for(Int_t n=0;n<numdata;n++)
600 {
601 v[n]=(*fMatrix)(n,mvar);
602 //isort[n]=n;
603
604 if (!TMath::Finite(v[n]))
605 {
606 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
607 *fLog << err <<" has a non finite value (eg. NaN)."<<endl;
608 return kFALSE;
609 }
610 }
611
612 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
613
614 // this sorts the v[n] in ascending order. isort[n] is the
615 // event number of that v[n], which is the n-th from the
616 // lowest (assume the original event numbers are 0,1,...).
617
618 // control sorting
619 /*
620 for(int n=0;n<numdata-1;n++)
621 if(v[isort[n]]>v[isort[n+1]])
622 {
623 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
624 *fLog << err <<" not at correct sorting position."<<endl;
625 return kFALSE;
626 }
627 */
628
629 // DataRang is similar to the isort index starting from 0 it is
630 // increased by one for each event which is greater, but stays
631 // the same for the same value. (So to say it counts how many
632 // different values we have)
633 for(Int_t n=0;n<numdata-1;n++)
634 {
635 const Int_t n1=isort[n];
636 const Int_t n2=isort[n+1];
637
638 // FIXME: Copying isort[n] to fDataSort[mvar*numdata]
639 // can be accelerated!
640 fDataSort[mvar*numdata+n]=n1;
641 if(n==0) fDataRang[mvar*numdata+n1]=0;
642 if(v[n1]<v[n2])
643 {
644 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
645 }else{
646 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
647 }
648 }
649 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
650 }
651 return kTRUE;
652}
653
654// Reoves all indices which are not in the bag from the datsortinbag
655void MRanForest::ModifyDataSort(MArrayI &datsortinbag, const MArrayI &jinbag)
656{
657 const Int_t numdim=GetNumDim();
658 const Int_t numdata=GetNumData();
659
660 Int_t ninbag=0;
661 for (Int_t n=0;n<numdata;n++)
662 if(jinbag[n]==1) ninbag++;
663
664 for(Int_t m=0;m<numdim;m++)
665 {
666 Int_t *subsort = &datsortinbag[m*numdata];
667
668 Int_t k=0;
669 for(Int_t n=0;n<ninbag;n++)
670 {
671 if(jinbag[subsort[k]]==1)
672 {
673 subsort[n] = subsort[k];
674 k++;
675 }else{
676 for(Int_t j=k+1;j<numdata;j++)
677 {
678 if(jinbag[subsort[j]]==1)
679 {
680 subsort[n] = subsort[j];
681 k = j+1;
682 break;
683 }
684 }
685 }
686 }
687 }
688}
689
690Bool_t MRanForest::AsciiWrite(ostream &out) const
691{
692 Int_t n=0;
693 MRanTree *tree;
694 TIter forest(fForest);
695
696 while ((tree=(MRanTree*)forest.Next()))
697 {
698 tree->AsciiWrite(out);
699 n++;
700 }
701
702 return n==fNumTrees;
703}
Note: See TracBrowser for help on using the repository browser.