source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 4778

Last change on this file since 4778 was 2307, checked in by tbretz, 21 years ago
*** empty log message ***
File size: 12.5 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
19!
20! Copyright: MAGIC Software Development, 2000-2003
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26// //
27// MRanForest //
28// //
29// ParameterContainer for Forest structure //
30// //
31// A random forest can be grown by calling GrowForest. //
32// In advance SetupGrow must be called in order to initialize arrays and //
33// do some preprocessing. //
34// GrowForest() provides the training data for a single tree (bootstrap //
35// aggregate procedure) //
36// //
37// Essentially two random elements serve to provide a "random" forest, //
38// namely bootstrap aggregating (which is done in GrowForest()) and random //
39// split selection (which is subject to MRanTree::GrowTree()) //
40// //
41/////////////////////////////////////////////////////////////////////////////
42#include "MRanForest.h"
43
44#include <TMatrix.h>
45#include <TRandom3.h>
46
47#include "MHMatrix.h"
48#include "MRanTree.h"
49
50#include "MLog.h"
51#include "MLogManip.h"
52
53ClassImp(MRanForest);
54
55using namespace std;
56
57// --------------------------------------------------------------------------
58//
59// Default constructor.
60//
61MRanForest::MRanForest(const char *name, const char *title) : fNumTrees(100), fRanTree(NULL),fUsePriors(kFALSE)
62{
63 fName = name ? name : "MRanForest";
64 fTitle = title ? title : "Storage container for Random Forest";
65
66 fForest=new TObjArray();
67 fForest->SetOwner(kTRUE);
68}
69
70// --------------------------------------------------------------------------
71//
72// Destructor.
73//
74MRanForest::~MRanForest()
75{
76 delete fForest;
77}
78
79void MRanForest::SetNumTrees(Int_t n)
80{
81 //at least 1 tree
82 fNumTrees=TMath::Max(n,1);
83 fTreeHad.Set(fNumTrees);
84 fTreeHad.Reset();
85}
86
87void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam)
88{
89 const Float_t sum=prior_gam+prior_had;
90
91 prior_gam/=sum;
92 prior_had/=sum;
93
94 fPrior[0]=prior_had;
95 fPrior[1]=prior_gam;
96
97 fUsePriors=kTRUE;
98}
99
100Int_t MRanForest::GetNumDim() const
101{
102 return fGammas ? fGammas->GetM().GetNcols() : 0;
103}
104
105
106Double_t MRanForest::CalcHadroness(const TVector &event)
107{
108 Double_t hadroness=0;
109 Int_t ntree=0;
110
111 TIter Next(fForest);
112
113 MRanTree *tree;
114 while ((tree=(MRanTree*)Next()))
115 {
116 fTreeHad[ntree]=tree->TreeHad(event);
117 hadroness+=fTreeHad[ntree];
118 ntree++;
119 }
120 return hadroness/ntree;
121}
122
123Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
124{
125 if (rantree)
126 fRanTree=rantree;
127 if (!fRanTree)
128 return kFALSE;
129
130 fForest->Add((MRanTree*)fRanTree->Clone());
131
132 return kTRUE;
133}
134
135Int_t MRanForest::GetNumData() const
136{
137 return fHadrons && fGammas ? fHadrons->GetM().GetNrows()+fGammas->GetM().GetNrows() : 0;
138}
139
140Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam)
141{
142 // pointer to training data
143 fHadrons=mhad;
144 fGammas=mgam;
145
146 // determine data entries and dimension of Hillas-parameter space
147 //fNumHad=fHadrons->GetM().GetNrows();
148 //fNumGam=fGammas->GetM().GetNrows();
149
150 const Int_t dim = GetNumDim();
151
152 if (dim!=fGammas->GetM().GetNcols())
153 return kFALSE;
154
155 const Int_t numdata = GetNumData();
156
157 // allocating and initializing arrays
158 fHadTrue.Set(numdata);
159 fHadTrue.Reset();
160 fHadEst.Set(numdata);
161
162 fPrior.Set(2);
163 fClassPop.Set(2);
164 fWeight.Set(numdata);
165 fNTimesOutBag.Set(numdata);
166 fNTimesOutBag.Reset();
167
168 fDataSort.Set(dim*numdata);
169 fDataRang.Set(dim*numdata);
170
171 // calculating class populations (= no. of gammas and hadrons)
172 fClassPop.Reset();
173 for(Int_t n=0;n<numdata;n++)
174 fClassPop[fHadTrue[n]]++;
175
176 // setting weights taking into account priors
177 if (!fUsePriors)
178 fWeight.Reset(1.);
179 else
180 {
181 for(Int_t j=0;j<2;j++)
182 fPrior[j] *= (fClassPop[j]>=1) ? (Float_t)numdata/fClassPop[j]:0;
183
184 for(Int_t n=0;n<numdata;n++)
185 fWeight[n]=fPrior[fHadTrue[n]];
186 }
187
188 // create fDataSort
189 CreateDataSort();
190
191 if(!fRanTree)
192 {
193 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
194 return kFALSE;
195 }
196 fRanTree->SetRules(fGammas->GetColumns());
197 fTreeNo=0;
198
199 return kTRUE;
200}
201
202void MRanForest::InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag)
203{
204 for (Int_t ievt=from;ievt<to;ievt++)
205 {
206 if (jinbag[ievt]>0)
207 continue;
208 fHadEst[ievt] += fRanTree->TreeHad(m, ievt-from);
209 fNTimesOutBag[ievt]++;
210 }
211}
212
213Bool_t MRanForest::GrowForest()
214{
215 if(!gRandom)
216 {
217 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
218 return kFALSE;
219 }
220
221 fTreeNo++;
222
223 // initialize running output
224 if (fTreeNo==1)
225 {
226 *fLog << inf << endl;
227 *fLog << underline; // << "1st col 2nd col" << endl;
228 *fLog << "no. of tree error in % (calulated using oob-data -> overestim. of error)" << endl;
229 }
230
231 const Int_t numdata = GetNumData();
232
233 // bootstrap aggregating (bagging) -> sampling with replacement:
234 //
235 // The integer k is randomly (uniformly) chosen from the set
236 // {0,1,...,fNumData-1}, which is the set of the index numbers of
237 // all events in the training sample
238 TArrayF classpopw(2);
239 TArrayI jinbag(numdata); // Initialization includes filling with 0
240 TArrayF winbag(numdata); // Initialization includes filling with 0
241
242 for (Int_t n=0; n<numdata; n++)
243 {
244 const Int_t k = Int_t(gRandom->Rndm()*numdata);
245
246 classpopw[fHadTrue[k]]+=fWeight[k];
247 winbag[k]+=fWeight[k];
248 jinbag[k]=1;
249 }
250
251 // modifying sorted-data array for in-bag data:
252 //
253 // In bagging procedure ca. 2/3 of all elements in the original
254 // training sample are used to build the in-bag data
255 TArrayI datsortinbag=fDataSort;
256 Int_t ninbag=0;
257
258 ModifyDataSort(datsortinbag, ninbag, jinbag);
259
260 const TMatrix &hadrons=fHadrons->GetM();
261 const TMatrix &gammas =fGammas->GetM();
262
263 // growing single tree
264 fRanTree->GrowTree(hadrons,gammas,fHadTrue,datsortinbag,
265 fDataRang,classpopw,jinbag,winbag);
266
267 // error-estimates from out-of-bag data (oob data):
268 //
269 // For a single tree the events not(!) contained in the bootstrap sample of
270 // this tree can be used to obtain estimates for the classification error of
271 // this tree.
272 // If you take a certain event, it is contained in the oob-data of 1/3 of
273 // the trees (see comment to ModifyData). This means that the classification error
274 // determined from oob-data is underestimated, but can still be taken as upper limit.
275
276 const Int_t numhad = hadrons.GetNrows();
277 InitHadEst(0, numhad, hadrons, jinbag);
278 InitHadEst(numhad, numdata, gammas, jinbag);
279 /*
280 for (Int_t ievt=0;ievt<numhad;ievt++)
281 {
282 if (jinbag[ievt]>0)
283 continue;
284 fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt);
285 fNTimesOutBag[ievt]++;
286 }
287
288 for (Int_t ievt=numhad;ievt<numdata;ievt++)
289 {
290 if (jinbag[ievt]>0)
291 continue;
292 fHadEst[ievt] += fRanTree->TreeHad(gammas, ievt-numhad);
293 fNTimesOutBag[ievt]++;
294 }
295 */
296 Int_t n=0;
297 Double_t ferr=0;
298 for (Int_t ievt=0;ievt<numdata;ievt++)
299 if (fNTimesOutBag[ievt]!=0)
300 {
301 const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt];
302 ferr += val*val;
303 n++;
304 }
305
306 ferr = TMath::Sqrt(ferr/n);
307
308 // give running output
309 *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl;
310
311 // adding tree to forest
312 AddTree();
313
314 return fTreeNo<fNumTrees;
315}
316
317void MRanForest::CreateDataSort()
318{
319 // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData,
320 // which does actually not exist) are sorted from lowest to highest.
321 //
322 //
323 // fHadrons(0,0) ... fHadrons(0,nhad-1) fGammas(0,0) ... fGammas(0,ngam-1)
324 // . . . .
325 // fData(m,n) = . . . .
326 // . . . .
327 // fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1)
328 //
329 //
330 // Then fDataSort(m,n) is the event number in which fData(m,n) occurs.
331 // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest
332 // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging).
333 const Int_t numdata = GetNumData();
334
335 TArrayF v(numdata);
336 TArrayI isort(numdata);
337
338 const TMatrix &hadrons=fHadrons->GetM();
339 const TMatrix &gammas=fGammas->GetM();
340
341 const Int_t numgam = gammas.GetNrows();
342 const Int_t numhad = hadrons.GetNrows();
343
344 for (Int_t j=0;j<numhad;j++)
345 fHadTrue[j]=1;
346
347 for (Int_t j=0;j<numgam;j++)
348 fHadTrue[j+numhad]=0;
349
350 const Int_t dim = GetNumDim();
351 for (Int_t mvar=0;mvar<dim;mvar++)
352 {
353 for(Int_t n=0;n<numhad;n++)
354 {
355 v[n]=hadrons(n,mvar);
356 isort[n]=n;
357 }
358
359 for(Int_t n=0;n<numgam;n++)
360 {
361 v[n+numhad]=gammas(n,mvar);
362 isort[n+numhad]=n;
363 }
364
365 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
366
367 // this sorts the v[n] in ascending order. isort[n] is the event number
368 // of that v[n], which is the n-th from the lowest (assume the original
369 // event numbers are 0,1,...).
370
371 for(Int_t n=0;n<numdata-1;n++)
372 {
373 const Int_t n1=isort[n];
374 const Int_t n2=isort[n+1];
375
376 fDataSort[mvar*numdata+n]=n1;
377 if(n==0) fDataRang[mvar*numdata+n1]=0;
378 if(v[n1]<v[n2])
379 {
380 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
381 }else{
382 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
383 }
384 }
385 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
386 }
387}
388
389void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
390{
391 const Int_t numdim=GetNumDim();
392 const Int_t numdata=GetNumData();
393
394 ninbag=0;
395 for (Int_t n=0;n<numdata;n++)
396 if(jinbag[n]==1) ninbag++;
397
398 for(Int_t m=0;m<numdim;m++)
399 {
400 Int_t k=0;
401 Int_t nt=0;
402 for(Int_t n=0;n<numdata;n++)
403 {
404 if(jinbag[datsortinbag[m*numdata+k]]==1)
405 {
406 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
407 k++;
408 }else{
409 for(Int_t j=1;j<numdata-k;j++)
410 {
411 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
412 {
413 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
414 k+=j+1;
415 break;
416 }
417 }
418 }
419 nt++;
420 if(nt>=ninbag) break;
421 }
422 }
423}
424
425Bool_t MRanForest::AsciiWrite(ostream &out) const
426{
427 Int_t n=0;
428 MRanTree *tree;
429 TIter forest(fForest);
430
431 while ((tree=(MRanTree*)forest.Next()))
432 {
433 tree->AsciiWrite(out);
434 n++;
435 }
436
437 return n==fNumTrees;
438}
Note: See TracBrowser for help on using the repository browser.