source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7396

Last change on this file since 7396 was 7396, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 16.8 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TVector.h>
45#include <TRandom.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49#include "MData.h"
50#include "MDataArray.h"
51#include "MParList.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1)
65{
66 fName = name ? name : "MRanForest";
67 fTitle = title ? title : "Storage container for Random Forest";
68
69 fForest=new TObjArray();
70 fForest->SetOwner(kTRUE);
71}
72
73MRanForest::MRanForest(const MRanForest &rf)
74{
75 // Copy constructor
76 fName = rf.fName;
77 fTitle = rf.fTitle;
78
79 fClassify = rf.fClassify;
80 fNumTrees = rf.fNumTrees;
81 fNumTry = rf.fNumTry;
82 fNdSize = rf.fNdSize;
83 fTreeNo = rf.fTreeNo;
84 fRanTree = NULL;
85
86 fRules=new MDataArray();
87 fRules->Reset();
88
89 MDataArray *newrules=rf.fRules;
90
91 for(Int_t i=0;i<newrules->GetNumEntries();i++)
92 {
93 MData &data=(*newrules)[i];
94 fRules->AddEntry(data.GetRule());
95 }
96
97 // trees
98 fForest=new TObjArray();
99 fForest->SetOwner(kTRUE);
100
101 TObjArray *newforest=rf.fForest;
102 for(Int_t i=0;i<newforest->GetEntries();i++)
103 {
104 MRanTree *rantree=(MRanTree*)newforest->At(i);
105
106 MRanTree *newtree=new MRanTree(*rantree);
107 fForest->Add(newtree);
108 }
109
110 fHadTrue = rf.fHadTrue;
111 fHadEst = rf.fHadEst;
112 fDataSort = rf.fDataSort;
113 fDataRang = rf.fDataRang;
114 fClassPop = rf.fClassPop;
115 fWeight = rf.fWeight;
116 fTreeHad = rf.fTreeHad;
117
118 fNTimesOutBag = rf.fNTimesOutBag;
119
120}
121
122// --------------------------------------------------------------------------
123// Destructor.
124MRanForest::~MRanForest()
125{
126 delete fForest;
127}
128
129MRanTree *MRanForest::GetTree(Int_t i)
130{
131 return (MRanTree*)(fForest->At(i));
132}
133
134void MRanForest::SetNumTrees(Int_t n)
135{
136 //at least 1 tree
137 fNumTrees=TMath::Max(n,1);
138 fTreeHad.Set(fNumTrees);
139 fTreeHad.Reset();
140}
141
142void MRanForest::SetNumTry(Int_t n)
143{
144 fNumTry=TMath::Max(n,0);
145}
146
147void MRanForest::SetNdSize(Int_t n)
148{
149 fNdSize=TMath::Max(n,1);
150}
151
152void MRanForest::SetWeights(const TArrayF &weights)
153{
154 const int n=weights.GetSize();
155 fWeight.Set(n);
156 fWeight=weights;
157
158 return;
159}
160
161void MRanForest::SetGrid(const TArrayD &grid)
162{
163 const int n=grid.GetSize();
164
165 for(int i=0;i<n-1;i++)
166 if(grid[i]>=grid[i+1])
167 {
168 *fLog<<inf<<"Grid points must be in increasing order! Ignoring grid."<<endl;
169 return;
170 }
171
172 fGrid=grid;
173
174 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
175 //for(int i=0;i<n;i++)
176 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
177
178 return;
179}
180
181Int_t MRanForest::GetNumDim() const
182{
183 return fMatrix ? fMatrix->GetNcols() : 0;
184}
185
186Int_t MRanForest::GetNumData() const
187{
188 return fMatrix ? fMatrix->GetNrows() : 0;
189}
190
191Int_t MRanForest::GetNclass() const
192{
193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
194
195 return int(fClass[maxidx])+1;
196}
197
198void MRanForest::PrepareClasses()
199{
200 const int numdata=fHadTrue.GetSize();
201
202 if(fGrid.GetSize()>0)
203 {
204 // classes given by grid
205 const int ngrid=fGrid.GetSize();
206
207 for(int j=0;j<numdata;j++)
208 {
209 // Array is supposed to be sorted prior to this call.
210 // If match is found, function returns position of element.
211 // If no match found, function gives nearest element smaller
212 // than value.
213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
214
215 fClass[j] = k;
216 }
217
218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
219 int min = fClass[minidx];
220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
221
222 }else{
223 // classes directly given
224 for (Int_t j=0;j<numdata;j++)
225 fClass[j] = int(fHadTrue[j]+0.5);
226 }
227
228 return;
229}
230
231/*
232Bool_t MRanForest::PreProcess(MParList *plist)
233{
234 if (!fRules)
235 {
236 *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl;
237 return kFALSE;
238 }
239
240 if (!fRules->PreProcess(plist))
241 {
242 *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl;
243 return kFALSE;
244 }
245
246 return kTRUE;
247}
248*/
249
250Double_t MRanForest::CalcHadroness()
251{
252 TVector event;
253 *fRules >> event;
254
255 return CalcHadroness(event);
256}
257
258Double_t MRanForest::CalcHadroness(const TVector &event)
259{
260 Double_t hadroness=0;
261 Int_t ntree=0;
262
263 TIter Next(fForest);
264
265 MRanTree *tree;
266 while ((tree=(MRanTree*)Next()))
267 {
268 hadroness+=(fTreeHad[ntree]=tree->TreeHad(event));
269 ntree++;
270 }
271 return hadroness/ntree;
272}
273
274Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
275{
276 fRanTree = rantree ? rantree:fRanTree;
277
278 if (!fRanTree) return kFALSE;
279
280 MRanTree *newtree=new MRanTree(*fRanTree);
281 fForest->Add(newtree);
282
283 return kTRUE;
284}
285
286Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
287{
288 //-------------------------------------------------------------------
289 // access matrix, copy last column (target) preliminarily
290 // into fHadTrue
291 TMatrix mat_tmp = mat->GetM();
292 int dim = mat_tmp.GetNcols();
293 int numdata = mat_tmp.GetNrows();
294
295 fMatrix=new TMatrix(mat_tmp);
296
297 fHadTrue.Set(numdata);
298 fHadTrue.Reset(0);
299
300 for (Int_t j=0;j<numdata;j++)
301 fHadTrue[j] = (*fMatrix)(j,dim-1);
302
303 // remove last col
304 fMatrix->ResizeTo(numdata,dim-1);
305 dim=fMatrix->GetNcols();
306
307 //-------------------------------------------------------------------
308 // setup labels for classification/regression
309 fClass.Set(numdata);
310 fClass.Reset(0);
311
312 if(fClassify) PrepareClasses();
313
314 //-------------------------------------------------------------------
315 // allocating and initializing arrays
316 fHadEst.Set(numdata); fHadEst.Reset(0);
317 fNTimesOutBag.Set(numdata); fNTimesOutBag.Reset(0);
318 fDataSort.Set(dim*numdata); fDataSort.Reset(0);
319 fDataRang.Set(dim*numdata); fDataRang.Reset(0);
320
321 if(fWeight.GetSize()!=numdata)
322 {
323 fWeight.Set(numdata);
324 fWeight.Reset(1.);
325 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
326 }
327
328 //-------------------------------------------------------------------
329 // setup rules to be used for classification/regression
330 MDataArray *allrules=(MDataArray*)mat->GetColumns();
331 if(allrules==NULL)
332 {
333 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
334 return kFALSE;
335 }
336
337 fRules=new MDataArray(); fRules->Reset();
338 TString target_rule;
339
340 for(Int_t i=0;i<dim+1;i++)
341 {
342 MData &data=(*allrules)[i];
343 if(i<dim)
344 fRules->AddEntry(data.GetRule());
345 else
346 target_rule=data.GetRule();
347 }
348
349 *fLog << inf <<endl;
350 *fLog << inf <<"Setting up RF for training on target:"<<endl<<" "<<target_rule.Data()<<endl;
351 *fLog << inf <<"Following rules are used as input to RF:"<<endl;
352
353 for(Int_t i=0;i<dim;i++)
354 {
355 MData &data=(*fRules)[i];
356 *fLog<<inf<<" "<<i<<") "<<data.GetRule()<<endl<<flush;
357 }
358
359 *fLog << inf <<endl;
360
361 //-------------------------------------------------------------------
362 // prepare (sort) data for fast optimization algorithm
363 if(!CreateDataSort()) return kFALSE;
364
365 //-------------------------------------------------------------------
366 // access and init tree container
367 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
368
369 if(!fRanTree)
370 {
371 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
372 return kFALSE;
373 }
374
375 fRanTree->SetClassify(fClassify);
376 fRanTree->SetNdSize(fNdSize);
377
378 if(fNumTry==0)
379 {
380 double ddim = double(dim);
381
382 fNumTry=int(sqrt(ddim)+0.5);
383 *fLog<<inf<<endl;
384 *fLog<<inf<<"Set no. of trials to the recommended value of round("<<sqrt(ddim)<<") = ";
385 *fLog<<inf<<fNumTry<<endl;
386
387 }
388 fRanTree->SetNumTry(fNumTry);
389
390 *fLog<<inf<<endl;
391 *fLog<<inf<<"Following settings for the tree growing are used:"<<endl;
392 *fLog<<inf<<" Number of Trees : "<<fNumTrees<<endl;
393 *fLog<<inf<<" Number of Trials: "<<fNumTry<<endl;
394 *fLog<<inf<<" Final Node size : "<<fNdSize<<endl;
395
396 fTreeNo=0;
397
398 return kTRUE;
399}
400
401Bool_t MRanForest::GrowForest()
402{
403 if(!gRandom)
404 {
405 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
406 return kFALSE;
407 }
408
409 fTreeNo++;
410
411 //-------------------------------------------------------------------
412 // initialize running output
413
414 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
415 Bool_t calcResolution=(minfloat>0.001);
416
417 if (fTreeNo==1)
418 {
419 *fLog << inf << endl;
420 *fLog << underline;
421
422 if(calcResolution)
423 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl;
424 else
425 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl;
426 // 12345678901234567890123456789012345678901234567890
427 }
428
429 const Int_t numdata = GetNumData();
430 const Int_t nclass = GetNclass();
431
432 //-------------------------------------------------------------------
433 // bootstrap aggregating (bagging) -> sampling with replacement:
434
435 TArrayF classpopw(nclass);
436 TArrayI jinbag(numdata); // Initialization includes filling with 0
437 TArrayF winbag(numdata); // Initialization includes filling with 0
438
439 float square=0; float mean=0;
440
441 for (Int_t n=0; n<numdata; n++)
442 {
443 // The integer k is randomly (uniformly) chosen from the set
444 // {0,1,...,numdata-1}, which is the set of the index numbers of
445 // all events in the training sample
446
447 const Int_t k = Int_t(gRandom->Rndm()*numdata);
448
449 if(fClassify)
450 classpopw[fClass[k]]+=fWeight[k];
451 else
452 classpopw[0]+=fWeight[k];
453
454 mean +=fHadTrue[k]*fWeight[k];
455 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
456
457 winbag[k]+=fWeight[k];
458 jinbag[k]=1;
459
460 }
461
462 //-------------------------------------------------------------------
463 // modifying sorted-data array for in-bag data:
464
465 // In bagging procedure ca. 2/3 of all elements in the original
466 // training sample are used to build the in-bag data
467 TArrayI datsortinbag=fDataSort;
468 Int_t ninbag=0;
469
470 ModifyDataSort(datsortinbag, ninbag, jinbag);
471
472 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
473 jinbag,winbag,nclass);
474
475 //-------------------------------------------------------------------
476 // error-estimates from out-of-bag data (oob data):
477 //
478 // For a single tree the events not(!) contained in the bootstrap sample of
479 // this tree can be used to obtain estimates for the classification error of
480 // this tree.
481 // If you take a certain event, it is contained in the oob-data of 1/3 of
482 // the trees (see comment to ModifyData). This means that the classification error
483 // determined from oob-data is underestimated, but can still be taken as upper limit.
484
485 for (Int_t ievt=0;ievt<numdata;ievt++)
486 {
487 if (jinbag[ievt]>0) continue;
488
489 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
490 fNTimesOutBag[ievt]++;
491
492 }
493
494 Int_t n=0;
495 double ferr=0;
496
497 for (Int_t ievt=0;ievt<numdata;ievt++)
498 {
499 if(fNTimesOutBag[ievt]!=0)
500 {
501 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
502 if(calcResolution) val/=fHadTrue[ievt];
503
504 ferr += val*val;
505 n++;
506 }
507 }
508 ferr = TMath::Sqrt(ferr/n);
509
510 //-------------------------------------------------------------------
511 // give running output
512 *fLog << inf << setw(5) << fTreeNo;
513 *fLog << inf << setw(20) << fRanTree->GetNumEndNodes();
514 *fLog << inf << Form("%20.2f", ferr*100.);
515 *fLog << inf << endl;
516
517 // adding tree to forest
518 AddTree();
519
520 return fTreeNo<fNumTrees;
521}
522
523Bool_t MRanForest::CreateDataSort()
524{
525 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
526 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
527 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
528 //
529 // There may be more then 1 event with rang r (due to bagging).
530
531 const Int_t numdata = GetNumData();
532 const Int_t dim = GetNumDim();
533
534 TArrayF v(numdata);
535 TArrayI isort(numdata);
536
537
538 for (Int_t mvar=0;mvar<dim;mvar++)
539 {
540
541 for(Int_t n=0;n<numdata;n++)
542 {
543 v[n]=(*fMatrix)(n,mvar);
544 isort[n]=n;
545
546 if(TMath::IsNaN(v[n]))
547 {
548 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
549 *fLog << err <<" has the value NaN."<<endl;
550 return kFALSE;
551 }
552 }
553
554 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
555
556 // this sorts the v[n] in ascending order. isort[n] is the event number
557 // of that v[n], which is the n-th from the lowest (assume the original
558 // event numbers are 0,1,...).
559
560 // control sorting
561 for(int n=1;n<numdata;n++)
562 if(v[isort[n-1]]>v[isort[n]])
563 {
564 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
565 *fLog << err <<" not at correct sorting position."<<endl;
566 return kFALSE;
567 }
568
569 for(Int_t n=0;n<numdata-1;n++)
570 {
571 const Int_t n1=isort[n];
572 const Int_t n2=isort[n+1];
573
574 fDataSort[mvar*numdata+n]=n1;
575 if(n==0) fDataRang[mvar*numdata+n1]=0;
576 if(v[n1]<v[n2])
577 {
578 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
579 }else{
580 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
581 }
582 }
583 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
584 }
585 return kTRUE;
586}
587
588void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
589{
590 const Int_t numdim=GetNumDim();
591 const Int_t numdata=GetNumData();
592
593 ninbag=0;
594 for (Int_t n=0;n<numdata;n++)
595 if(jinbag[n]==1) ninbag++;
596
597 for(Int_t m=0;m<numdim;m++)
598 {
599 Int_t k=0;
600 Int_t nt=0;
601 for(Int_t n=0;n<numdata;n++)
602 {
603 if(jinbag[datsortinbag[m*numdata+k]]==1)
604 {
605 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
606 k++;
607 }else{
608 for(Int_t j=1;j<numdata-k;j++)
609 {
610 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
611 {
612 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
613 k+=j+1;
614 break;
615 }
616 }
617 }
618 nt++;
619 if(nt>=ninbag) break;
620 }
621 }
622}
623
624Bool_t MRanForest::AsciiWrite(ostream &out) const
625{
626 Int_t n=0;
627 MRanTree *tree;
628 TIter forest(fForest);
629
630 while ((tree=(MRanTree*)forest.Next()))
631 {
632 tree->AsciiWrite(out);
633 n++;
634 }
635
636 return n==fNumTrees;
637}
Note: See TracBrowser for help on using the repository browser.