source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7685

Last change on this file since 7685 was 7685, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 17.9 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TVector.h>
45#include <TRandom.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49#include "MData.h"
50#include "MDataArray.h"
51#include "MParList.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title)
65 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
66 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
67{
68 fName = name ? name : "MRanForest";
69 fTitle = title ? title : "Storage container for Random Forest";
70
71 fForest=new TObjArray();
72 fForest->SetOwner(kTRUE);
73}
74
75MRanForest::MRanForest(const MRanForest &rf)
76{
77 // Copy constructor
78 fName = rf.fName;
79 fTitle = rf.fTitle;
80
81 fClassify = rf.fClassify;
82 fNumTrees = rf.fNumTrees;
83 fNumTry = rf.fNumTry;
84 fNdSize = rf.fNdSize;
85 fTreeNo = rf.fTreeNo;
86 fRanTree = NULL;
87
88 fRules=new MDataArray();
89 fRules->Reset();
90
91 MDataArray *newrules=rf.fRules;
92
93 for(Int_t i=0;i<newrules->GetNumEntries();i++)
94 {
95 MData &data=(*newrules)[i];
96 fRules->AddEntry(data.GetRule());
97 }
98
99 // trees
100 fForest=new TObjArray();
101 fForest->SetOwner(kTRUE);
102
103 TObjArray *newforest=rf.fForest;
104 for(Int_t i=0;i<newforest->GetEntries();i++)
105 {
106 MRanTree *rantree=(MRanTree*)newforest->At(i);
107
108 MRanTree *newtree=new MRanTree(*rantree);
109 fForest->Add(newtree);
110 }
111
112 fHadTrue = rf.fHadTrue;
113 fHadEst = rf.fHadEst;
114 fDataSort = rf.fDataSort;
115 fDataRang = rf.fDataRang;
116 fClassPop = rf.fClassPop;
117 fWeight = rf.fWeight;
118 fTreeHad = rf.fTreeHad;
119
120 fNTimesOutBag = rf.fNTimesOutBag;
121
122}
123
124// --------------------------------------------------------------------------
125// Destructor.
126MRanForest::~MRanForest()
127{
128 delete fForest;
129 if (fMatrix)
130 delete fMatrix;
131 if (fRules)
132 delete fRules;
133}
134
135void MRanForest::Print(Option_t *o) const
136{
137 *fLog << inf << GetDescriptor() << ": " << endl;
138 MRanTree *t = GetTree(0);
139 if (t)
140 {
141 *fLog << "Setting up RF for training on target:" << endl;
142 *fLog << " " << t->GetTitle() << endl;
143 }
144 if (fRules)
145 {
146 *fLog << "Following rules are used as input to RF:" << endl;
147 for (Int_t i=0;i<fRules->GetNumEntries();i++)
148 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
149 }
150 *fLog << "Random forest parameters:" << endl;
151 if (t)
152 {
153 *fLog << " - " << (t->IsClassify()?"classification":"regression") << " tree" << endl;
154 *fLog << " - Number of trys: " << t->GetNumTry() << endl;
155 *fLog << " - Node size: " << t->GetNdSize() << endl;
156 }
157 *fLog << " - Number of trees: " << fNumTrees << endl;
158 *fLog << " - User value: " << fUserVal << endl;
159 *fLog << endl;
160}
161
162void MRanForest::SetNumTrees(Int_t n)
163{
164 //at least 1 tree
165 fNumTrees=TMath::Max(n,1);
166}
167
168void MRanForest::SetNumTry(Int_t n)
169{
170 fNumTry=TMath::Max(n,0);
171}
172
173void MRanForest::SetNdSize(Int_t n)
174{
175 fNdSize=TMath::Max(n,1);
176}
177
178void MRanForest::SetWeights(const TArrayF &weights)
179{
180 fWeight=weights;
181}
182
183void MRanForest::SetGrid(const TArrayD &grid)
184{
185 const int n=grid.GetSize();
186
187 for(int i=0;i<n-1;i++)
188 if(grid[i]>=grid[i+1])
189 {
190 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
191 return;
192 }
193
194 fGrid=grid;
195
196 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
197 //for(int i=0;i<n;i++)
198 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
199}
200
201MRanTree *MRanForest::GetTree(Int_t i) const
202{
203 return static_cast<MRanTree*>(fForest->UncheckedAt(i));
204}
205
206Int_t MRanForest::GetNumDim() const
207{
208 return fMatrix ? fMatrix->GetNcols() : 0;
209}
210
211Int_t MRanForest::GetNumData() const
212{
213 return fMatrix ? fMatrix->GetNrows() : 0;
214}
215
216Int_t MRanForest::GetNclass() const
217{
218 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
219
220 return int(fClass[maxidx])+1;
221}
222
223void MRanForest::PrepareClasses()
224{
225 const int numdata=fHadTrue.GetSize();
226
227 if(fGrid.GetSize()>0)
228 {
229 // classes given by grid
230 const int ngrid=fGrid.GetSize();
231
232 for(int j=0;j<numdata;j++)
233 {
234 // Array is supposed to be sorted prior to this call.
235 // If match is found, function returns position of element.
236 // If no match found, function gives nearest element smaller
237 // than value.
238 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
239
240 fClass[j] = k;
241 }
242
243 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
244 int min = fClass[minidx];
245 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
246
247 }else{
248 // classes directly given
249 for (Int_t j=0;j<numdata;j++)
250 fClass[j] = TMath::Nint(fHadTrue[j]);
251 }
252}
253
254Double_t MRanForest::CalcHadroness()
255{
256 TVector event;
257 *fRules >> event;
258
259 return CalcHadroness(event);
260}
261
262Double_t MRanForest::CalcHadroness(const TVector &event)
263{
264 fTreeHad.Set(fNumTrees);
265
266 Double_t hadroness=0;
267 Int_t ntree =0;
268
269 TIter Next(fForest);
270
271 MRanTree *tree;
272 while ((tree=(MRanTree*)Next()))
273 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
274
275 return hadroness/ntree;
276}
277
278Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
279{
280 fRanTree = rantree ? rantree : fRanTree;
281
282 if (!fRanTree) return kFALSE;
283
284 MRanTree *newtree=new MRanTree(*fRanTree);
285 fForest->Add(newtree);
286
287 return kTRUE;
288}
289
290Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
291{
292 //-------------------------------------------------------------------
293 // access matrix, copy last column (target) preliminarily
294 // into fHadTrue
295 if (fMatrix)
296 delete fMatrix;
297 fMatrix = new TMatrix(mat->GetM());
298
299 int dim = fMatrix->GetNcols()-1;
300 int numdata = fMatrix->GetNrows();
301
302 fHadTrue.Set(numdata);
303 fHadTrue.Reset(0);
304
305 for (Int_t j=0;j<numdata;j++)
306 fHadTrue[j] = (*fMatrix)(j,dim);
307
308 // remove last col
309 fMatrix->ResizeTo(numdata,dim);
310
311 //-------------------------------------------------------------------
312 // setup labels for classification/regression
313 fClass.Set(numdata);
314 fClass.Reset(0);
315
316 if (fClassify)
317 PrepareClasses();
318
319 //-------------------------------------------------------------------
320 // allocating and initializing arrays
321 fHadEst.Set(numdata);
322 fHadEst.Reset(0);
323
324 fNTimesOutBag.Set(numdata);
325 fNTimesOutBag.Reset(0);
326
327 fDataSort.Set(dim*numdata);
328 fDataSort.Reset(0);
329
330 fDataRang.Set(dim*numdata);
331 fDataRang.Reset(0);
332
333 Bool_t useweights = fWeight.GetSize()==numdata;
334 if (!useweights)
335 {
336 fWeight.Set(numdata);
337 fWeight.Reset(1.);
338 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
339 }
340
341 //-------------------------------------------------------------------
342 // setup rules to be used for classification/regression
343 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
344 if (allrules==NULL)
345 {
346 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
347 return kFALSE;
348 }
349
350 if (allrules->GetNumEntries()!=dim+1)
351 {
352 *fLog << err <<"Rules of matrix " << allrules->GetNumEntries() << " mismatch dimension+1 " << dim+1 << "...exiting."<< endl;
353 return kFALSE;
354 }
355
356 if (fRules)
357 delete fRules;
358
359 fRules = new MDataArray();
360 fRules->Reset();
361
362 const TString target_rule = (*allrules)[dim].GetRule();
363 for (Int_t i=0;i<dim;i++)
364 fRules->AddEntry((*allrules)[i].GetRule());
365
366 *fLog << inf << endl;
367 *fLog << "Setting up RF for training on target:" << endl;
368 *fLog << " " << target_rule.Data() << endl;
369 *fLog << "Following rules are used as input to RF:" << endl;
370 for (Int_t i=0;i<dim;i++)
371 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
372 *fLog << endl;
373
374 //-------------------------------------------------------------------
375 // prepare (sort) data for fast optimization algorithm
376 if (!CreateDataSort())
377 return kFALSE;
378
379 //-------------------------------------------------------------------
380 // access and init tree container
381 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
382 if(!fRanTree)
383 {
384 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
385 return kFALSE;
386 }
387 //fRanTree->SetName(target_rule); // Is not stored anyhow
388
389 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
390
391 *fLog << inf << endl;
392 *fLog << "Following input for the tree growing are used:"<<endl;
393 *fLog << " Forest type : "<<(fClassify?"classification":"regression")<<endl;
394 *fLog << " Number of Trees : "<<fNumTrees<<endl;
395 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
396 *fLog << " Final Node size : "<<fNdSize<<endl;
397 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
398 *fLog << " Using Weights : "<<(useweights?"Yes":"No")<<endl;
399 *fLog << " Number of Events: "<<numdata<<endl;
400 *fLog << " Number of Params: "<<dim<<endl;
401
402 if(fNumTry==0)
403 {
404 fNumTry=tryest;
405 *fLog << inf << endl;
406 *fLog << "Set no. of trials to the recommended value of round(";
407 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
408 }
409
410 fRanTree->SetNumTry(fNumTry);
411 fRanTree->SetClassify(fClassify);
412 fRanTree->SetNdSize(fNdSize);
413
414 fTreeNo=0;
415
416 return kTRUE;
417}
418
419Bool_t MRanForest::GrowForest()
420{
421 if(!gRandom)
422 {
423 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
424 return kFALSE;
425 }
426
427 fTreeNo++;
428
429 //-------------------------------------------------------------------
430 // initialize running output
431
432 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
433 Bool_t calcResolution=(minfloat>0.001);
434
435 if (fTreeNo==1)
436 {
437 *fLog << inf << endl << underline;
438
439 if(calcResolution)
440 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl;
441 else
442 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl;
443 // 12345678901234567890123456789012345678901234567890
444 }
445
446 const Int_t numdata = GetNumData();
447 const Int_t nclass = GetNclass();
448
449 //-------------------------------------------------------------------
450 // bootstrap aggregating (bagging) -> sampling with replacement:
451
452 TArrayF classpopw(nclass);
453 TArrayI jinbag(numdata); // Initialization includes filling with 0
454 TArrayF winbag(numdata); // Initialization includes filling with 0
455
456 float square=0;
457 float mean=0;
458
459 for (Int_t n=0; n<numdata; n++)
460 {
461 // The integer k is randomly (uniformly) chosen from the set
462 // {0,1,...,numdata-1}, which is the set of the index numbers of
463 // all events in the training sample
464
465 const Int_t k = Int_t(gRandom->Rndm()*numdata);
466
467 if(fClassify)
468 classpopw[fClass[k]]+=fWeight[k];
469 else
470 classpopw[0]+=fWeight[k];
471
472 mean +=fHadTrue[k]*fWeight[k];
473 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
474
475 winbag[k]+=fWeight[k];
476 jinbag[k]=1;
477
478 }
479
480 //-------------------------------------------------------------------
481 // modifying sorted-data array for in-bag data:
482
483 // In bagging procedure ca. 2/3 of all elements in the original
484 // training sample are used to build the in-bag data
485 TArrayI datsortinbag=fDataSort;
486 Int_t ninbag=0;
487
488 ModifyDataSort(datsortinbag, ninbag, jinbag);
489
490 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
491 jinbag,winbag,nclass);
492
493 //-------------------------------------------------------------------
494 // error-estimates from out-of-bag data (oob data):
495 //
496 // For a single tree the events not(!) contained in the bootstrap sample of
497 // this tree can be used to obtain estimates for the classification error of
498 // this tree.
499 // If you take a certain event, it is contained in the oob-data of 1/3 of
500 // the trees (see comment to ModifyData). This means that the classification error
501 // determined from oob-data is underestimated, but can still be taken as upper limit.
502
503 for (Int_t ievt=0;ievt<numdata;ievt++)
504 {
505 if (jinbag[ievt]>0)
506 continue;
507
508 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
509 fNTimesOutBag[ievt]++;
510
511 }
512
513 Int_t n=0;
514 Float_t ferr=0;
515
516 for (Int_t ievt=0;ievt<numdata;ievt++)
517 {
518 if(fNTimesOutBag[ievt]!=0)
519 {
520 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
521 if(calcResolution) val/=fHadTrue[ievt];
522
523 ferr += val*val;
524 n++;
525 }
526 }
527 ferr = TMath::Sqrt(ferr/n);
528
529 //-------------------------------------------------------------------
530 // give running output
531 *fLog << setw(5) << fTreeNo;
532 *fLog << setw(18) << fRanTree->GetNumEndNodes();
533 *fLog << Form("%18.2f", ferr*100.);
534 *fLog << endl;
535
536 fRanTree->SetError(ferr);
537
538 // adding tree to forest
539 AddTree();
540
541 return fTreeNo<fNumTrees;
542}
543
544Bool_t MRanForest::CreateDataSort()
545{
546 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
547 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
548 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
549 //
550 // There may be more then 1 event with rang r (due to bagging).
551
552 const Int_t numdata = GetNumData();
553 const Int_t dim = GetNumDim();
554
555 TArrayF v(numdata);
556 TArrayI isort(numdata);
557
558
559 for (Int_t mvar=0;mvar<dim;mvar++)
560 {
561
562 for(Int_t n=0;n<numdata;n++)
563 {
564 v[n]=(*fMatrix)(n,mvar);
565 isort[n]=n;
566
567 if(TMath::IsNaN(v[n]))
568 {
569 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
570 *fLog << err <<" has the value NaN."<<endl;
571 return kFALSE;
572 }
573 }
574
575 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
576
577 // this sorts the v[n] in ascending order. isort[n] is the event number
578 // of that v[n], which is the n-th from the lowest (assume the original
579 // event numbers are 0,1,...).
580
581 // control sorting
582 for(int n=1;n<numdata;n++)
583 if(v[isort[n-1]]>v[isort[n]])
584 {
585 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
586 *fLog << err <<" not at correct sorting position."<<endl;
587 return kFALSE;
588 }
589
590 for(Int_t n=0;n<numdata-1;n++)
591 {
592 const Int_t n1=isort[n];
593 const Int_t n2=isort[n+1];
594
595 fDataSort[mvar*numdata+n]=n1;
596 if(n==0) fDataRang[mvar*numdata+n1]=0;
597 if(v[n1]<v[n2])
598 {
599 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
600 }else{
601 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
602 }
603 }
604 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
605 }
606 return kTRUE;
607}
608
609void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
610{
611 const Int_t numdim=GetNumDim();
612 const Int_t numdata=GetNumData();
613
614 ninbag=0;
615 for (Int_t n=0;n<numdata;n++)
616 if(jinbag[n]==1) ninbag++;
617
618 for(Int_t m=0;m<numdim;m++)
619 {
620 Int_t k=0;
621 Int_t nt=0;
622 for(Int_t n=0;n<numdata;n++)
623 {
624 if(jinbag[datsortinbag[m*numdata+k]]==1)
625 {
626 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
627 k++;
628 }else{
629 for(Int_t j=1;j<numdata-k;j++)
630 {
631 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
632 {
633 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
634 k+=j+1;
635 break;
636 }
637 }
638 }
639 nt++;
640 if(nt>=ninbag) break;
641 }
642 }
643}
644
645Bool_t MRanForest::AsciiWrite(ostream &out) const
646{
647 Int_t n=0;
648 MRanTree *tree;
649 TIter forest(fForest);
650
651 while ((tree=(MRanTree*)forest.Next()))
652 {
653 tree->AsciiWrite(out);
654 n++;
655 }
656
657 return n==fNumTrees;
658}
Note: See TracBrowser for help on using the repository browser.