source: trunk/MagicSoft/Mars/mranforest/MRanTree.cc@ 2307

Last change on this file since 2307 was 2307, checked in by tbretz, 21 years ago
*** empty log message ***
File size: 15.5 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
19!
20! Copyright: MAGIC Software Development, 2000-2003
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26// //
27// MRanTree //
28// //
29// ParameterContainer for Tree structure //
30// //
31// //
32/////////////////////////////////////////////////////////////////////////////
33#include "MRanTree.h"
34
35#include <iostream>
36
37#include <TVector.h>
38#include <TMatrix.h>
39#include <TRandom.h>
40
41#include "MDataArray.h"
42
43#include "MLog.h"
44#include "MLogManip.h"
45
46ClassImp(MRanTree);
47
48using namespace std;
49
50// --------------------------------------------------------------------------
51//
52// Default constructor.
53//
54MRanTree::MRanTree(const char *name, const char *title):fNdSize(0), fNumTry(3), fData(NULL)
55{
56
57 fName = name ? name : "MRanTree";
58 fTitle = title ? title : "Storage container for structure of a single tree";
59}
60
61void MRanTree::SetNdSize(Int_t n)
62{
63 // threshold nodesize of terminal nodes, i.e. the training data is splitted
64 // until there is only pure date in the subsets(=terminal nodes) or the
65 // subset size is LE n
66
67 fNdSize=TMath::Max(1,n);//at least 1 event per node
68}
69
70void MRanTree::SetNumTry(Int_t n)
71{
72 // number of trials in random split selection:
73 // choose at least 1 variable to split in
74
75 fNumTry=TMath::Max(1,n);
76}
77
78void MRanTree::GrowTree(const TMatrix &mhad, const TMatrix &mgam,
79 const TArrayI &hadtrue, TArrayI &datasort,
80 const TArrayI &datarang, TArrayF &tclasspop, TArrayI &jinbag,
81 const TArrayF &winbag)
82{
83 // arrays have to be initialized with generous size, so number of total nodes (nrnodes)
84 // is estimated for worst case
85 const Int_t numdim =mhad.GetNcols();
86 const Int_t numdata=winbag.GetSize();
87 const Int_t nrnodes=2*numdata+1;
88
89 // number of events in bootstrap sample
90 Int_t ninbag=0;
91 for (Int_t n=0;n<numdata;n++)
92 if(jinbag[n]==1) ninbag++;
93
94 TArrayI bestsplit(nrnodes);
95 TArrayI bestsplitnext(nrnodes);
96
97 fBestVar.Set(nrnodes);
98 fTreeMap1.Set(nrnodes);
99 fTreeMap2.Set(nrnodes);
100 fBestSplit.Set(nrnodes);
101
102 fTreeMap1.Reset();
103 fTreeMap2.Reset();
104 fBestSplit.Reset();
105
106 fGiniDec.Set(numdim);
107 fGiniDec.Reset();
108
109 // tree growing
110 BuildTree(datasort,datarang,hadtrue,bestsplit,
111 bestsplitnext,tclasspop,winbag,ninbag);
112
113 // post processing, determine cut (or split) values fBestSplit
114 Int_t nhad=mhad.GetNrows();
115
116 for(Int_t k=0; k<nrnodes; k++)
117 {
118 if (GetNodeStatus(k)==-1)
119 continue;
120
121 const Int_t &bsp =bestsplit[k];
122 const Int_t &bspn=bestsplitnext[k];
123 const Int_t &msp =fBestVar[k];
124
125 fBestSplit[k] = bsp<nhad ? mhad(bsp, msp):mgam(bsp-nhad, msp);
126 fBestSplit[k] += bspn<nhad ? mhad(bspn,msp):mgam(bspn-nhad,msp);
127 fBestSplit[k] /= 2;
128 }
129
130 // resizing arrays to save memory
131 fBestVar.Set(fNumNodes);
132 fTreeMap1.Set(fNumNodes);
133 fTreeMap2.Set(fNumNodes);
134 fBestSplit.Set(fNumNodes);
135}
136
137Int_t MRanTree::FindBestSplit(const TArrayI &datasort,const TArrayI &datarang,
138 const TArrayI &hadtrue,Int_t ndstart,Int_t ndend,TArrayF &tclasspop,
139 Int_t &msplit,Float_t &decsplit,Int_t &nbest,
140 const TArrayF &winbag)
141{
142 const Int_t nrnodes = fBestSplit.GetSize();
143 const Int_t numdata = (nrnodes-1)/2;
144 const Int_t mdim = fGiniDec.GetSize();
145
146 // weighted class populations after split
147 TArrayF wc(2);
148 TArrayF wr(2); // right node
149
150 // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
151 // split on. decsplit is the decreae in impurity measured by Gini-index.
152 // nsplit is the case number of value of msplit split on,
153 // and nsplitnext is the case number of the next larger value of msplit.
154
155 Int_t nbestvar=0;
156
157 // compute initial values of numerator and denominator of Gini-index,
158 // Gini index= pno/dno
159 Double_t pno=0;
160 Double_t pdo=0;
161 for (Int_t j=0;j<2;j++)
162 {
163 pno+=tclasspop[j]*tclasspop[j];
164 pdo+=tclasspop[j];
165 }
166
167 const Double_t crit0=pno/pdo;
168 Int_t jstat=0;
169
170 // start main loop through variables to find best split,
171 // (Gini-index as criterium crit)
172
173 Double_t critmax=-1.0e20; // FIXME: Replace by a constant from limits.h
174
175 // random split selection, number of trials = fNumTry
176 for(Int_t mt=0;mt<fNumTry;mt++)
177 {
178 const Int_t mvar=Int_t(gRandom->Rndm()*mdim);
179 const Int_t mn = mvar*numdata;
180
181 // Gini index = rrn/rrd+rln/rld
182 Double_t rrn=pno;
183 Double_t rrd=pdo;
184 Double_t rln=0;
185 Double_t rld=0;
186
187 TArrayF wl(2); // left node
188 wr = tclasspop;
189
190 Double_t critvar=-1.0e20;
191
192 for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
193 {
194 const Int_t &nc=datasort[mn+nsp];
195 const Int_t &k=hadtrue[nc];
196
197 const Float_t &u=winbag[nc];
198
199 rln+=u*(2*wl[k]+u);
200 rrn+=u*(-2*wr[k]+u);
201 rld+=u;
202 rrd-=u;
203
204 wl[k]+=u;
205 wr[k]-=u;
206
207 if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]])
208 continue;
209 if (TMath::Min(rrd,rld)<=1.0e-5)
210 continue;
211
212 const Double_t crit=(rln/rld)+(rrn/rrd);
213 if (crit<=critvar)
214 continue;
215
216 nbestvar=nsp;
217 critvar=crit;
218 }
219
220 if (critvar<=critmax)
221 continue;
222
223 msplit=mvar;
224 nbest=nbestvar;
225 critmax=critvar;
226 }
227
228 decsplit=critmax-crit0;
229
230 return critmax<-1.0e10 ? 1 : jstat;
231}
232
233void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart,
234 Int_t ndend,TArrayI &idmove,TArrayI &ncase,Int_t msplit,
235 Int_t nbest,Int_t &ndendl)
236{
237 // This is the heart of the BuildTree construction. Based on the best split
238 // the data in the part of datasort corresponding to the current node is moved to the
239 // left if it belongs to the left child and right if it belongs to the right child-node.
240 const Int_t numdata = ncase.GetSize();
241 const Int_t mdim = fGiniDec.GetSize();
242
243 TArrayI tdatasort(numdata);
244
245 // compute idmove = indicator of case nos. going left
246
247 for (Int_t nsp=ndstart;nsp<=ndend;nsp++)
248 {
249 const Int_t &nc=datasort[msplit*numdata+nsp];
250 idmove[nc]= nsp<=nbest?1:0;
251 }
252 ndendl=nbest;
253
254 // shift case. nos. right and left for numerical variables.
255
256 for(Int_t msh=0;msh<mdim;msh++)
257 {
258 Int_t k=ndstart-1;
259 for (Int_t n=ndstart;n<=ndend;n++)
260 {
261 const Int_t &ih=datasort[msh*numdata+n];
262 if (idmove[ih]==1)
263 tdatasort[++k]=datasort[msh*numdata+n];
264 }
265
266 for (Int_t n=ndstart;n<=ndend;n++)
267 {
268 const Int_t &ih=datasort[msh*numdata+n];
269 if (idmove[ih]==0)
270 tdatasort[++k]=datasort[msh*numdata+n];
271 }
272
273 for(Int_t m=ndstart;m<=ndend;m++)
274 datasort[msh*numdata+m]=tdatasort[m];
275 }
276
277 // compute case nos. for right and left nodes.
278
279 for(Int_t n=ndstart;n<=ndend;n++)
280 ncase[n]=datasort[msplit*numdata+n];
281}
282
283void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang,
284 const TArrayI &hadtrue, TArrayI &bestsplit,
285 TArrayI &bestsplitnext, TArrayF &tclasspop,
286 const TArrayF &winbag, Int_t ninbag)
287{
288 // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData.
289 // Findbestsplit does just that--it finds the best split of the current node.
290 // MoveData moves the data in the split node right and left so that the data
291 // corresponding to each child node is contiguous.
292 //
293 // buildtree bookkeeping:
294 // ncur is the total number of nodes to date. nodestatus(k)=1 if the kth node has been split.
295 // nodestatus(k)=2 if the node exists but has not yet been split, and =-1 if the node is
296 // terminal. A node is terminal if its size is below a threshold value, or if it is all
297 // one class, or if all the data-values are equal. If the current node k is split, then its
298 // children are numbered ncur+1 (left), and ncur+2(right), ncur increases to ncur+2 and
299 // the next node to be split is numbered k+1. When no more nodes can be split, buildtree
300 // returns.
301 const Int_t mdim = fGiniDec.GetSize();
302 const Int_t nrnodes = fBestSplit.GetSize();
303 const Int_t numdata = (nrnodes-1)/2;
304
305 TArrayI nodepop(nrnodes);
306 TArrayI nodestart(nrnodes);
307 TArrayI parent(nrnodes);
308
309 TArrayI ncase(numdata);
310 TArrayI idmove(numdata);
311 TArrayI iv(mdim);
312
313 TArrayF classpop(nrnodes*2);
314 TArrayI nodestatus(nrnodes);
315
316 for (Int_t j=0;j<2;j++)
317 classpop[j*nrnodes+0]=tclasspop[j];
318
319 Int_t ncur=0;
320 nodepop[0]=ninbag;
321 nodestatus[0]=2;
322
323 // start main loop
324 for (Int_t kbuild=0; kbuild<nrnodes; kbuild++)
325 {
326 if (kbuild>ncur) break;
327 if (nodestatus[kbuild]!=2) continue;
328
329 // initialize for next call to FindBestSplit
330
331 const Int_t ndstart=nodestart[kbuild];
332 const Int_t ndend=ndstart+nodepop[kbuild]-1;
333 for (Int_t j=0;j<2;j++)
334 tclasspop[j]=classpop[j*nrnodes+kbuild];
335
336 Int_t msplit, nbest;
337 Float_t decsplit=0;
338 const Int_t jstat=FindBestSplit(datasort,datarang,hadtrue,
339 ndstart,ndend,tclasspop,msplit,
340 decsplit,nbest,winbag);
341
342 if (jstat==1)
343 {
344 nodestatus[kbuild]=-1;
345 continue;
346 }
347
348 fBestVar[kbuild]=msplit;
349 fGiniDec[msplit]+=decsplit;
350
351 bestsplit[kbuild]=datasort[msplit*numdata+nbest];
352 bestsplitnext[kbuild]=datasort[msplit*numdata+nbest+1];
353
354 Int_t ndendl;
355 MoveData(datasort,ndstart,ndend,idmove,ncase,
356 msplit,nbest,ndendl);
357
358 // leftnode no.= ncur+1, rightnode no. = ncur+2.
359
360 nodepop[ncur+1]=ndendl-ndstart+1;
361 nodepop[ncur+2]=ndend-ndendl;
362 nodestart[ncur+1]=ndstart;
363 nodestart[ncur+2]=ndendl+1;
364
365 // find class populations in both nodes
366
367 for (Int_t n=ndstart;n<=ndendl;n++)
368 {
369 const Int_t &nc=ncase[n];
370 const Int_t &j=hadtrue[nc];
371 classpop[j*nrnodes+ncur+1]+=winbag[nc];
372 }
373
374 for (Int_t n=ndendl+1;n<=ndend;n++)
375 {
376 const Int_t &nc=ncase[n];
377 const Int_t &j=hadtrue[nc];
378 classpop[j*nrnodes+ncur+2]+=winbag[nc];
379 }
380
381 // check on nodestatus
382
383 nodestatus[ncur+1]=2;
384 nodestatus[ncur+2]=2;
385 if (nodepop[ncur+1]<=fNdSize) nodestatus[ncur+1]=-1;
386 if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1;
387
388 Double_t popt1=0;
389 Double_t popt2=0;
390 for (Int_t j=0;j<2;j++)
391 {
392 popt1+=classpop[j*nrnodes+ncur+1];
393 popt2+=classpop[j*nrnodes+ncur+2];
394 }
395
396 for (Int_t j=0;j<2;j++)
397 {
398 if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
399 if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
400 }
401
402 fTreeMap1[kbuild]=ncur+1;
403 fTreeMap2[kbuild]=ncur+2;
404 parent[ncur+1]=kbuild;
405 parent[ncur+2]=kbuild;
406 nodestatus[kbuild]=1;
407 ncur+=2;
408 if (ncur>=nrnodes) break;
409 }
410
411 // determine number of nodes
412 fNumNodes=nrnodes;
413 for (Int_t k=nrnodes-1;k>=0;k--)
414 {
415 if (nodestatus[k]==0) fNumNodes-=1;
416 if (nodestatus[k]==2) nodestatus[k]=-1;
417 }
418
419 fNumEndNodes=0;
420 for (Int_t kn=0;kn<fNumNodes;kn++)
421 if(nodestatus[kn]==-1)
422 {
423 fNumEndNodes++;
424 Double_t pp=0;
425 for (Int_t j=0;j<2;j++)
426 {
427 if(classpop[j*nrnodes+kn]>pp)
428 {
429 // class + status of node kn coded into fBestVar[kn]
430 fBestVar[kn]=j-2;
431 pp=classpop[j*nrnodes+kn];
432 }
433 }
434 fBestSplit[kn] =classpop[1*nrnodes+kn];
435 fBestSplit[kn]/=(classpop[0*nrnodes+kn]+classpop[1*nrnodes+kn]);
436 }
437}
438
439void MRanTree::SetRules(MDataArray *rules)
440{
441 fData=rules;
442}
443
444Double_t MRanTree::TreeHad(const TVector &event)
445{
446 Int_t kt=0;
447 // to optimize on storage space node status and node class
448 // are coded into fBestVar:
449 // status of node kt = TMath::Sign(1,fBestVar[kt])
450 // class of node kt = fBestVar[kt]+2 (class defined by larger
451 // node population, actually not used)
452 // hadronness assigned to node kt = fBestSplit[kt]
453
454 for (Int_t k=0;k<fNumNodes;k++)
455 {
456 if (fBestVar[kt]<0)
457 break;
458
459 const Int_t m=fBestVar[kt];
460 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
461 }
462
463 return fBestSplit[kt];
464}
465
466Double_t MRanTree::TreeHad(const TMatrixRow &event)
467{
468 Int_t kt=0;
469 // to optimize on storage space node status and node class
470 // are coded into fBestVar:
471 // status of node kt = TMath::Sign(1,fBestVar[kt])
472 // class of node kt = fBestVar[kt]+2 (class defined by larger
473 // node population, actually not used)
474 // hadronness assigned to node kt = fBestSplit[kt]
475
476 for (Int_t k=0;k<fNumNodes;k++)
477 {
478 if (fBestVar[kt]<0)
479 break;
480
481 const Int_t m=fBestVar[kt];
482 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
483 }
484
485 return fBestSplit[kt];
486}
487
488Double_t MRanTree::TreeHad(const TMatrix &m, Int_t ievt)
489{
490 return TreeHad(TMatrixRow(m, ievt));
491}
492
493Double_t MRanTree::TreeHad()
494{
495 TVector event;
496 *fData >> event;
497
498 return TreeHad(event);
499}
500
501Bool_t MRanTree::AsciiWrite(ostream &out) const
502{
503 TString str;
504 Int_t k;
505
506 out.width(5);out<<fNumNodes<<endl;
507
508 for (k=0;k<fNumNodes;k++)
509 {
510 str=Form("%f",GetBestSplit(k));
511
512 out.width(5); out << k;
513 out.width(5); out << GetNodeStatus(k);
514 out.width(5); out << GetTreeMap1(k);
515 out.width(5); out << GetTreeMap2(k);
516 out.width(5); out << GetBestVar(k);
517 out.width(15); out << str<<endl;
518 out.width(5); out << GetNodeClass(k);
519 }
520 out<<endl;
521
522 return k==fNumNodes;
523}
Note: See TracBrowser for help on using the repository browser.