source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7419

Last change on this file since 7419 was 7417, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 16.8 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TVector.h>
45#include <TRandom.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49#include "MData.h"
50#include "MDataArray.h"
51#include "MParList.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1)
65{
66 fName = name ? name : "MRanForest";
67 fTitle = title ? title : "Storage container for Random Forest";
68
69 fForest=new TObjArray();
70 fForest->SetOwner(kTRUE);
71}
72
73MRanForest::MRanForest(const MRanForest &rf)
74{
75 // Copy constructor
76 fName = rf.fName;
77 fTitle = rf.fTitle;
78
79 fClassify = rf.fClassify;
80 fNumTrees = rf.fNumTrees;
81 fNumTry = rf.fNumTry;
82 fNdSize = rf.fNdSize;
83 fTreeNo = rf.fTreeNo;
84 fRanTree = NULL;
85
86 fRules=new MDataArray();
87 fRules->Reset();
88
89 MDataArray *newrules=rf.fRules;
90
91 for(Int_t i=0;i<newrules->GetNumEntries();i++)
92 {
93 MData &data=(*newrules)[i];
94 fRules->AddEntry(data.GetRule());
95 }
96
97 // trees
98 fForest=new TObjArray();
99 fForest->SetOwner(kTRUE);
100
101 TObjArray *newforest=rf.fForest;
102 for(Int_t i=0;i<newforest->GetEntries();i++)
103 {
104 MRanTree *rantree=(MRanTree*)newforest->At(i);
105
106 MRanTree *newtree=new MRanTree(*rantree);
107 fForest->Add(newtree);
108 }
109
110 fHadTrue = rf.fHadTrue;
111 fHadEst = rf.fHadEst;
112 fDataSort = rf.fDataSort;
113 fDataRang = rf.fDataRang;
114 fClassPop = rf.fClassPop;
115 fWeight = rf.fWeight;
116 fTreeHad = rf.fTreeHad;
117
118 fNTimesOutBag = rf.fNTimesOutBag;
119
120}
121
122// --------------------------------------------------------------------------
123// Destructor.
124MRanForest::~MRanForest()
125{
126 delete fForest;
127}
128
129MRanTree *MRanForest::GetTree(Int_t i)
130{
131 return (MRanTree*)(fForest->At(i));
132}
133
134void MRanForest::SetNumTrees(Int_t n)
135{
136 //at least 1 tree
137 fNumTrees=TMath::Max(n,1);
138 fTreeHad.Set(fNumTrees);
139 fTreeHad.Reset();
140}
141
142void MRanForest::SetNumTry(Int_t n)
143{
144 fNumTry=TMath::Max(n,0);
145}
146
147void MRanForest::SetNdSize(Int_t n)
148{
149 fNdSize=TMath::Max(n,1);
150}
151
152void MRanForest::SetWeights(const TArrayF &weights)
153{
154 const int n=weights.GetSize();
155 fWeight.Set(n);
156 fWeight=weights;
157
158 return;
159}
160
161void MRanForest::SetGrid(const TArrayD &grid)
162{
163 const int n=grid.GetSize();
164
165 for(int i=0;i<n-1;i++)
166 if(grid[i]>=grid[i+1])
167 {
168 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
169 return;
170 }
171
172 fGrid=grid;
173
174 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
175 //for(int i=0;i<n;i++)
176 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
177
178 return;
179}
180
181Int_t MRanForest::GetNumDim() const
182{
183 return fMatrix ? fMatrix->GetNcols() : 0;
184}
185
186Int_t MRanForest::GetNumData() const
187{
188 return fMatrix ? fMatrix->GetNrows() : 0;
189}
190
191Int_t MRanForest::GetNclass() const
192{
193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
194
195 return int(fClass[maxidx])+1;
196}
197
198void MRanForest::PrepareClasses()
199{
200 const int numdata=fHadTrue.GetSize();
201
202 if(fGrid.GetSize()>0)
203 {
204 // classes given by grid
205 const int ngrid=fGrid.GetSize();
206
207 for(int j=0;j<numdata;j++)
208 {
209 // Array is supposed to be sorted prior to this call.
210 // If match is found, function returns position of element.
211 // If no match found, function gives nearest element smaller
212 // than value.
213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
214
215 fClass[j] = k;
216 }
217
218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
219 int min = fClass[minidx];
220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
221
222 }else{
223 // classes directly given
224 for (Int_t j=0;j<numdata;j++)
225 fClass[j] = int(fHadTrue[j]+0.5);
226 }
227}
228
229/*
230Bool_t MRanForest::PreProcess(MParList *plist)
231{
232 if (!fRules)
233 {
234 *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl;
235 return kFALSE;
236 }
237
238 if (!fRules->PreProcess(plist))
239 {
240 *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl;
241 return kFALSE;
242 }
243
244 return kTRUE;
245}
246*/
247
248Double_t MRanForest::CalcHadroness()
249{
250 TVector event;
251 *fRules >> event;
252
253 return CalcHadroness(event);
254}
255
256Double_t MRanForest::CalcHadroness(const TVector &event)
257{
258 Double_t hadroness=0;
259 Int_t ntree=0;
260
261 TIter Next(fForest);
262
263 MRanTree *tree;
264 while ((tree=(MRanTree*)Next()))
265 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
266
267 return hadroness/ntree;
268}
269
270Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
271{
272 fRanTree = rantree ? rantree : fRanTree;
273
274 if (!fRanTree) return kFALSE;
275
276 MRanTree *newtree=new MRanTree(*fRanTree);
277 fForest->Add(newtree);
278
279 return kTRUE;
280}
281
282Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
283{
284 //-------------------------------------------------------------------
285 // access matrix, copy last column (target) preliminarily
286 // into fHadTrue
287 fMatrix = new TMatrix(mat->GetM());
288
289 int dim = fMatrix->GetNcols()-1;
290 int numdata = fMatrix->GetNrows();
291
292 fHadTrue.Set(numdata);
293 fHadTrue.Reset(0);
294
295 for (Int_t j=0;j<numdata;j++)
296 fHadTrue[j] = (*fMatrix)(j,dim);
297
298 // remove last col
299 fMatrix->ResizeTo(numdata,dim);
300
301 //-------------------------------------------------------------------
302 // setup labels for classification/regression
303 fClass.Set(numdata);
304 fClass.Reset(0);
305
306 if (fClassify)
307 PrepareClasses();
308
309 //-------------------------------------------------------------------
310 // allocating and initializing arrays
311 fHadEst.Set(numdata);
312 fHadEst.Reset(0);
313
314 fNTimesOutBag.Set(numdata);
315 fNTimesOutBag.Reset(0);
316
317 fDataSort.Set(dim*numdata);
318 fDataSort.Reset(0);
319
320 fDataRang.Set(dim*numdata);
321 fDataRang.Reset(0);
322
323 if(fWeight.GetSize()!=numdata)
324 {
325 fWeight.Set(numdata);
326 fWeight.Reset(1.);
327 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
328 }
329
330 //-------------------------------------------------------------------
331 // setup rules to be used for classification/regression
332 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
333 if(allrules==NULL)
334 {
335 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
336 return kFALSE;
337 }
338
339 fRules = new MDataArray();
340 fRules->Reset();
341
342 const TString target_rule = (*allrules)[dim].GetRule();
343 for (Int_t i=0;i<dim;i++)
344 fRules->AddEntry((*allrules)[i].GetRule());
345
346 *fLog << inf << endl;
347 *fLog << "Setting up RF for training on target:" << endl;
348 *fLog << " " << target_rule.Data() << endl;
349 *fLog << "Following rules are used as input to RF:" << endl;
350 for (Int_t i=0;i<dim;i++)
351 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
352
353 *fLog << endl;
354
355 //-------------------------------------------------------------------
356 // prepare (sort) data for fast optimization algorithm
357 if (!CreateDataSort())
358 return kFALSE;
359
360 //-------------------------------------------------------------------
361 // access and init tree container
362 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
363 if(!fRanTree)
364 {
365 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
366 return kFALSE;
367 }
368
369 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
370
371 *fLog << inf << endl;
372 *fLog << "Following input for the tree growing are used:"<<endl;
373 *fLog << " Number of Trees : "<<fNumTrees<<endl;
374 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
375 *fLog << " Final Node size : "<<fNdSize<<endl;
376 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
377 *fLog << " Number of Events: "<<numdata<<endl;
378 *fLog << " Number of Params: "<<dim<<endl;
379
380 if(fNumTry==0)
381 {
382 fNumTry=tryest;
383 *fLog << inf << endl;
384 *fLog << "Set no. of trials to the recommended value of round(";
385 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
386 }
387
388 fRanTree->SetNumTry(fNumTry);
389 fRanTree->SetClassify(fClassify);
390 fRanTree->SetNdSize(fNdSize);
391
392 fTreeNo=0;
393
394 return kTRUE;
395}
396
397Bool_t MRanForest::GrowForest()
398{
399 if(!gRandom)
400 {
401 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
402 return kFALSE;
403 }
404
405 fTreeNo++;
406
407 //-------------------------------------------------------------------
408 // initialize running output
409
410 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
411 Bool_t calcResolution=(minfloat>0.001);
412
413 if (fTreeNo==1)
414 {
415 *fLog << inf << endl << underline;
416
417 if(calcResolution)
418 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl;
419 else
420 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl;
421 // 12345678901234567890123456789012345678901234567890
422 }
423
424 const Int_t numdata = GetNumData();
425 const Int_t nclass = GetNclass();
426
427 //-------------------------------------------------------------------
428 // bootstrap aggregating (bagging) -> sampling with replacement:
429
430 TArrayF classpopw(nclass);
431 TArrayI jinbag(numdata); // Initialization includes filling with 0
432 TArrayF winbag(numdata); // Initialization includes filling with 0
433
434 float square=0;
435 float mean=0;
436
437 for (Int_t n=0; n<numdata; n++)
438 {
439 // The integer k is randomly (uniformly) chosen from the set
440 // {0,1,...,numdata-1}, which is the set of the index numbers of
441 // all events in the training sample
442
443 const Int_t k = Int_t(gRandom->Rndm()*numdata);
444
445 if(fClassify)
446 classpopw[fClass[k]]+=fWeight[k];
447 else
448 classpopw[0]+=fWeight[k];
449
450 mean +=fHadTrue[k]*fWeight[k];
451 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
452
453 winbag[k]+=fWeight[k];
454 jinbag[k]=1;
455
456 }
457
458 //-------------------------------------------------------------------
459 // modifying sorted-data array for in-bag data:
460
461 // In bagging procedure ca. 2/3 of all elements in the original
462 // training sample are used to build the in-bag data
463 TArrayI datsortinbag=fDataSort;
464 Int_t ninbag=0;
465
466 ModifyDataSort(datsortinbag, ninbag, jinbag);
467
468 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
469 jinbag,winbag,nclass);
470
471 //-------------------------------------------------------------------
472 // error-estimates from out-of-bag data (oob data):
473 //
474 // For a single tree the events not(!) contained in the bootstrap sample of
475 // this tree can be used to obtain estimates for the classification error of
476 // this tree.
477 // If you take a certain event, it is contained in the oob-data of 1/3 of
478 // the trees (see comment to ModifyData). This means that the classification error
479 // determined from oob-data is underestimated, but can still be taken as upper limit.
480
481 for (Int_t ievt=0;ievt<numdata;ievt++)
482 {
483 if (jinbag[ievt]>0)
484 continue;
485
486 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
487 fNTimesOutBag[ievt]++;
488
489 }
490
491 Int_t n=0;
492 Float_t ferr=0;
493
494 for (Int_t ievt=0;ievt<numdata;ievt++)
495 {
496 if(fNTimesOutBag[ievt]!=0)
497 {
498 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
499 if(calcResolution) val/=fHadTrue[ievt];
500
501 ferr += val*val;
502 n++;
503 }
504 }
505 ferr = TMath::Sqrt(ferr/n);
506
507 //-------------------------------------------------------------------
508 // give running output
509 *fLog << setw(5) << fTreeNo;
510 *fLog << setw(18) << fRanTree->GetNumEndNodes();
511 *fLog << Form("%18.2f", ferr*100.);
512 *fLog << endl;
513
514 fRanTree->SetError(ferr);
515
516 // adding tree to forest
517 AddTree();
518
519 return fTreeNo<fNumTrees;
520}
521
522Bool_t MRanForest::CreateDataSort()
523{
524 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
525 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
526 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
527 //
528 // There may be more then 1 event with rang r (due to bagging).
529
530 const Int_t numdata = GetNumData();
531 const Int_t dim = GetNumDim();
532
533 TArrayF v(numdata);
534 TArrayI isort(numdata);
535
536
537 for (Int_t mvar=0;mvar<dim;mvar++)
538 {
539
540 for(Int_t n=0;n<numdata;n++)
541 {
542 v[n]=(*fMatrix)(n,mvar);
543 isort[n]=n;
544
545 if(TMath::IsNaN(v[n]))
546 {
547 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
548 *fLog << err <<" has the value NaN."<<endl;
549 return kFALSE;
550 }
551 }
552
553 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
554
555 // this sorts the v[n] in ascending order. isort[n] is the event number
556 // of that v[n], which is the n-th from the lowest (assume the original
557 // event numbers are 0,1,...).
558
559 // control sorting
560 for(int n=1;n<numdata;n++)
561 if(v[isort[n-1]]>v[isort[n]])
562 {
563 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
564 *fLog << err <<" not at correct sorting position."<<endl;
565 return kFALSE;
566 }
567
568 for(Int_t n=0;n<numdata-1;n++)
569 {
570 const Int_t n1=isort[n];
571 const Int_t n2=isort[n+1];
572
573 fDataSort[mvar*numdata+n]=n1;
574 if(n==0) fDataRang[mvar*numdata+n1]=0;
575 if(v[n1]<v[n2])
576 {
577 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
578 }else{
579 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
580 }
581 }
582 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
583 }
584 return kTRUE;
585}
586
587void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
588{
589 const Int_t numdim=GetNumDim();
590 const Int_t numdata=GetNumData();
591
592 ninbag=0;
593 for (Int_t n=0;n<numdata;n++)
594 if(jinbag[n]==1) ninbag++;
595
596 for(Int_t m=0;m<numdim;m++)
597 {
598 Int_t k=0;
599 Int_t nt=0;
600 for(Int_t n=0;n<numdata;n++)
601 {
602 if(jinbag[datsortinbag[m*numdata+k]]==1)
603 {
604 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
605 k++;
606 }else{
607 for(Int_t j=1;j<numdata-k;j++)
608 {
609 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
610 {
611 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
612 k+=j+1;
613 break;
614 }
615 }
616 }
617 nt++;
618 if(nt>=ninbag) break;
619 }
620 }
621}
622
623Bool_t MRanForest::AsciiWrite(ostream &out) const
624{
625 Int_t n=0;
626 MRanTree *tree;
627 TIter forest(fForest);
628
629 while ((tree=(MRanTree*)forest.Next()))
630 {
631 tree->AsciiWrite(out);
632 n++;
633 }
634
635 return n==fNumTrees;
636}
Note: See TracBrowser for help on using the repository browser.