source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7420

Last change on this file since 7420 was 7420, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 16.6 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
19!
20! Copyright: MAGIC Software Development, 2000-2005
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TVector.h>
45#include <TRandom.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49#include "MData.h"
50#include "MDataArray.h"
51#include "MParList.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title)
65 : fClassify(kTRUE), fNumTrees(100), fNumTry(0), fNdSize(1),
66 fRanTree(NULL), fRules(NULL), fMatrix(NULL), fUserVal(-1)
67{
68 fName = name ? name : "MRanForest";
69 fTitle = title ? title : "Storage container for Random Forest";
70
71 fForest=new TObjArray();
72 fForest->SetOwner(kTRUE);
73}
74
75MRanForest::MRanForest(const MRanForest &rf)
76{
77 // Copy constructor
78 fName = rf.fName;
79 fTitle = rf.fTitle;
80
81 fClassify = rf.fClassify;
82 fNumTrees = rf.fNumTrees;
83 fNumTry = rf.fNumTry;
84 fNdSize = rf.fNdSize;
85 fTreeNo = rf.fTreeNo;
86 fRanTree = NULL;
87
88 fRules=new MDataArray();
89 fRules->Reset();
90
91 MDataArray *newrules=rf.fRules;
92
93 for(Int_t i=0;i<newrules->GetNumEntries();i++)
94 {
95 MData &data=(*newrules)[i];
96 fRules->AddEntry(data.GetRule());
97 }
98
99 // trees
100 fForest=new TObjArray();
101 fForest->SetOwner(kTRUE);
102
103 TObjArray *newforest=rf.fForest;
104 for(Int_t i=0;i<newforest->GetEntries();i++)
105 {
106 MRanTree *rantree=(MRanTree*)newforest->At(i);
107
108 MRanTree *newtree=new MRanTree(*rantree);
109 fForest->Add(newtree);
110 }
111
112 fHadTrue = rf.fHadTrue;
113 fHadEst = rf.fHadEst;
114 fDataSort = rf.fDataSort;
115 fDataRang = rf.fDataRang;
116 fClassPop = rf.fClassPop;
117 fWeight = rf.fWeight;
118 fTreeHad = rf.fTreeHad;
119
120 fNTimesOutBag = rf.fNTimesOutBag;
121
122}
123
124// --------------------------------------------------------------------------
125// Destructor.
126MRanForest::~MRanForest()
127{
128 delete fForest;
129 if (fMatrix)
130 delete fMatrix;
131 if (fRules)
132 delete fRules;
133}
134
135MRanTree *MRanForest::GetTree(Int_t i)
136{
137 return (MRanTree*)(fForest->At(i));
138}
139
140void MRanForest::SetNumTrees(Int_t n)
141{
142 //at least 1 tree
143 fNumTrees=TMath::Max(n,1);
144 fTreeHad.Set(fNumTrees);
145 fTreeHad.Reset();
146}
147
148void MRanForest::SetNumTry(Int_t n)
149{
150 fNumTry=TMath::Max(n,0);
151}
152
153void MRanForest::SetNdSize(Int_t n)
154{
155 fNdSize=TMath::Max(n,1);
156}
157
158void MRanForest::SetWeights(const TArrayF &weights)
159{
160 fWeight=weights;
161}
162
163void MRanForest::SetGrid(const TArrayD &grid)
164{
165 const int n=grid.GetSize();
166
167 for(int i=0;i<n-1;i++)
168 if(grid[i]>=grid[i+1])
169 {
170 *fLog<<warn<<"Grid points must be in increasing order! Ignoring grid."<<endl;
171 return;
172 }
173
174 fGrid=grid;
175
176 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
177 //for(int i=0;i<n;i++)
178 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
179}
180
181Int_t MRanForest::GetNumDim() const
182{
183 return fMatrix ? fMatrix->GetNcols() : 0;
184}
185
186Int_t MRanForest::GetNumData() const
187{
188 return fMatrix ? fMatrix->GetNrows() : 0;
189}
190
191Int_t MRanForest::GetNclass() const
192{
193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
194
195 return int(fClass[maxidx])+1;
196}
197
198void MRanForest::PrepareClasses()
199{
200 const int numdata=fHadTrue.GetSize();
201
202 if(fGrid.GetSize()>0)
203 {
204 // classes given by grid
205 const int ngrid=fGrid.GetSize();
206
207 for(int j=0;j<numdata;j++)
208 {
209 // Array is supposed to be sorted prior to this call.
210 // If match is found, function returns position of element.
211 // If no match found, function gives nearest element smaller
212 // than value.
213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
214
215 fClass[j] = k;
216 }
217
218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
219 int min = fClass[minidx];
220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
221
222 }else{
223 // classes directly given
224 for (Int_t j=0;j<numdata;j++)
225 fClass[j] = int(fHadTrue[j]+0.5);
226 }
227}
228
229Double_t MRanForest::CalcHadroness()
230{
231 TVector event;
232 *fRules >> event;
233
234 return CalcHadroness(event);
235}
236
237Double_t MRanForest::CalcHadroness(const TVector &event)
238{
239 Double_t hadroness=0;
240 Int_t ntree=0;
241
242 TIter Next(fForest);
243
244 MRanTree *tree;
245 while ((tree=(MRanTree*)Next()))
246 hadroness += (fTreeHad[ntree++]=tree->TreeHad(event));
247
248 return hadroness/ntree;
249}
250
251Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
252{
253 fRanTree = rantree ? rantree : fRanTree;
254
255 if (!fRanTree) return kFALSE;
256
257 MRanTree *newtree=new MRanTree(*fRanTree);
258 fForest->Add(newtree);
259
260 return kTRUE;
261}
262
263Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
264{
265 //-------------------------------------------------------------------
266 // access matrix, copy last column (target) preliminarily
267 // into fHadTrue
268 if (fMatrix)
269 delete fMatrix;
270 fMatrix = new TMatrix(mat->GetM());
271
272 int dim = fMatrix->GetNcols()-1;
273 int numdata = fMatrix->GetNrows();
274
275 fHadTrue.Set(numdata);
276 fHadTrue.Reset(0);
277
278 for (Int_t j=0;j<numdata;j++)
279 fHadTrue[j] = (*fMatrix)(j,dim);
280
281 // remove last col
282 fMatrix->ResizeTo(numdata,dim);
283
284 //-------------------------------------------------------------------
285 // setup labels for classification/regression
286 fClass.Set(numdata);
287 fClass.Reset(0);
288
289 if (fClassify)
290 PrepareClasses();
291
292 //-------------------------------------------------------------------
293 // allocating and initializing arrays
294 fHadEst.Set(numdata);
295 fHadEst.Reset(0);
296
297 fNTimesOutBag.Set(numdata);
298 fNTimesOutBag.Reset(0);
299
300 fDataSort.Set(dim*numdata);
301 fDataSort.Reset(0);
302
303 fDataRang.Set(dim*numdata);
304 fDataRang.Reset(0);
305
306 if(fWeight.GetSize()!=numdata)
307 {
308 fWeight.Set(numdata);
309 fWeight.Reset(1.);
310 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
311 }
312
313 //-------------------------------------------------------------------
314 // setup rules to be used for classification/regression
315 const MDataArray *allrules=(MDataArray*)mat->GetColumns();
316 if(allrules==NULL)
317 {
318 *fLog << err <<"Rules of matrix == null, exiting"<< endl;
319 return kFALSE;
320 }
321
322 if (fRules)
323 delete fRules;
324 fRules = new MDataArray();
325 fRules->Reset();
326
327 const TString target_rule = (*allrules)[dim].GetRule();
328 for (Int_t i=0;i<dim;i++)
329 fRules->AddEntry((*allrules)[i].GetRule());
330
331 *fLog << inf << endl;
332 *fLog << "Setting up RF for training on target:" << endl;
333 *fLog << " " << target_rule.Data() << endl;
334 *fLog << "Following rules are used as input to RF:" << endl;
335 for (Int_t i=0;i<dim;i++)
336 *fLog << " " << i << ") " << (*fRules)[i].GetRule() << endl;
337
338 *fLog << endl;
339
340 //-------------------------------------------------------------------
341 // prepare (sort) data for fast optimization algorithm
342 if (!CreateDataSort())
343 return kFALSE;
344
345 //-------------------------------------------------------------------
346 // access and init tree container
347 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
348 if(!fRanTree)
349 {
350 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
351 return kFALSE;
352 }
353
354 const Int_t tryest = TMath::Nint(TMath::Sqrt(dim));
355
356 *fLog << inf << endl;
357 *fLog << "Following input for the tree growing are used:"<<endl;
358 *fLog << " Number of Trees : "<<fNumTrees<<endl;
359 *fLog << " Number of Trials: "<<(fNumTry==0?tryest:fNumTry)<<(fNumTry==0?" (auto)":"")<<endl;
360 *fLog << " Final Node size : "<<fNdSize<<endl;
361 *fLog << " Using Grid : "<<(fGrid.GetSize()>0?"Yes":"No")<<endl;
362 *fLog << " Number of Events: "<<numdata<<endl;
363 *fLog << " Number of Params: "<<dim<<endl;
364
365 if(fNumTry==0)
366 {
367 fNumTry=tryest;
368 *fLog << inf << endl;
369 *fLog << "Set no. of trials to the recommended value of round(";
370 *fLog << TMath::Sqrt(dim) << ") = " << fNumTry << endl;
371 }
372
373 fRanTree->SetNumTry(fNumTry);
374 fRanTree->SetClassify(fClassify);
375 fRanTree->SetNdSize(fNdSize);
376
377 fTreeNo=0;
378
379 return kTRUE;
380}
381
382Bool_t MRanForest::GrowForest()
383{
384 if(!gRandom)
385 {
386 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
387 return kFALSE;
388 }
389
390 fTreeNo++;
391
392 //-------------------------------------------------------------------
393 // initialize running output
394
395 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
396 Bool_t calcResolution=(minfloat>0.001);
397
398 if (fTreeNo==1)
399 {
400 *fLog << inf << endl << underline;
401
402 if(calcResolution)
403 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl;
404 else
405 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl;
406 // 12345678901234567890123456789012345678901234567890
407 }
408
409 const Int_t numdata = GetNumData();
410 const Int_t nclass = GetNclass();
411
412 //-------------------------------------------------------------------
413 // bootstrap aggregating (bagging) -> sampling with replacement:
414
415 TArrayF classpopw(nclass);
416 TArrayI jinbag(numdata); // Initialization includes filling with 0
417 TArrayF winbag(numdata); // Initialization includes filling with 0
418
419 float square=0;
420 float mean=0;
421
422 for (Int_t n=0; n<numdata; n++)
423 {
424 // The integer k is randomly (uniformly) chosen from the set
425 // {0,1,...,numdata-1}, which is the set of the index numbers of
426 // all events in the training sample
427
428 const Int_t k = Int_t(gRandom->Rndm()*numdata);
429
430 if(fClassify)
431 classpopw[fClass[k]]+=fWeight[k];
432 else
433 classpopw[0]+=fWeight[k];
434
435 mean +=fHadTrue[k]*fWeight[k];
436 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
437
438 winbag[k]+=fWeight[k];
439 jinbag[k]=1;
440
441 }
442
443 //-------------------------------------------------------------------
444 // modifying sorted-data array for in-bag data:
445
446 // In bagging procedure ca. 2/3 of all elements in the original
447 // training sample are used to build the in-bag data
448 TArrayI datsortinbag=fDataSort;
449 Int_t ninbag=0;
450
451 ModifyDataSort(datsortinbag, ninbag, jinbag);
452
453 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
454 jinbag,winbag,nclass);
455
456 //-------------------------------------------------------------------
457 // error-estimates from out-of-bag data (oob data):
458 //
459 // For a single tree the events not(!) contained in the bootstrap sample of
460 // this tree can be used to obtain estimates for the classification error of
461 // this tree.
462 // If you take a certain event, it is contained in the oob-data of 1/3 of
463 // the trees (see comment to ModifyData). This means that the classification error
464 // determined from oob-data is underestimated, but can still be taken as upper limit.
465
466 for (Int_t ievt=0;ievt<numdata;ievt++)
467 {
468 if (jinbag[ievt]>0)
469 continue;
470
471 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
472 fNTimesOutBag[ievt]++;
473
474 }
475
476 Int_t n=0;
477 Float_t ferr=0;
478
479 for (Int_t ievt=0;ievt<numdata;ievt++)
480 {
481 if(fNTimesOutBag[ievt]!=0)
482 {
483 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
484 if(calcResolution) val/=fHadTrue[ievt];
485
486 ferr += val*val;
487 n++;
488 }
489 }
490 ferr = TMath::Sqrt(ferr/n);
491
492 //-------------------------------------------------------------------
493 // give running output
494 *fLog << setw(5) << fTreeNo;
495 *fLog << setw(18) << fRanTree->GetNumEndNodes();
496 *fLog << Form("%18.2f", ferr*100.);
497 *fLog << endl;
498
499 fRanTree->SetError(ferr);
500
501 // adding tree to forest
502 AddTree();
503
504 return fTreeNo<fNumTrees;
505}
506
507Bool_t MRanForest::CreateDataSort()
508{
509 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
510 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
511 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
512 //
513 // There may be more then 1 event with rang r (due to bagging).
514
515 const Int_t numdata = GetNumData();
516 const Int_t dim = GetNumDim();
517
518 TArrayF v(numdata);
519 TArrayI isort(numdata);
520
521
522 for (Int_t mvar=0;mvar<dim;mvar++)
523 {
524
525 for(Int_t n=0;n<numdata;n++)
526 {
527 v[n]=(*fMatrix)(n,mvar);
528 isort[n]=n;
529
530 if(TMath::IsNaN(v[n]))
531 {
532 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
533 *fLog << err <<" has the value NaN."<<endl;
534 return kFALSE;
535 }
536 }
537
538 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
539
540 // this sorts the v[n] in ascending order. isort[n] is the event number
541 // of that v[n], which is the n-th from the lowest (assume the original
542 // event numbers are 0,1,...).
543
544 // control sorting
545 for(int n=1;n<numdata;n++)
546 if(v[isort[n-1]]>v[isort[n]])
547 {
548 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
549 *fLog << err <<" not at correct sorting position."<<endl;
550 return kFALSE;
551 }
552
553 for(Int_t n=0;n<numdata-1;n++)
554 {
555 const Int_t n1=isort[n];
556 const Int_t n2=isort[n+1];
557
558 fDataSort[mvar*numdata+n]=n1;
559 if(n==0) fDataRang[mvar*numdata+n1]=0;
560 if(v[n1]<v[n2])
561 {
562 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
563 }else{
564 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
565 }
566 }
567 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
568 }
569 return kTRUE;
570}
571
572void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
573{
574 const Int_t numdim=GetNumDim();
575 const Int_t numdata=GetNumData();
576
577 ninbag=0;
578 for (Int_t n=0;n<numdata;n++)
579 if(jinbag[n]==1) ninbag++;
580
581 for(Int_t m=0;m<numdim;m++)
582 {
583 Int_t k=0;
584 Int_t nt=0;
585 for(Int_t n=0;n<numdata;n++)
586 {
587 if(jinbag[datsortinbag[m*numdata+k]]==1)
588 {
589 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
590 k++;
591 }else{
592 for(Int_t j=1;j<numdata-k;j++)
593 {
594 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
595 {
596 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
597 k+=j+1;
598 break;
599 }
600 }
601 }
602 nt++;
603 if(nt>=ninbag) break;
604 }
605 }
606}
607
608Bool_t MRanForest::AsciiWrite(ostream &out) const
609{
610 Int_t n=0;
611 MRanTree *tree;
612 TIter forest(fForest);
613
614 while ((tree=(MRanTree*)forest.Next()))
615 {
616 tree->AsciiWrite(out);
617 n++;
618 }
619
620 return n==fNumTrees;
621}
Note: See TracBrowser for help on using the repository browser.