source: trunk/MagicSoft/Mars/mranforest/MRanTree.cc@ 2296

Last change on this file since 2296 was 2296, checked in by tbretz, 21 years ago
*** empty log message ***
File size: 15.8 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
19!
20! Copyright: MAGIC Software Development, 2000-2003
21!
22!
23\* ======================================================================== */
24
25/////////////////////////////////////////////////////////////////////////////
26// //
27// MRanTree //
28// //
29// ParameterContainer for Tree structure //
30// //
31// //
32/////////////////////////////////////////////////////////////////////////////
33#include "MRanTree.h"
34
35#include <iostream>
36
37#include <TVector.h>
38#include <TMatrix.h>
39#include <TRandom.h>
40
41#include "MDataArray.h"
42
43#include "MLog.h"
44#include "MLogManip.h"
45
46ClassImp(MRanTree);
47
48using namespace std;
49
50// --------------------------------------------------------------------------
51//
52// Default constructor.
53//
54MRanTree::MRanTree(const char *name, const char *title):fNdSize(0), fNumTry(3), fData(NULL)
55{
56
57 fName = name ? name : "MRanTree";
58 fTitle = title ? title : "Storage container for structure of a single tree";
59}
60
61void MRanTree::SetNdSize(Int_t n)
62{
63 // threshold nodesize of terminal nodes, i.e. the training data is splitted
64 // until there is only pure date in the subsets(=terminal nodes) or the
65 // subset size is LE n
66
67 fNdSize=TMath::Max(1,n);//at least 1 event per node
68}
69
70void MRanTree::SetNumTry(Int_t n)
71{
72 // number of trials in random split selection:
73 // choose at least 1 variable to split in
74
75 fNumTry=TMath::Max(1,n);
76}
77
78void MRanTree::GrowTree(const TMatrix &mhad, const TMatrix &mgam,Int_t numdata, Int_t numdim,TArrayI &hadtrue,
79 TArrayI &datasort,TArrayI &datarang,TArrayF &tclasspop,TArrayI &jinbag,
80 TArrayF &winbag,TArrayF &weight)
81{
82 // arrays have to be initialized with generous size, so number of total nodes (nrnodes)
83 // is estimated for worst case
84 Int_t nrnodes=2*numdata+1;
85
86 // number of events in bootstrap sample
87 Int_t ninbag=0;
88 for (Int_t n=0;n<numdata;n++)
89 if(jinbag[n]==1) ninbag++;
90
91 TArrayI bestsplit(nrnodes);
92 TArrayI bestsplitnext(nrnodes);
93 TArrayI nodepop(nrnodes);
94 TArrayI parent(nrnodes);
95 TArrayI nodex(numdata);
96 TArrayI nodestart(nrnodes);
97
98 TArrayI ncase(numdata);
99 TArrayI iv(numdim);
100 TArrayI idmove(numdata);
101
102 fBestVar.Set(nrnodes);
103 fTreeMap1.Set(nrnodes);
104 fTreeMap2.Set(nrnodes);
105 fBestSplit.Set(nrnodes);
106
107 fTreeMap1.Reset();
108 fTreeMap2.Reset();
109 fBestSplit.Reset();
110
111 fGiniDec.Set(numdim);
112 fGiniDec.Reset();
113
114 // tree growing
115 BuildTree(datasort,datarang,hadtrue,numdim,numdata,bestsplit,
116 bestsplitnext,nodepop,nodestart,tclasspop,nrnodes,
117 idmove,ncase,parent,jinbag,iv,winbag,ninbag);
118
119 // post processing, determine cut (or split) values fBestSplit
120 Int_t nhad=mhad.GetNrows();
121
122 for(Int_t k=0; k<nrnodes; k++)
123 {
124 if (GetNodeStatus(k)==-1)
125 continue;
126
127 const Int_t &bsp =bestsplit[k];
128 const Int_t &bspn=bestsplitnext[k];
129 const Int_t &msp =fBestVar[k];
130
131 fBestSplit[k] = bsp<nhad ? mhad(bsp, msp):mgam(bsp-nhad, msp);
132 fBestSplit[k] += bspn<nhad ? mhad(bspn,msp):mgam(bspn-nhad,msp);
133 fBestSplit[k] /= 2;
134 }
135
136 // resizing arrays to save memory
137 fBestVar.Set(fNumNodes);
138 fTreeMap1.Set(fNumNodes);
139 fTreeMap2.Set(fNumNodes);
140 fBestSplit.Set(fNumNodes);
141}
142
143Int_t MRanTree::FindBestSplit(TArrayI &datasort,TArrayI &datarang,TArrayI &hadtrue,Int_t mdim,
144 Int_t numdata,Int_t ndstart,Int_t ndend,TArrayF &tclasspop,
145 Int_t &msplit,Float_t &decsplit,Int_t &nbest,TArrayI &ncase,
146 TArrayI &jinbag,TArrayI &iv,TArrayF &winbag,Int_t kbuild)
147{
148 if(!gRandom)
149 {
150 *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
151 return kFALSE;
152 }
153
154 // weighted class populations after split
155 TArrayF wc(2);
156 TArrayF wr(2); // right node
157
158 // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
159 // split on. decsplit is the decreae in impurity measured by Gini-index.
160 // nsplit is the case number of value of msplit split on,
161 // and nsplitnext is the case number of the next larger value of msplit.
162
163 Int_t nc,nbestvar=0,k;
164 Float_t crit;
165 Float_t rrn, rrd, rln, rld, u;
166
167 // compute initial values of numerator and denominator of Gini-index,
168 // Gini index= pno/dno
169 Float_t pno=0;
170 Float_t pdo=0;
171
172 for (Int_t j=0;j<2;j++)
173 {
174 pno+=tclasspop[j]*tclasspop[j];
175 pdo+=tclasspop[j];
176 }
177
178 const Double_t crit0=pno/pdo;
179 Int_t jstat=0;
180
181 // start main loop through variables to find best split,
182 // (Gini-index as criterium crit)
183
184 Double_t critmax=-1.0e20; // FIXME: Replace by a constant from limits.h
185
186 // random split selection, number of trials = fNumTry
187 for(Int_t mt=0;mt<fNumTry;mt++)
188 {
189 Int_t mvar=Int_t(gRandom->Rndm()*mdim);
190
191 // Gini index = rrn/rrd+rln/rld
192 rrn=pno;
193 rrd=pdo;
194 rln=0;
195 rld=0;
196
197 TArrayF wl(2); // left node
198 wr = tclasspop;
199
200 Double_t critvar=-1.0e20;
201
202 for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
203 {
204 nc=datasort[mvar*numdata+nsp];
205
206 u=winbag[nc];
207 k=hadtrue[nc];
208
209 rln=rln+u*(2*wl[k]+u);
210 rrn=rrn+u*(-2*wr[k]+u);
211 rld=rld+u;
212 rrd=rrd-u;
213
214 wl[k]=wl[k]+u;
215 wr[k]=wr[k]-u;
216
217 if (datarang[mvar*numdata+nc]<datarang[mvar*numdata+datasort[mvar*numdata+nsp+1]])
218 {
219 if (TMath::Min(rrd,rld)>1.0e-5)
220 {
221 crit=(rln/rld)+(rrn/rrd);
222 if (crit>critvar)
223 {
224 nbestvar=nsp;
225 critvar=crit;
226 }
227 }
228 }
229 }
230
231 if (critvar>critmax) {
232 msplit=mvar;
233 nbest=nbestvar;
234 critmax=critvar;
235 }
236 }
237
238 decsplit=critmax-crit0;
239 if (critmax<-1.0e10) jstat=1;
240
241 return jstat;
242}
243
244void MRanTree::MoveData(TArrayI &datasort,Int_t mdim,Int_t numdata,Int_t ndstart,
245 Int_t ndend,TArrayI &idmove,TArrayI &ncase,Int_t msplit,
246 Int_t nbest,Int_t &ndendl)
247{
248 // This is the heart of the BuildTree construction. Based on the best split
249 // the data in the part of datasort corresponding to the current node is moved to the
250 // left if it belongs to the left child and right if it belongs to the right child-node.
251
252 Int_t nc,k,ih;
253 TArrayI tdatasort(numdata);
254
255 // compute idmove = indicator of case nos. going left
256
257 for (Int_t nsp=ndstart;nsp<=nbest;nsp++)
258 {
259 nc=datasort[msplit*numdata+nsp];
260 idmove[nc]=1;
261 }
262 for (Int_t nsp=nbest+1;nsp<=ndend;nsp++)
263 {
264 nc=datasort[msplit*numdata+nsp];
265 idmove[nc]=0;
266 }
267 ndendl=nbest;
268
269 // shift case. nos. right and left for numerical variables.
270
271 for(Int_t msh=0;msh<mdim;msh++)
272 {
273 k=ndstart-1;
274 for (Int_t n=ndstart;n<=ndend;n++)
275 {
276 ih=datasort[msh*numdata+n];
277 if (idmove[ih]==1) {
278 k++;
279 tdatasort[k]=datasort[msh*numdata+n];
280 }
281 }
282
283 for (Int_t n=ndstart;n<=ndend;n++)
284 {
285 ih=datasort[msh*numdata+n];
286 if (idmove[ih]==0){
287 k++;
288 tdatasort[k]=datasort[msh*numdata+n];
289 }
290 }
291 for(Int_t k=ndstart;k<=ndend;k++)
292 datasort[msh*numdata+k]=tdatasort[k];
293 }
294
295 // compute case nos. for right and left nodes.
296
297 for(Int_t n=ndstart;n<=ndend;n++)
298 ncase[n]=datasort[msplit*numdata+n];
299}
300
301void MRanTree::BuildTree(TArrayI &datasort,TArrayI &datarang,TArrayI &hadtrue,Int_t mdim,
302 Int_t numdata,TArrayI &bestsplit,TArrayI &bestsplitnext,
303 TArrayI &nodepop,TArrayI &nodestart,TArrayF &tclasspop,
304 Int_t nrnodes,TArrayI &idmove,TArrayI &ncase,TArrayI &parent,
305 TArrayI &jinbag,TArrayI &iv,TArrayF &winbag,Int_t ninbag)
306{
307 // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData.
308 // Findbestsplit does just that--it finds the best split of the current node.
309 // MoveData moves the data in the split node right and left so that the data
310 // corresponding to each child node is contiguous.
311 //
312 // buildtree bookkeeping:
313 // ncur is the total number of nodes to date. nodestatus(k)=1 if the kth node has been split.
314 // nodestatus(k)=2 if the node exists but has not yet been split, and =-1 if the node is
315 // terminal. A node is terminal if its size is below a threshold value, or if it is all
316 // one class, or if all the data-values are equal. If the current node k is split, then its
317 // children are numbered ncur+1 (left), and ncur+2(right), ncur increases to ncur+2 and
318 // the next node to be split is numbered k+1. When no more nodes can be split, buildtree
319 // returns.
320
321 Int_t msplit, nbest, ndendl;
322 Float_t decsplit=0;
323 TArrayF classpop(2*nrnodes);
324 TArrayI nodestatus(nrnodes);
325
326 nodestart.Reset();
327 nodepop.Reset();
328
329 for (Int_t j=0;j<2;j++)
330 classpop[j*nrnodes+0]=tclasspop[j];
331
332 Int_t ncur=0;
333 nodestart[0]=0;
334 nodepop[0]=ninbag;
335 nodestatus[0]=2;
336
337 // start main loop
338 for (Int_t kbuild=0;kbuild<nrnodes;kbuild++)
339 {
340 if (kbuild>ncur) break;
341 if (nodestatus[kbuild]!=2) continue;
342
343 // initialize for next call to FindBestSplit
344
345 const Int_t ndstart=nodestart[kbuild];
346 const Int_t ndend=ndstart+nodepop[kbuild]-1;
347 for (Int_t j=0;j<2;j++)
348 tclasspop[j]=classpop[j*nrnodes+kbuild];
349
350 const Int_t jstat=FindBestSplit(datasort,datarang,hadtrue,mdim,numdata,
351 ndstart,ndend,tclasspop,msplit,decsplit,
352 nbest,ncase,jinbag,iv,winbag,kbuild);
353
354 if(jstat==1) {
355 nodestatus[kbuild]=-1;
356 continue;
357 }else{
358 fBestVar[kbuild]=msplit;
359 fGiniDec[msplit]+=decsplit;
360
361 bestsplit[kbuild]=datasort[msplit*numdata+nbest];
362 bestsplitnext[kbuild]=datasort[msplit*numdata+nbest+1];
363 }
364
365 MoveData(datasort,mdim,numdata,ndstart,ndend,idmove,ncase,
366 msplit,nbest,ndendl);
367
368 // leftnode no.= ncur+1, rightnode no. = ncur+2.
369
370 nodepop[ncur+1]=ndendl-ndstart+1;
371 nodepop[ncur+2]=ndend-ndendl;
372 nodestart[ncur+1]=ndstart;
373 nodestart[ncur+2]=ndendl+1;
374
375 // find class populations in both nodes
376
377 for (Int_t n=ndstart;n<=ndendl;n++)
378 {
379 const Int_t nc=ncase[n];
380 const Int_t j=hadtrue[nc];
381 classpop[j*nrnodes+ncur+1]+=winbag[nc];
382 }
383
384 for (Int_t n=ndendl+1;n<=ndend;n++)
385 {
386 const Int_t nc=ncase[n];
387 const Int_t j=hadtrue[nc];
388 classpop[j*nrnodes+ncur+2]+=winbag[nc];
389 }
390
391 // check on nodestatus
392
393 nodestatus[ncur+1]=2;
394 nodestatus[ncur+2]=2;
395 if (nodepop[ncur+1]<=fNdSize) nodestatus[ncur+1]=-1;
396 if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1;
397
398 Double_t popt1=0;
399 Double_t popt2=0;
400 for (Int_t j=0;j<2;j++)
401 {
402 popt1+=classpop[j*nrnodes+ncur+1];
403 popt2+=classpop[j*nrnodes+ncur+2];
404 }
405
406 for (Int_t j=0;j<2;j++)
407 {
408 if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
409 if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
410 }
411
412 fTreeMap1[kbuild]=ncur+1;
413 fTreeMap2[kbuild]=ncur+2;
414 parent[ncur+1]=kbuild;
415 parent[ncur+2]=kbuild;
416 nodestatus[kbuild]=1;
417 ncur+=2;
418 if (ncur>=nrnodes) break;
419 }
420
421 // determine number of nodes
422 fNumNodes=nrnodes;
423 for (Int_t k=nrnodes-1;k>=0;k--)
424 {
425 if (nodestatus[k]==0) fNumNodes-=1;
426 if (nodestatus[k]==2) nodestatus[k]=-1;
427 }
428
429 fNumEndNodes=0;
430 for (Int_t kn=0;kn<fNumNodes;kn++)
431 if(nodestatus[kn]==-1)
432 {
433 fNumEndNodes++;
434 Double_t pp=0;
435 for (Int_t j=0;j<2;j++)
436 {
437 if(classpop[j*nrnodes+kn]>pp)
438 {
439 // class + status of node kn coded into fBestVar[kn]
440 fBestVar[kn]=j-2;
441 pp=classpop[j*nrnodes+kn];
442 }
443 }
444 fBestSplit[kn] =classpop[1*nrnodes+kn];
445 fBestSplit[kn]/=(classpop[0*nrnodes+kn]+classpop[1*nrnodes+kn]);
446 }
447}
448
449void MRanTree::SetRules(MDataArray *rules)
450{
451 fData=rules;
452}
453
454Double_t MRanTree::TreeHad(const TVector &event)
455{
456 Int_t kt=0;
457 // to optimize on storage space node status and node class
458 // are coded into fBestVar:
459 // status of node kt = TMath::Sign(1,fBestVar[kt])
460 // class of node kt = fBestVar[kt]+2 (class defined by larger
461 // node population, actually not used)
462 // hadronness assigned to node kt = fBestSplit[kt]
463
464 for (Int_t k=0;k<fNumNodes;k++)
465 {
466 if (fBestVar[kt]<0)
467 break;
468
469 const Int_t m=fBestVar[kt];
470 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
471 }
472
473 return fBestSplit[kt];
474}
475
476Double_t MRanTree::TreeHad(const TMatrixRow &event)
477{
478 Int_t kt=0;
479 // to optimize on storage space node status and node class
480 // are coded into fBestVar:
481 // status of node kt = TMath::Sign(1,fBestVar[kt])
482 // class of node kt = fBestVar[kt]+2 (class defined by larger
483 // node population, actually not used)
484 // hadronness assigned to node kt = fBestSplit[kt]
485
486 for (Int_t k=0;k<fNumNodes;k++)
487 {
488 if (fBestVar[kt]<0)
489 break;
490
491 const Int_t m=fBestVar[kt];
492 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
493 }
494
495 return fBestSplit[kt];
496}
497
498Double_t MRanTree::TreeHad(const TMatrix &m, Int_t ievt)
499{
500 return TreeHad(TMatrixRow(m, ievt));
501}
502
503Double_t MRanTree::TreeHad()
504{
505 TVector event;
506 *fData >> event;
507
508 return TreeHad(event);
509}
510
511Bool_t MRanTree::AsciiWrite(ostream &out) const
512{
513 TString str;
514 Int_t k;
515
516 out.width(5);out<<fNumNodes<<endl;
517
518 for (k=0;k<fNumNodes;k++)
519 {
520 str=Form("%f",GetBestSplit(k));
521
522 out.width(5); out << k;
523 out.width(5); out << GetNodeStatus(k);
524 out.width(5); out << GetTreeMap1(k);
525 out.width(5); out << GetTreeMap2(k);
526 out.width(5); out << GetBestVar(k);
527 out.width(15); out << str<<endl;
528 out.width(5); out << GetNodeClass(k);
529 }
530 out<<endl;
531
532 return k==fNumNodes;
533}
Note: See TracBrowser for help on using the repository browser.