source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7398

Last change on this file since 7398 was 7398, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 16.8 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TVector.h>
45#include <TRandom.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49#include "MData.h"
50#include "MDataArray.h"
51#include "MParList.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1)
65{
66 fName = name ? name : "MRanForest";
67 fTitle = title ? title : "Storage container for Random Forest";
68
69 fForest=new TObjArray();
70 fForest->SetOwner(kTRUE);
71}
72
73MRanForest::MRanForest(const MRanForest &rf)
74{
75 // Copy constructor
76 fName = rf.fName;
77 fTitle = rf.fTitle;
78
79 fClassify = rf.fClassify;
80 fNumTrees = rf.fNumTrees;
81 fNumTry = rf.fNumTry;
82 fNdSize = rf.fNdSize;
83 fTreeNo = rf.fTreeNo;
84 fRanTree = NULL;
85
86 fRules=new MDataArray();
87 fRules->Reset();
88
89 MDataArray *newrules=rf.fRules;
90
91 for(Int_t i=0;i<newrules->GetNumEntries();i++)
92 {
93 MData &data=(*newrules)[i];
94 fRules->AddEntry(data.GetRule());
95 }
96
97 // trees
98 fForest=new TObjArray();
99 fForest->SetOwner(kTRUE);
100
101 TObjArray *newforest=rf.fForest;
102 for(Int_t i=0;i<newforest->GetEntries();i++)
103 {
104 MRanTree *rantree=(MRanTree*)newforest->At(i);
105
106 MRanTree *newtree=new MRanTree(*rantree);
107 fForest->Add(newtree);
108 }
109
110 fHadTrue = rf.fHadTrue;
111 fHadEst = rf.fHadEst;
112 fDataSort = rf.fDataSort;
113 fDataRang = rf.fDataRang;
114 fClassPop = rf.fClassPop;
115 fWeight = rf.fWeight;
116 fTreeHad = rf.fTreeHad;
117
118 fNTimesOutBag = rf.fNTimesOutBag;
119
120}
121
122// --------------------------------------------------------------------------
123// Destructor.
124MRanForest::~MRanForest()
125{
126 delete fForest;
127}
128
129MRanTree *MRanForest::GetTree(Int_t i)
130{
131 return (MRanTree*)(fForest->At(i));
132}
133
134void MRanForest::SetNumTrees(Int_t n)
135{
136 //at least 1 tree
137 fNumTrees=TMath::Max(n,1);
138 fTreeHad.Set(fNumTrees);
139 fTreeHad.Reset();
140}
141
142void MRanForest::SetNumTry(Int_t n)
143{
144 fNumTry=TMath::Max(n,0);
145}
146
147void MRanForest::SetNdSize(Int_t n)
148{
149 fNdSize=TMath::Max(n,1);
150}
151
152void MRanForest::SetWeights(const TArrayF &weights)
153{
154 const int n=weights.GetSize();
155 fWeight.Set(n);
156 fWeight=weights;
157
158 return;
159}
160
161void MRanForest::SetGrid(const TArrayD &grid)
162{
163 const int n=grid.GetSize();
164
165 for(int i=0;i<n-1;i++)
166 if(grid[i]>=grid[i+1])
167 {
168 *fLog<<inf<<"Grid points must be in increasing order! Ignoring grid."<<endl;
169 return;
170 }
171
172 fGrid=grid;
173
174 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
175 //for(int i=0;i<n;i++)
176 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
177
178 return;
179}
180
181Int_t MRanForest::GetNumDim() const
182{
183 return fMatrix ? fMatrix->GetNcols() : 0;
184}
185
186Int_t MRanForest::GetNumData() const
187{
188 return fMatrix ? fMatrix->GetNrows() : 0;
189}
190
191Int_t MRanForest::GetNclass() const
192{
193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
194
195 return int(fClass[maxidx])+1;
196}
197
198void MRanForest::PrepareClasses()
199{
200 const int numdata=fHadTrue.GetSize();
201
202 if(fGrid.GetSize()>0)
203 {
204 // classes given by grid
205 const int ngrid=fGrid.GetSize();
206
207 for(int j=0;j<numdata;j++)
208 {
209 // Array is supposed to be sorted prior to this call.
210 // If match is found, function returns position of element.
211 // If no match found, function gives nearest element smaller
212 // than value.
213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
214
215 fClass[j] = k;
216 }
217
218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
219 int min = fClass[minidx];
220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
221
222 }else{
223 // classes directly given
224 for (Int_t j=0;j<numdata;j++)
225 fClass[j] = int(fHadTrue[j]+0.5);
226 }
227
228 return;
229}
230
231/*
232Bool_t MRanForest::PreProcess(MParList *plist)
233{
234 if (!fRules)
235 {
236 *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl;
237 return kFALSE;
238 }
239
240 if (!fRules->PreProcess(plist))
241 {
242 *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl;
243 return kFALSE;
244 }
245
246 return kTRUE;
247}
248*/
249
250Double_t MRanForest::CalcHadroness()
251{
252 TVector event;
253 *fRules >> event;
254
255 return CalcHadroness(event);
256}
257
258Double_t MRanForest::CalcHadroness(const TVector &event)
259{
260 Double_t hadroness=0;
261 Int_t ntree=0;
262
263 TIter Next(fForest);
264
265 MRanTree *tree;
266 while ((tree=(MRanTree*)Next()))
267 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
268
269 return hadroness/ntree;
270}
271
272Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
273{
274 fRanTree = rantree ? rantree:fRanTree;
275
276 if (!fRanTree) return kFALSE;
277
278 MRanTree *newtree=new MRanTree(*fRanTree);
279 fForest->Add(newtree);
280
281 return kTRUE;
282}
283
284Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
285{
286 //-------------------------------------------------------------------
287 // access matrix, copy last column (target) preliminarily
288 // into fHadTrue
289 TMatrix mat_tmp = mat->GetM();
290 int dim = mat_tmp.GetNcols();
291 int numdata = mat_tmp.GetNrows();
292
293 fMatrix=new TMatrix(mat_tmp);
294
295 fHadTrue.Set(numdata);
296 fHadTrue.Reset(0);
297
298 for (Int_t j=0;j<numdata;j++)
299 fHadTrue[j] = (*fMatrix)(j,dim-1);
300
301 // remove last col
302 fMatrix->ResizeTo(numdata,dim-1);
303 dim=fMatrix->GetNcols();
304
305 //-------------------------------------------------------------------
306 // setup labels for classification/regression
307 fClass.Set(numdata);
308 fClass.Reset(0);
309
310 if(fClassify) PrepareClasses();
311
312 //-------------------------------------------------------------------
313 // allocating and initializing arrays
314 fHadEst.Set(numdata); fHadEst.Reset(0);
315 fNTimesOutBag.Set(numdata); fNTimesOutBag.Reset(0);
316 fDataSort.Set(dim*numdata); fDataSort.Reset(0);
317 fDataRang.Set(dim*numdata); fDataRang.Reset(0);
318
319 if(fWeight.GetSize()!=numdata)
320 {
321 fWeight.Set(numdata);
322 fWeight.Reset(1.);
323 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
324 }
325
326 //-------------------------------------------------------------------
327 // setup rules to be used for classification/regression
328 MDataArray *allrules=(MDataArray*)mat->GetColumns();
329 if(allrules==NULL)
330 {
331 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
332 return kFALSE;
333 }
334
335 fRules=new MDataArray(); fRules->Reset();
336 TString target_rule;
337
338 for(Int_t i=0;i<dim+1;i++)
339 {
340 MData &data=(*allrules)[i];
341 if(i<dim)
342 fRules->AddEntry(data.GetRule());
343 else
344 target_rule=data.GetRule();
345 }
346
347 *fLog << inf <<endl;
348 *fLog << inf <<"Setting up RF for training on target:"<<endl<<" "<<target_rule.Data()<<endl;
349 *fLog << inf <<"Following rules are used as input to RF:"<<endl;
350
351 for(Int_t i=0;i<dim;i++)
352 {
353 MData &data=(*fRules)[i];
354 *fLog<<inf<<" "<<i<<") "<<data.GetRule()<<endl<<flush;
355 }
356
357 *fLog << inf <<endl;
358
359 //-------------------------------------------------------------------
360 // prepare (sort) data for fast optimization algorithm
361 if(!CreateDataSort()) return kFALSE;
362
363 //-------------------------------------------------------------------
364 // access and init tree container
365 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
366
367 if(!fRanTree)
368 {
369 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
370 return kFALSE;
371 }
372
373 fRanTree->SetClassify(fClassify);
374 fRanTree->SetNdSize(fNdSize);
375
376 if(fNumTry==0)
377 {
378 double ddim = double(dim);
379
380 fNumTry=int(sqrt(ddim)+0.5);
381 *fLog<<inf<<endl;
382 *fLog<<inf<<"Set no. of trials to the recommended value of round("<<sqrt(ddim)<<") = ";
383 *fLog<<inf<<fNumTry<<endl;
384
385 }
386 fRanTree->SetNumTry(fNumTry);
387
388 *fLog<<inf<<endl;
389 *fLog<<inf<<"Following settings for the tree growing are used:"<<endl;
390 *fLog<<inf<<" Number of Trees : "<<fNumTrees<<endl;
391 *fLog<<inf<<" Number of Trials: "<<fNumTry<<endl;
392 *fLog<<inf<<" Final Node size : "<<fNdSize<<endl;
393
394 fTreeNo=0;
395
396 return kTRUE;
397}
398
399Bool_t MRanForest::GrowForest()
400{
401 if(!gRandom)
402 {
403 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
404 return kFALSE;
405 }
406
407 fTreeNo++;
408
409 //-------------------------------------------------------------------
410 // initialize running output
411
412 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
413 Bool_t calcResolution=(minfloat>0.001);
414
415 if (fTreeNo==1)
416 {
417 *fLog << inf << endl;
418 *fLog << underline;
419
420 if(calcResolution)
421 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl;
422 else
423 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl;
424 // 12345678901234567890123456789012345678901234567890
425 }
426
427 const Int_t numdata = GetNumData();
428 const Int_t nclass = GetNclass();
429
430 //-------------------------------------------------------------------
431 // bootstrap aggregating (bagging) -> sampling with replacement:
432
433 TArrayF classpopw(nclass);
434 TArrayI jinbag(numdata); // Initialization includes filling with 0
435 TArrayF winbag(numdata); // Initialization includes filling with 0
436
437 float square=0; float mean=0;
438
439 for (Int_t n=0; n<numdata; n++)
440 {
441 // The integer k is randomly (uniformly) chosen from the set
442 // {0,1,...,numdata-1}, which is the set of the index numbers of
443 // all events in the training sample
444
445 const Int_t k = Int_t(gRandom->Rndm()*numdata);
446
447 if(fClassify)
448 classpopw[fClass[k]]+=fWeight[k];
449 else
450 classpopw[0]+=fWeight[k];
451
452 mean +=fHadTrue[k]*fWeight[k];
453 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
454
455 winbag[k]+=fWeight[k];
456 jinbag[k]=1;
457
458 }
459
460 //-------------------------------------------------------------------
461 // modifying sorted-data array for in-bag data:
462
463 // In bagging procedure ca. 2/3 of all elements in the original
464 // training sample are used to build the in-bag data
465 TArrayI datsortinbag=fDataSort;
466 Int_t ninbag=0;
467
468 ModifyDataSort(datsortinbag, ninbag, jinbag);
469
470 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
471 jinbag,winbag,nclass);
472
473 //-------------------------------------------------------------------
474 // error-estimates from out-of-bag data (oob data):
475 //
476 // For a single tree the events not(!) contained in the bootstrap sample of
477 // this tree can be used to obtain estimates for the classification error of
478 // this tree.
479 // If you take a certain event, it is contained in the oob-data of 1/3 of
480 // the trees (see comment to ModifyData). This means that the classification error
481 // determined from oob-data is underestimated, but can still be taken as upper limit.
482
483 for (Int_t ievt=0;ievt<numdata;ievt++)
484 {
485 if (jinbag[ievt]>0) continue;
486
487 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
488 fNTimesOutBag[ievt]++;
489
490 }
491
492 Int_t n=0;
493 double ferr=0;
494
495 for (Int_t ievt=0;ievt<numdata;ievt++)
496 {
497 if(fNTimesOutBag[ievt]!=0)
498 {
499 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
500 if(calcResolution) val/=fHadTrue[ievt];
501
502 ferr += val*val;
503 n++;
504 }
505 }
506 ferr = TMath::Sqrt(ferr/n);
507
508 //-------------------------------------------------------------------
509 // give running output
510 *fLog << inf << setw(5) << fTreeNo;
511 *fLog << inf << setw(20) << fRanTree->GetNumEndNodes();
512 *fLog << inf << Form("%20.2f", ferr*100.);
513 *fLog << inf << endl;
514
515 // adding tree to forest
516 AddTree();
517
518 return fTreeNo<fNumTrees;
519}
520
521Bool_t MRanForest::CreateDataSort()
522{
523 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
524 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
525 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
526 //
527 // There may be more then 1 event with rang r (due to bagging).
528
529 const Int_t numdata = GetNumData();
530 const Int_t dim = GetNumDim();
531
532 TArrayF v(numdata);
533 TArrayI isort(numdata);
534
535
536 for (Int_t mvar=0;mvar<dim;mvar++)
537 {
538
539 for(Int_t n=0;n<numdata;n++)
540 {
541 v[n]=(*fMatrix)(n,mvar);
542 isort[n]=n;
543
544 if(TMath::IsNaN(v[n]))
545 {
546 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
547 *fLog << err <<" has the value NaN."<<endl;
548 return kFALSE;
549 }
550 }
551
552 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
553
554 // this sorts the v[n] in ascending order. isort[n] is the event number
555 // of that v[n], which is the n-th from the lowest (assume the original
556 // event numbers are 0,1,...).
557
558 // control sorting
559 for(int n=1;n<numdata;n++)
560 if(v[isort[n-1]]>v[isort[n]])
561 {
562 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
563 *fLog << err <<" not at correct sorting position."<<endl;
564 return kFALSE;
565 }
566
567 for(Int_t n=0;n<numdata-1;n++)
568 {
569 const Int_t n1=isort[n];
570 const Int_t n2=isort[n+1];
571
572 fDataSort[mvar*numdata+n]=n1;
573 if(n==0) fDataRang[mvar*numdata+n1]=0;
574 if(v[n1]<v[n2])
575 {
576 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
577 }else{
578 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
579 }
580 }
581 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
582 }
583 return kTRUE;
584}
585
586void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
587{
588 const Int_t numdim=GetNumDim();
589 const Int_t numdata=GetNumData();
590
591 ninbag=0;
592 for (Int_t n=0;n<numdata;n++)
593 if(jinbag[n]==1) ninbag++;
594
595 for(Int_t m=0;m<numdim;m++)
596 {
597 Int_t k=0;
598 Int_t nt=0;
599 for(Int_t n=0;n<numdata;n++)
600 {
601 if(jinbag[datsortinbag[m*numdata+k]]==1)
602 {
603 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
604 k++;
605 }else{
606 for(Int_t j=1;j<numdata-k;j++)
607 {
608 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
609 {
610 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
611 k+=j+1;
612 break;
613 }
614 }
615 }
616 nt++;
617 if(nt>=ninbag) break;
618 }
619 }
620}
621
622Bool_t MRanForest::AsciiWrite(ostream &out) const
623{
624 Int_t n=0;
625 MRanTree *tree;
626 TIter forest(fForest);
627
628 while ((tree=(MRanTree*)forest.Next()))
629 {
630 tree->AsciiWrite(out);
631 n++;
632 }
633
634 return n==fNumTrees;
635}
Note: See TracBrowser for help on using the repository browser.