source: trunk/MagicSoft/Mars/mranforest/MRanForest.cc@ 7360

Last change on this file since 7360 was 7170, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 11.9 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
19!
20! Copyright: MAGIC Software Development, 2000-2003
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26//
27// MRanForest
28//
29// ParameterContainer for Forest structure
30//
31// A random forest can be grown by calling GrowForest.
32// In advance SetupGrow must be called in order to initialize arrays and
33// do some preprocessing.
34// GrowForest() provides the training data for a single tree (bootstrap
35// aggregate procedure)
36//
37// Essentially two random elements serve to provide a "random" forest,
38// namely bootstrap aggregating (which is done in GrowForest()) and random
39// split selection (which is subject to MRanTree::GrowTree())
40//
41// Version 2:
42// + fUserVal
43//
44/////////////////////////////////////////////////////////////////////////////
45#include "MRanForest.h"
46
47#include <TMatrix.h>
48#include <TRandom3.h>
49
50#include "MHMatrix.h"
51#include "MRanTree.h"
52
53#include "MLog.h"
54#include "MLogManip.h"
55
56ClassImp(MRanForest);
57
58using namespace std;
59
60// --------------------------------------------------------------------------
61//
62// Default constructor.
63//
64MRanForest::MRanForest(const char *name, const char *title) : fNumTrees(100), fRanTree(NULL),fUsePriors(kFALSE), fUserVal(-1)
65{
66 fName = name ? name : "MRanForest";
67 fTitle = title ? title : "Storage container for Random Forest";
68
69 fForest=new TObjArray();
70 fForest->SetOwner(kTRUE);
71}
72
73// --------------------------------------------------------------------------
74//
75// Destructor.
76//
77MRanForest::~MRanForest()
78{
79 delete fForest;
80}
81
82void MRanForest::SetNumTrees(Int_t n)
83{
84 //at least 1 tree
85 fNumTrees=TMath::Max(n,1);
86 fTreeHad.Set(fNumTrees);
87 fTreeHad.Reset();
88}
89
90void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam)
91{
92 const Float_t sum=prior_gam+prior_had;
93
94 prior_gam/=sum;
95 prior_had/=sum;
96
97 fPrior[0]=prior_had;
98 fPrior[1]=prior_gam;
99
100 fUsePriors=kTRUE;
101}
102
103Int_t MRanForest::GetNumDim() const
104{
105 return fGammas ? fGammas->GetM().GetNcols() : 0;
106}
107
108
109Double_t MRanForest::CalcHadroness(const TVector &event)
110{
111 Double_t hadroness=0;
112 Int_t ntree=0;
113
114 TIter Next(fForest);
115
116 MRanTree *tree;
117 while ((tree=(MRanTree*)Next()))
118 {
119 fTreeHad[ntree]=tree->TreeHad(event);
120 hadroness+=fTreeHad[ntree];
121 ntree++;
122 }
123 return hadroness/ntree;
124}
125
126Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
127{
128 if (rantree)
129 fRanTree=rantree;
130 if (!fRanTree)
131 return kFALSE;
132
133 fForest->Add((MRanTree*)fRanTree->Clone());
134
135 return kTRUE;
136}
137
138Int_t MRanForest::GetNumData() const
139{
140 return fHadrons && fGammas ? fHadrons->GetM().GetNrows()+fGammas->GetM().GetNrows() : 0;
141}
142
143Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam)
144{
145 // pointer to training data
146 fHadrons=mhad;
147 fGammas=mgam;
148
149 // determine data entries and dimension of Hillas-parameter space
150 //fNumHad=fHadrons->GetM().GetNrows();
151 //fNumGam=fGammas->GetM().GetNrows();
152
153 const Int_t dim = GetNumDim();
154
155 if (dim!=fGammas->GetM().GetNcols())
156 return kFALSE;
157
158 const Int_t numdata = GetNumData();
159
160 // allocating and initializing arrays
161 fHadTrue.Set(numdata);
162 fHadTrue.Reset();
163 fHadEst.Set(numdata);
164
165 fPrior.Set(2);
166 fClassPop.Set(2);
167 fWeight.Set(numdata);
168 fNTimesOutBag.Set(numdata);
169 fNTimesOutBag.Reset();
170
171 fDataSort.Set(dim*numdata);
172 fDataRang.Set(dim*numdata);
173
174 // calculating class populations (= no. of gammas and hadrons)
175 fClassPop.Reset();
176 for(Int_t n=0;n<numdata;n++)
177 fClassPop[fHadTrue[n]]++;
178
179 // setting weights taking into account priors
180 if (!fUsePriors)
181 fWeight.Reset(1.);
182 else
183 {
184 for(Int_t j=0;j<2;j++)
185 fPrior[j] *= (fClassPop[j]>=1) ? (Float_t)numdata/fClassPop[j]:0;
186
187 for(Int_t n=0;n<numdata;n++)
188 fWeight[n]=fPrior[fHadTrue[n]];
189 }
190
191 // create fDataSort
192 CreateDataSort();
193
194 if(!fRanTree)
195 {
196 *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
197 return kFALSE;
198 }
199 fRanTree->SetRules(fGammas->GetColumns());
200 fTreeNo=0;
201
202 return kTRUE;
203}
204
205void MRanForest::InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag)
206{
207 for (Int_t ievt=from;ievt<to;ievt++)
208 {
209 if (jinbag[ievt]>0)
210 continue;
211 fHadEst[ievt] += fRanTree->TreeHad(m, ievt-from);
212 fNTimesOutBag[ievt]++;
213 }
214}
215
216Bool_t MRanForest::GrowForest()
217{
218 if(!gRandom)
219 {
220 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
221 return kFALSE;
222 }
223
224 fTreeNo++;
225
226 // initialize running output
227 if (fTreeNo==1)
228 {
229 *fLog << inf << endl;
230 *fLog << underline; // << "1st col 2nd col" << endl;
231 *fLog << "no. of tree error in % (calulated using oob-data -> overestim. of error)" << endl;
232 }
233
234 const Int_t numdata = GetNumData();
235
236 // bootstrap aggregating (bagging) -> sampling with replacement:
237 //
238 // The integer k is randomly (uniformly) chosen from the set
239 // {0,1,...,fNumData-1}, which is the set of the index numbers of
240 // all events in the training sample
241 TArrayF classpopw(2);
242 TArrayI jinbag(numdata); // Initialization includes filling with 0
243 TArrayF winbag(numdata); // Initialization includes filling with 0
244
245 for (Int_t n=0; n<numdata; n++)
246 {
247 const Int_t k = Int_t(gRandom->Rndm()*numdata);
248
249 classpopw[fHadTrue[k]]+=fWeight[k];
250 winbag[k]+=fWeight[k];
251 jinbag[k]=1;
252 }
253
254 // modifying sorted-data array for in-bag data:
255 //
256 // In bagging procedure ca. 2/3 of all elements in the original
257 // training sample are used to build the in-bag data
258 TArrayI datsortinbag=fDataSort;
259 Int_t ninbag=0;
260
261 ModifyDataSort(datsortinbag, ninbag, jinbag);
262
263 const TMatrix &hadrons=fHadrons->GetM();
264 const TMatrix &gammas =fGammas->GetM();
265
266 // growing single tree
267 fRanTree->GrowTree(hadrons,gammas,fHadTrue,datsortinbag,
268 fDataRang,classpopw,jinbag,winbag);
269
270 // error-estimates from out-of-bag data (oob data):
271 //
272 // For a single tree the events not(!) contained in the bootstrap sample of
273 // this tree can be used to obtain estimates for the classification error of
274 // this tree.
275 // If you take a certain event, it is contained in the oob-data of 1/3 of
276 // the trees (see comment to ModifyData). This means that the classification error
277 // determined from oob-data is underestimated, but can still be taken as upper limit.
278
279 const Int_t numhad = hadrons.GetNrows();
280 InitHadEst(0, numhad, hadrons, jinbag);
281 InitHadEst(numhad, numdata, gammas, jinbag);
282 /*
283 for (Int_t ievt=0;ievt<numhad;ievt++)
284 {
285 if (jinbag[ievt]>0)
286 continue;
287 fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt);
288 fNTimesOutBag[ievt]++;
289 }
290
291 for (Int_t ievt=numhad;ievt<numdata;ievt++)
292 {
293 if (jinbag[ievt]>0)
294 continue;
295 fHadEst[ievt] += fRanTree->TreeHad(gammas, ievt-numhad);
296 fNTimesOutBag[ievt]++;
297 }
298 */
299 Int_t n=0;
300 Double_t ferr=0;
301 for (Int_t ievt=0;ievt<numdata;ievt++)
302 if (fNTimesOutBag[ievt]!=0)
303 {
304 const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt];
305 ferr += val*val;
306 n++;
307 }
308
309 ferr = TMath::Sqrt(ferr/n);
310
311 // give running output
312 *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl;
313
314 // adding tree to forest
315 AddTree();
316
317 return fTreeNo<fNumTrees;
318}
319
320void MRanForest::CreateDataSort()
321{
322 // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData,
323 // which does actually not exist) are sorted from lowest to highest.
324 //
325 //
326 // fHadrons(0,0) ... fHadrons(0,nhad-1) fGammas(0,0) ... fGammas(0,ngam-1)
327 // . . . .
328 // fData(m,n) = . . . .
329 // . . . .
330 // fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1)
331 //
332 //
333 // Then fDataSort(m,n) is the event number in which fData(m,n) occurs.
334 // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest
335 // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging).
336 const Int_t numdata = GetNumData();
337
338 TArrayF v(numdata);
339 TArrayI isort(numdata);
340
341 const TMatrix &hadrons=fHadrons->GetM();
342 const TMatrix &gammas=fGammas->GetM();
343
344 const Int_t numgam = gammas.GetNrows();
345 const Int_t numhad = hadrons.GetNrows();
346
347 for (Int_t j=0;j<numhad;j++)
348 fHadTrue[j]=1;
349
350 for (Int_t j=0;j<numgam;j++)
351 fHadTrue[j+numhad]=0;
352
353 const Int_t dim = GetNumDim();
354 for (Int_t mvar=0;mvar<dim;mvar++)
355 {
356 for(Int_t n=0;n<numhad;n++)
357 {
358 v[n]=hadrons(n,mvar);
359 isort[n]=n;
360 }
361
362 for(Int_t n=0;n<numgam;n++)
363 {
364 v[n+numhad]=gammas(n,mvar);
365 isort[n+numhad]=n;
366 }
367
368 TMath::Sort(numdata,v.GetArray(),isort.GetArray(),kFALSE);
369
370 // this sorts the v[n] in ascending order. isort[n] is the event number
371 // of that v[n], which is the n-th from the lowest (assume the original
372 // event numbers are 0,1,...).
373
374 for(Int_t n=0;n<numdata-1;n++)
375 {
376 const Int_t n1=isort[n];
377 const Int_t n2=isort[n+1];
378
379 fDataSort[mvar*numdata+n]=n1;
380 if(n==0) fDataRang[mvar*numdata+n1]=0;
381 if(v[n1]<v[n2])
382 {
383 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1]+1;
384 }else{
385 fDataRang[mvar*numdata+n2]=fDataRang[mvar*numdata+n1];
386 }
387 }
388 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
389 }
390}
391
392void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
393{
394 const Int_t numdim=GetNumDim();
395 const Int_t numdata=GetNumData();
396
397 ninbag=0;
398 for (Int_t n=0;n<numdata;n++)
399 if(jinbag[n]==1) ninbag++;
400
401 for(Int_t m=0;m<numdim;m++)
402 {
403 Int_t k=0;
404 Int_t nt=0;
405 for(Int_t n=0;n<numdata;n++)
406 {
407 if(jinbag[datsortinbag[m*numdata+k]]==1)
408 {
409 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k];
410 k++;
411 }else{
412 for(Int_t j=1;j<numdata-k;j++)
413 {
414 if(jinbag[datsortinbag[m*numdata+k+j]]==1)
415 {
416 datsortinbag[m*numdata+nt]=datsortinbag[m*numdata+k+j];
417 k+=j+1;
418 break;
419 }
420 }
421 }
422 nt++;
423 if(nt>=ninbag) break;
424 }
425 }
426}
427
428Bool_t MRanForest::AsciiWrite(ostream &out) const
429{
430 Int_t n=0;
431 MRanTree *tree;
432 TIter forest(fForest);
433
434 while ((tree=(MRanTree*)forest.Next()))
435 {
436 tree->AsciiWrite(out);
437 n++;
438 }
439
440 return n==fNumTrees;
441}
Note: See TracBrowser for help on using the repository browser.