source: trunk/Mars/mranforest/MRanForest.cc@ 10083

Last change on this file since 10083 was 8988, checked in by tbretz, 16 years ago
*** empty log message ***
File size: 19.4 KB
Line 
1/* ======================================================================== *\
2! $Name: not supported by cvs2svn $:$Id: MRanForest.cc,v 1.29 2008-06-30 09:36:39 tbretz Exp $
3! --------------------------------------------------------------------------
4!
5! *
6! * This file is part of MARS, the MAGIC Analysis and Reconstruction
7! * Software. It is distributed to you in the hope that it can be a useful
8! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
9! * It is distributed WITHOUT ANY WARRANTY.
10! *
11! * Permission to use, copy, modify and distribute this software and its
12! * documentation for any purpose is hereby granted without fee,
13! * provided that the above copyright notice appear in all copies and
14! * that both that copyright notice and this permission notice appear
15! * in supporting documentation. It is provided "as is" without express
16! * or implied warranty.
17! *
18!
19!
20! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
21! Author(s): Thomas Bretz <mailto:tbretz@astro.uni-wuerzburg.de>
22!
23! Copyright: MAGIC Software Development, 2000-2006
24!
25!
26\* ======================================================================== */
27
28/////////////////////////////////////////////////////////////////////////////
29//
30// MRanForest
31//
32// ParameterContainer for Forest structure
33//
34// A random forest can be grown by calling GrowForest.
35// In advance SetupGrow must be called in order to initialize arrays and
36// do some preprocessing.
37// GrowForest() provides the training data for a single tree (bootstrap
38// aggregate procedure)
39//
40// Essentially two random elements serve to provide a "random" forest,
41// namely bootstrap aggregating (which is done in GrowForest()) and random
42// split selection (which is subject to MRanTree::GrowTree())
43//
44/////////////////////////////////////////////////////////////////////////////
45#include "MRanForest.h"
46
47#include <TMath.h>
48#include <TRandom.h>
49
50#include "MHMatrix.h"
51#include "MRanTree.h"
52#include "MData.h"
53#include "MDataArray.h"
54#include "MParList.h"
55
56#include "MArrayI.h"
57#include "MArrayF.h"
58
59#include "MLog.h"
60#include "MLogManip.h"
61
62ClassImp(MRanForest);
63
64using namespace std;
65
66// --------------------------------------------------------------------------
67//
68// Default constructor.
69//
70MRanForest::MRanForest(const char *name, const char *title)
71 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
72 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
73{
74 fName = name ? name : "MRanForest";
75 fTitle = title ? title : "Storage container for Random Forest";
76
77 fForest=new TObjArray();
78 fForest->SetOwner(kTRUE);
79}
80
81MRanForest::MRanForest(const MRanForest &rf)
82{
83 // Copy constructor
84 fName = rf.fName;
85 fTitle = rf.fTitle;
86
87 fClassify = rf.fClassify;
88 fNumTrees = rf.fNumTrees;
89 fNumTry = rf.fNumTry;
90 fNdSize = rf.fNdSize;
91 fTreeNo = rf.fTreeNo;
92 fRanTree = NULL;
93
94 fRules=new MDataArray();
95 fRules->Reset();
96
97 MDataArray *newrules=rf.fRules;
98
99 for(Int_t i=0;i<newrules->GetNumEntries();i++)
100 {
101 MData &data=(*newrules)[i];
102 fRules->AddEntry(data.GetRule());
103 }
104
105 // trees
106 fForest=new TObjArray();
107 fForest->SetOwner(kTRUE);
108
109 TObjArray *newforest=rf.fForest;
110 for(Int_t i=0;i<newforest->GetEntries();i++)
111 {
112 MRanTree *rantree=(MRanTree*)newforest->At(i);
113
114 MRanTree *newtree=new MRanTree(*rantree);
115 fForest->Add(newtree);
116 }
117
118 fHadTrue = rf.fHadTrue;
119 fHadEst = rf.fHadEst;
120 fDataSort = rf.fDataSort;
121 fDataRang = rf.fDataRang;
122 fClassPop = rf.fClassPop;
123 fWeight = rf.fWeight;
124 fTreeHad = rf.fTreeHad;
125
126 fNTimesOutBag = rf.fNTimesOutBag;
127}
128
129// --------------------------------------------------------------------------
130// Destructor.
131MRanForest::~MRanForest()
132{
133 delete fForest;
134 if (fMatrix)
135 delete fMatrix;
136 if (fRules)
137 delete fRules;
138}
139
140void MRanForest::Print(Option_t *o) const
141{
142 *fLog << inf << GetDescriptor() << ": " << endl;
143 MRanTree *t = GetTree(0);
144 if (t)
145 {
146 *fLog << "Setting up RF for training on target:" << endl;
147 *fLog << " " << t->GetTitle() << endl;
148 }
149 if (fRules)
150 {
151 *fLog << "Following rules are used as input to RF:" << endl;
152 for (Int_t i=0;i<fRules->GetNumEntries();i++)
153 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
154 }
155 *fLog << "Random forest parameters:" << endl;
156 if (t)
157 {
158 *fLog << " - " << (t->IsClassify()?"classification":"regression") << " tree" << endl;
159 *fLog << " - Number of trys: " << t->GetNumTry() << endl;
160 *fLog << " - Node size: " << t->GetNdSize() << endl;
161 }
162 *fLog << " - Number of trees: " << fNumTrees << endl;
163 *fLog << " - User value: " << fUserVal << endl;
164 *fLog << endl;
165}
166
167void MRanForest::SetNumTrees(Int_t n)
168{
169 //at least 1 tree
170 fNumTrees=TMath::Max(n,1);
171}
172
173void MRanForest::SetNumTry(Int_t n)
174{
175 fNumTry=TMath::Max(n,0);
176}
177
178void MRanForest::SetNdSize(Int_t n)
179{
180 fNdSize=TMath::Max(n,1);
181}
182
183void MRanForest::SetWeights(const TArrayF &weights)
184{
185 fWeight=weights;
186}
187
188void MRanForest::SetGrid(const TArrayD &grid)
189{
190 const int n=grid.GetSize();
191
192 for(int i=0;i<n-1;i++)
193 if(grid[i]>=grid[i+1])
194 {
195 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
196 return;
197 }
198
199 fGrid=grid;
200
201 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
202 //for(int i=0;i<n;i++)
203 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
204}
205
206MRanTree *MRanForest::GetTree(Int_t i) const
207{
208 return static_cast<MRanTree*>(fForest->UncheckedAt(i));
209}
210
211Int_t MRanForest::GetNumDim() const
212{
213 return fMatrix ? fMatrix->GetNcols() : 0;
214}
215
216Int_t MRanForest::GetNumData() const
217{
218 return fMatrix ? fMatrix->GetNrows() : 0;
219}
220
221Int_t MRanForest::GetNclass() const
222{
223 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
224
225 return int(fClass[maxidx])+1;
226}
227
228void MRanForest::PrepareClasses()
229{
230 const int numdata=fHadTrue.GetSize();
231
232 if(fGrid.GetSize()>0)
233 {
234 // classes given by grid
235 const int ngrid=fGrid.GetSize();
236
237 for(int j=0;j<numdata;j++)
238 {
239 // Array is supposed to be sorted prior to this call.
240 // If match is found, function returns position of element.
241 // If no match found, function gives nearest element smaller
242 // than value.
243 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), (Double_t)fHadTrue[j]);
244
245 fClass[j] = k;
246 }
247
248 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
249 int min = fClass[minidx];
250 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
251
252 }else{
253 // classes directly given
254 for (Int_t j=0;j<numdata;j++)
255 fClass[j] = TMath::Nint(fHadTrue[j]);
256 }
257}
258
259Double_t MRanForest::CalcHadroness()
260{
261 TVector event;
262 *fRules >> event;
263
264 return CalcHadroness(event);
265}
266
267Double_t MRanForest::CalcHadroness(const TVector &event)
268{
269 fTreeHad.Set(fNumTrees);
270
271 Double_t hadroness=0;
272 Int_t ntree =0;
273
274 TIter Next(fForest);
275
276 MRanTree *tree;
277 while ((tree=(MRanTree*)Next()))
278 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
279
280 return hadroness/ntree;
281}
282
283Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
284{
285 fRanTree = rantree ? rantree : fRanTree;
286
287 if (!fRanTree) return kFALSE;
288
289 MRanTree *newtree=new MRanTree(*fRanTree);
290 fForest->Add(newtree);
291
292 return kTRUE;
293}
294
295Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
296{
297 //-------------------------------------------------------------------
298 // access matrix, copy last column (target) preliminarily
299 // into fHadTrue
300 if (fMatrix)
301 delete fMatrix;
302 fMatrix = new TMatrix(mat->GetM());
303
304 int dim = fMatrix->GetNcols()-1;
305 int numdata = fMatrix->GetNrows();
306
307 fHadTrue.Set(numdata);
308 fHadTrue.Reset();
309
310 for (Int_t j=0;j<numdata;j++)
311 fHadTrue[j] = (*fMatrix)(j,dim);
312
313 // remove last col
314 fMatrix->ResizeTo(numdata,dim);
315
316 //-------------------------------------------------------------------
317 // setup labels for classification/regression
318 fClass.Set(numdata);
319 fClass.Reset();
320
321 if (fClassify)
322 PrepareClasses();
323
324 //-------------------------------------------------------------------
325 // allocating and initializing arrays
326 fHadEst.Set(numdata);
327 fHadEst.Reset();
328
329 fNTimesOutBag.Set(numdata);
330 fNTimesOutBag.Reset();
331
332 fDataSort.Set(dim*numdata);
333 fDataSort.Reset();
334
335 fDataRang.Set(dim*numdata);
336 fDataRang.Reset();
337
338 Bool_t useweights = fWeight.GetSize()==numdata;
339 if (!useweights)
340 {
341 fWeight.Set(numdata);
342 fWeight.Reset(1.);
343 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
344 }
345
346 //-------------------------------------------------------------------
347 // setup rules to be used for classification/regression
348 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
349 if (allrules==NULL)
350 {
351 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
352 return kFALSE;
353 }
354
355 if (allrules->GetNumEntries()!=dim+1)
356 {
357 *fLog << err <<"Rules of matrix " << allrules->GetNumEntries() << " mismatch dimension+1 " << dim+1 << "...exiting."<< endl;
358 return kFALSE;
359 }
360
361 if (fRules)
362 delete fRules;
363
364 fRules = new MDataArray();
365 fRules->Reset();
366
367 const TString target_rule = (*allrules)[dim].GetRule();
368 for (Int_t i=0;i<dim;i++)
369 fRules->AddEntry((*allrules)[i].GetRule());
370
371 *fLog << inf << endl;
372 *fLog << "Setting up RF for training on target:" << endl;
373 *fLog << " " << target_rule.Data() << endl;
374 *fLog << "Following rules are used as input to RF:" << endl;
375 for (Int_t i=0;i<dim;i++)
376 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
377 *fLog << endl;
378
379 //-------------------------------------------------------------------
380 // prepare (sort) data for fast optimization algorithm
381 if (!CreateDataSort())
382 return kFALSE;
383
384 //-------------------------------------------------------------------
385 // access and init tree container
386 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
387 if(!fRanTree)
388 {
389 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
390 return kFALSE;
391 }
392 //fRanTree->SetName(target_rule); // Is not stored anyhow
393
394 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
395
396 *fLog << inf << endl;
397 *fLog << "Following input for the tree growing are used:"<<endl;
398 *fLog << " Forest type : "<<(fClassify?"classification":"regression")<<endl;
399 *fLog << " Number of Trees : "<<fNumTrees<<endl;
400 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
401 *fLog << " Final Node size : "<<fNdSize<<endl;
402 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
403 *fLog << " Using Weights : "<<(useweights?"Yes":"No")<<endl;
404 *fLog << " Number of Events: "<<numdata<<endl;
405 *fLog << " Number of Params: "<<dim<<endl;
406
407 if(fNumTry==0)
408 {
409 fNumTry=tryest;
410 *fLog << inf << endl;
411 *fLog << "Set no. of trials to the recommended value of round(";
412 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
413 }
414
415 fRanTree->SetNumTry(fNumTry);
416 fRanTree->SetClassify(fClassify);
417 fRanTree->SetNdSize(fNdSize);
418
419 fTreeNo=0;
420
421 return kTRUE;
422}
423
424Bool_t MRanForest::GrowForest()
425{
426 if(!gRandom)
427 {
428 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
429 return kFALSE;
430 }
431
432 fTreeNo++;
433
434 //-------------------------------------------------------------------
435 // initialize running output
436
437 float minfloat=TMath::MinElement(fHadTrue.GetSize(),fHadTrue.GetArray());
438 Bool_t calcResolution=(minfloat>FLT_MIN);
439
440 if (fTreeNo==1)
441 {
442 *fLog << inf << endl << underline;
443
444 if(calcResolution)
445 *fLog << "TreeNum BagSize NumNodes TestSize Bias/% var/% res/% (from oob-data)" << endl;
446 else
447 *fLog << "TreeNum BagSize NumNodes TestSize Bias/au var/au rms/au (from oob-data)" << endl;
448 // 12345678901234567890123456789012345678901234567890
449 }
450
451 const Int_t numdata = GetNumData();
452 const Int_t nclass = GetNclass();
453
454 //-------------------------------------------------------------------
455 // bootstrap aggregating (bagging) -> sampling with replacement:
456
457 MArrayF classpopw(nclass);
458 MArrayI jinbag(numdata); // Initialization includes filling with 0
459 MArrayF winbag(numdata); // Initialization includes filling with 0
460
461 float square=0;
462 float mean=0;
463
464 for (Int_t n=0; n<numdata; n++)
465 {
466 // The integer k is randomly (uniformly) chosen from the set
467 // {0,1,...,numdata-1}, which is the set of the index numbers of
468 // all events in the training sample
469
470 const Int_t k = gRandom->Integer(numdata);
471
472 if(fClassify)
473 classpopw[fClass[k]]+=fWeight[k];
474 else
475 classpopw[0]+=fWeight[k];
476
477 mean +=fHadTrue[k]*fWeight[k];
478 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
479
480 winbag[k]+=fWeight[k]; // Increase weight if chosen more than once
481 jinbag[k]=1;
482 }
483
484 //-------------------------------------------------------------------
485 // modifying sorted-data array for in-bag data:
486
487 // In bagging procedure ca. 2/3 of all elements in the original
488 // training sample are used to build the in-bag data
489 const MArrayF hadtrue(fHadTrue.GetSize(), fHadTrue.GetArray());
490 const MArrayI fclass(fClass.GetSize(), fClass.GetArray());
491 const MArrayI datarang(fDataRang.GetSize(), fDataRang.GetArray());
492
493 MArrayI datsortinbag(fDataSort.GetSize(), fDataSort.GetArray());
494
495 ModifyDataSort(datsortinbag, jinbag);
496
497 fRanTree->GrowTree(fMatrix,hadtrue,fclass,datsortinbag,datarang,classpopw,mean,square,
498 jinbag,winbag,nclass);
499
500 const Double_t ferr = EstimateError(jinbag, calcResolution);
501
502 fRanTree->SetError(ferr);
503
504 // adding tree to forest
505 AddTree();
506
507 return fTreeNo<fNumTrees;
508}
509
510//-------------------------------------------------------------------
511// error-estimates from out-of-bag data (oob data):
512//
513// For a single tree the events not(!) contained in the bootstrap
514// sample of this tree can be used to obtain estimates for the
515// classification error of this tree.
516// If you take a certain event, it is contained in the oob-data of
517// 1/3 of the trees (see comment to ModifyData). This means that the
518// classification error determined from oob-data is underestimated,
519// but can still be taken as upper limit.
520//
521Double_t MRanForest::EstimateError(const MArrayI &jinbag, Bool_t calcResolution)
522{
523 const Int_t numdata = GetNumData();
524
525 Int_t ninbag = 0;
526 for (Int_t ievt=0;ievt<numdata;ievt++)
527 {
528 if (jinbag[ievt]>0)
529 {
530 ninbag++;
531 continue;
532 }
533
534 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
535 fNTimesOutBag[ievt]++;
536 }
537
538 Int_t n=0;
539
540 Double_t sum=0;
541 Double_t sq =0;
542 for (Int_t i=0; i<numdata; i++)
543 {
544 if (fNTimesOutBag[i]==0)
545 continue;
546
547 const Float_t hadest = fHadEst[i]/fNTimesOutBag[i];
548
549 const Float_t val = calcResolution ?
550 hadest/fHadTrue[i] - 1 : hadest - fHadTrue[i];
551
552 sum += val;
553 sq += val*val;
554 n++;
555 }
556
557 if (calcResolution)
558 {
559 sum *= 100;
560 sq *= 10000;
561 }
562
563 sum /= n;
564 sq /= n;
565
566 const Double_t var = TMath::Sqrt(sq-sum*sum);
567 const Double_t ferr = TMath::Sqrt(sq);
568
569 //-------------------------------------------------------------------
570 // give running output
571 *fLog << setw(4) << fTreeNo;
572 *fLog << Form(" %8.1f", 100.*ninbag/numdata);
573 *fLog << setw(9) << fRanTree->GetNumEndNodes();
574 *fLog << Form(" %9.1f", 100.*n/numdata);
575 *fLog << Form(" %7.2f", sum);
576 *fLog << Form(" %7.2f", var);
577 *fLog << Form(" %7.2f", ferr);
578 *fLog << endl;
579
580 return ferr;
581}
582
583Bool_t MRanForest::CreateDataSort()
584{
585 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
586 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
587 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
588 //
589 // There may be more then 1 event with rang r (due to bagging).
590
591 const Int_t numdata = GetNumData();
592 const Int_t dim = GetNumDim();
593
594 TArrayF v(numdata);
595 TArrayI isort(numdata);
596
597
598 for (Int_t mvar=0;mvar<dim;mvar++)
599 {
600
601 for(Int_t n=0;n<numdata;n++)
602 {
603 v[n]=(*fMatrix)(n,mvar);
604 //isort[n]=n;
605
606 if (!TMath::Finite(v[n]))
607 {
608 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
609 *fLog << err <<" has a non finite value (eg. NaN)."<<endl;
610 return kFALSE;
611 }
612 }
613
614 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
615
616 // this sorts the v[n] in ascending order. isort[n] is the
617 // event number of that v[n], which is the n-th from the
618 // lowest (assume the original event numbers are 0,1,...).
619
620 // control sorting
621 /*
622 for(int n=0;n<numdata-1;n++)
623 if(v[isort[n]]>v[isort[n+1]])
624 {
625 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
626 *fLog << err <<" not at correct sorting position."<<endl;
627 return kFALSE;
628 }
629 */
630
631 // DataRang is similar to the isort index starting from 0 it is
632 // increased by one for each event which is greater, but stays
633 // the same for the same value. (So to say it counts how many
634 // different values we have)
635 for(Int_t n=0;n<numdata-1;n++)
636 {
637 const Int_t n1=isort[n];
638 const Int_t n2=isort[n+1];
639
640 // FIXME: Copying isort[n] to fDataSort[mvar*numdata]
641 // can be accelerated!
642 fDataSort[mvar*numdata+n]=n1;
643 if(n==0) fDataRang[mvar*numdata+n1]=0;
644 if(v[n1]<v[n2])
645 {
646 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
647 }else{
648 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
649 }
650 }
651 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
652 }
653 return kTRUE;
654}
655
656// Reoves all indices which are not in the bag from the datsortinbag
657void MRanForest::ModifyDataSort(MArrayI &datsortinbag, const MArrayI &jinbag)
658{
659 const Int_t numdim=GetNumDim();
660 const Int_t numdata=GetNumData();
661
662 Int_t ninbag=0;
663 for (Int_t n=0;n<numdata;n++)
664 if(jinbag[n]==1) ninbag++;
665
666 for(Int_t m=0;m<numdim;m++)
667 {
668 Int_t *subsort = &datsortinbag[m*numdata];
669
670 Int_t k=0;
671 for(Int_t n=0;n<ninbag;n++)
672 {
673 if(jinbag[subsort[k]]==1)
674 {
675 subsort[n] = subsort[k];
676 k++;
677 }else{
678 for(Int_t j=k+1;j<numdata;j++)
679 {
680 if(jinbag[subsort[j]]==1)
681 {
682 subsort[n] = subsort[j];
683 k = j+1;
684 break;
685 }
686 }
687 }
688 }
689 }
690}
691
692Bool_t MRanForest::AsciiWrite(ostream &out) const
693{
694 Int_t n=0;
695 MRanTree *tree;
696 TIter forest(fForest);
697
698 while ((tree=(MRanTree*)forest.Next()))
699 {
700 tree->AsciiWrite(out);
701 n++;
702 }
703
704 return n==fNumTrees;
705}
Note: See TracBrowser for help on using the repository browser.