source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 8699

Last change on this file since 8699 was 8698, checked in by tbretz, 17 years ago
*** empty log message ***
File size: 13.3 KB
Line 
1/* ======================================================================== *\
2! $Name: not supported by cvs2svn $:$Id: MRanForestCalc.cc,v 1.28 2007-08-23 10:25:08 tbretz Exp $
3! --------------------------------------------------------------------------
4!
5! *
6! * This file is part of MARS, the MAGIC Analysis and Reconstruction
7! * Software. It is distributed to you in the hope that it can be a useful
8! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
9! * It is distributed WITHOUT ANY WARRANTY.
10! *
11! * Permission to use, copy, modify and distribute this software and its
12! * documentation for any purpose is hereby granted without fee,
13! * provided that the above copyright notice appear in all copies and
14! * that both that copyright notice and this permission notice appear
15! * in supporting documentation. It is provided "as is" without express
16! * or implied warranty.
17! *
18!
19!
20! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
21! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
22!
23! Copyright: MAGIC Software Development, 2000-2006
24!
25!
26\* ======================================================================== */
27
28/////////////////////////////////////////////////////////////////////////////
29//
30// MRanForestCalc
31//
32//
33////////////////////////////////////////////////////////////////////////////
34#include "MRanForestCalc.h"
35
36#include <TF1.h>
37#include <TFile.h>
38#include <TGraph.h>
39#include <TVector.h>
40
41#include "MHMatrix.h"
42
43#include "MLog.h"
44#include "MLogManip.h"
45
46#include "MData.h"
47#include "MDataArray.h"
48
49#include "MRanForest.h"
50#include "MParameters.h"
51
52#include "MParList.h"
53#include "MTaskList.h"
54#include "MEvtLoop.h"
55#include "MRanForestGrow.h"
56#include "MFillH.h"
57
58ClassImp(MRanForestCalc);
59
60using namespace std;
61
62const TString MRanForestCalc::gsDefName = "MRanForestCalc";
63const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
64
65const TString MRanForestCalc::gsNameOutput = "RanForestOut";
66
67MRanForestCalc::MRanForestCalc(const char *name, const char *title)
68 : fData(0), fRFOut(0), fTestMatrix(0), fFunc("Function", "x"),
69 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
70 fLastDataColumnHasWeights(kFALSE),
71 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
72{
73 fName = name ? name : gsDefName.Data();
74 fTitle = title ? title : gsDefTitle.Data();
75
76 gROOT->GetListOfFunctions()->Remove(&fFunc);
77
78 // FIXME:
79 fNumTrees = 100; //100
80 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
81 fNdSize = 1; //1
82}
83
84MRanForestCalc::~MRanForestCalc()
85{
86 fEForests.Delete();
87}
88
89// --------------------------------------------------------------------------
90//
91// Set a function which is applied to the output of the random forest
92//
93Bool_t MRanForestCalc::SetFunction(const char *func)
94{
95 return !fFunc.Compile(func);
96}
97
98// --------------------------------------------------------------------------
99//
100// ver=0: One yes/no-classification forest is trained for each bin.
101// the yes/no classification is done using the grid
102// ver=1: One classification forest is trained. The last column contains a
103// value which is turned into a classifier by rf itself using the grid
104// ver=2: One classification forest is trained. The last column already contains
105// the classifier
106// ver=3: A regression forest is trained. The last column contains the
107// classifier
108//
109Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
110{
111 gLog.Separator("MRanForestCalc - Train");
112
113 if (!matrixtrain.GetColumns())
114 {
115 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
116 return kFALSE;
117 }
118
119 const Int_t ncols = matrixtrain.GetM().GetNcols();
120 const Int_t nrows = matrixtrain.GetM().GetNrows();
121 if (ncols<=0 || nrows <=0)
122 {
123 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
124 return kFALSE;
125 }
126
127 // rules (= combination of image par) to be used for energy estimation
128 TFile fileRF(fFileName, "recreate");
129 if (!fileRF.IsOpen())
130 {
131 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
132 return kFALSE;
133 }
134
135 // The number of columns which have to be removed for the training
136 // The last data column may contain weight which also have to be removed
137 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
138
139 const MDataArray &dcol = *matrixtrain.GetColumns();
140
141 // Make a copy of the rules for accessing the train-data
142 MDataArray usedrules;
143 for (Int_t i=0; i<ncols; i++)
144 if (i<ncols-nobs) // -3 is important!!!
145 usedrules.AddEntry(dcol[i].GetRule());
146 else
147 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
148
149 // In the case of regression store the rule to be regessed in the
150 // last entry of your rules
151 MDataArray rules(usedrules);
152 rules.AddEntry(ver<3?"Classification.fVal":dcol[ncols-1].GetRule().Data());
153
154 // prepare train-matrix finally used
155 TMatrix mat(matrixtrain.GetM());
156
157 // Resize it such that the obsolete columns are removed
158 mat.ResizeTo(nrows, ncols-nobs+1);
159
160 if (fDebug)
161 gLog.SetNullOutput(kTRUE);
162
163 // In the case one independant RF is trained for each bin (e.g.
164 // energy-bin) train all of them
165 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
166 for (Int_t ie=0; ie<nbins; ie++)
167 {
168 // In the case weights should be used initialize the
169 // corresponding array
170 TArrayF weights(nrows);
171 if (fLastDataColumnHasWeights)
172 for (Int_t j=0; j<nrows; j++)
173 {
174 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
175 if (j%250==0)
176 cout << weights[j] << " ";
177 }
178
179 // Setup the matrix such that the last comlumn contains
180 // the classifier or the regeression target value
181 switch (ver)
182 {
183 case 0: // Replace last column by a classification which is 1 in
184 // the case the event belongs to this bin, 0 otherwise
185 {
186 Int_t irows=0;
187 for (Int_t j=0; j<nrows; j++)
188 {
189 const Double_t value = matrixtrain.GetM()(j,ncols-1);
190 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
191
192 mat(j, ncols-nobs) = inside ? 1 : 0;
193
194 if (inside)
195 irows++;
196 }
197 if (irows==0)
198 *fLog << warn << "WARNING - Skipping";
199 else
200 *fLog << inf << "Training RF for";
201
202 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
203
204 if (irows==0)
205 continue;
206 }
207 break;
208
209 case 1: // Use last column as classifier or for regression
210 case 2:
211 case 3:
212 for (Int_t j=0; j<nrows; j++)
213 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
214 break;
215 }
216
217 MHMatrix matrix(mat, &rules, "MatrixTrain");
218
219 MParList plist;
220 MTaskList tlist;
221 plist.AddToList(&tlist);
222 plist.AddToList(&matrix);
223
224 MRanForest rf;
225 rf.SetNumTrees(fNumTrees);
226 rf.SetNumTry(fNumTry);
227 rf.SetNdSize(fNdSize);
228 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
229 if (ver==1)
230 rf.SetGrid(grid);
231 if (fLastDataColumnHasWeights)
232 rf.SetWeights(weights);
233
234 plist.AddToList(&rf);
235
236 MRanForestGrow rfgrow;
237 tlist.AddToList(&rfgrow);
238
239 MFillH fillh("MHRanForestGini");
240 tlist.AddToList(&fillh);
241
242 MEvtLoop evtloop(fTitle);
243 evtloop.SetParList(&plist);
244 evtloop.SetDisplay(fDisplay);
245 evtloop.SetLogStream(fLog);
246
247 if (!evtloop.Eventloop())
248 return kFALSE;
249
250 if (fDebug)
251 gLog.SetNullOutput(kFALSE);
252
253 if (ver==0)
254 {
255 // Calculate bin center
256 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
257
258 // save whole forest
259 rf.SetUserVal(E);
260 rf.SetName(Form("%.10f", E));
261 }
262
263 rf.Write();
264 }
265
266 // save rules
267 usedrules.Write("rules");
268
269 return kTRUE;
270}
271
272Int_t MRanForestCalc::ReadForests(MParList &plist)
273{
274 TFile fileRF(fFileName, "read");
275 if (!fileRF.IsOpen())
276 {
277 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
278 return kFALSE;
279 }
280
281 fEForests.Delete();
282
283 TIter Next(fileRF.GetListOfKeys());
284 TObject *o=0;
285 while ((o=Next()))
286 {
287 MRanForest *forest=0;
288 fileRF.GetObject(o->GetName(), forest);
289 if (!forest)
290 continue;
291
292 forest->SetUserVal(atof(o->GetName()));
293
294 fEForests.Add(forest);
295 }
296
297 // Maybe fEForests[0].fRules could be used instead?
298 if (fData->Read("rules")<=0)
299 {
300 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
301 return kFALSE;
302 }
303
304 if (fileRF.GetListOfKeys()->FindObject("Function"))
305 fFunc.Read("Function");
306
307 return kTRUE;
308}
309
310Int_t MRanForestCalc::PreProcess(MParList *plist)
311{
312 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
313 if (!fRFOut)
314 return kFALSE;
315
316 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
317 if (!fData)
318 return kFALSE;
319
320 if (!ReadForests(*plist))
321 {
322 *fLog << err << "Reading RFs failed... aborting." << endl;
323 return kFALSE;
324 }
325
326 *fLog << inf << "RF read from " << fFileName << endl;
327
328 if (fTestMatrix)
329 return kTRUE;
330
331 fData->Print();
332
333 if (!fData->PreProcess(plist))
334 {
335 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
336 return kFALSE;
337 }
338
339 return kTRUE;
340}
341
342Double_t MRanForestCalc::Eval() const
343{
344 TVector event;
345 if (fTestMatrix)
346 *fTestMatrix >> event;
347 else
348 *fData >> event;
349
350 // --------------- Single Tree RF -------------------
351 if (fEForests.GetEntriesFast()==1)
352 {
353 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
354 return rf->CalcHadroness(event);
355 }
356
357 // --------------- Multi Tree RF -------------------
358 static TF1 f1("f1", "gaus");
359
360 Double_t sume = 0;
361 Double_t sumh = 0;
362 Double_t maxh = 0;
363 Double_t maxe = 0;
364
365 Double_t max = -1e10;
366 Double_t min = 1e10;
367
368 TIter Next(&fEForests);
369 MRanForest *rf = 0;
370
371 TGraph g;
372 while ((rf=(MRanForest*)Next()))
373 {
374 const Double_t h = rf->CalcHadroness(event);
375 const Double_t e = rf->GetUserVal();
376
377 g.SetPoint(g.GetN(), e, h);
378
379 sume += e*h;
380 sumh += h;
381
382 if (h>maxh)
383 {
384 maxh = h;
385 maxe = e;
386 }
387 if (e>max)
388 max = e;
389 if (e<min)
390 min = e;
391 }
392
393 switch (fEstimationMode)
394 {
395 case kMean:
396 return sume/sumh;
397 case kMaximum:
398 return maxe;
399 case kFit:
400 f1.SetParameter(0, maxh);
401 f1.SetParameter(1, maxe);
402 f1.SetParameter(2, 0.125);
403 g.Fit(&f1, "Q0N");
404 return f1.GetParameter(1);
405 }
406
407 return 0;
408}
409
410Int_t MRanForestCalc::Process()
411{
412 const Double_t val = Eval();
413
414 fRFOut->SetVal(fFunc.Eval(val));
415 fRFOut->SetReadyToSave();
416
417 return kTRUE;
418}
419
420void MRanForestCalc::Print(Option_t *o) const
421{
422 *fLog << all;
423 *fLog << GetDescriptor() << ":" << endl;
424 *fLog << " - Forest ";
425 switch (fEForests.GetEntries())
426 {
427 case 0: *fLog << "not yet initialized." << endl; break;
428 case 1: *fLog << "is a single tree forest." << endl; break;
429 default: *fLog << "is a multi tree forest." << endl; break;
430 }
431 /*
432 *fLog << " - Trees: " << fNumTrees << endl;
433 *fLog << " - Trys: " << fNumTry << endl;
434 *fLog << " - Node Size: " << fNdSize << endl;
435 *fLog << " - Node Size: " << fNdSize << endl;
436 */
437 *fLog << " - FileName: " << fFileName << endl;
438 *fLog << " - NameOutput: " << fNameOutput << endl;
439}
440
441// --------------------------------------------------------------------------
442//
443//
444Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
445{
446 Bool_t rc = kFALSE;
447 if (IsEnvDefined(env, prefix, "FileName", print))
448 {
449 rc = kTRUE;
450 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
451 }
452 if (IsEnvDefined(env, prefix, "Debug", print))
453 {
454 rc = kTRUE;
455 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
456 }
457 if (IsEnvDefined(env, prefix, "NameOutput", print))
458 {
459 rc = kTRUE;
460 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
461 }
462 if (IsEnvDefined(env, prefix, "EstimationMode", print))
463 {
464 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
465 txt = txt.Strip(TString::kBoth);
466 txt.ToLower();
467 if (txt==(TString)"mean")
468 fEstimationMode = kMean;
469 if (txt==(TString)"maximum")
470 fEstimationMode = kMaximum;
471 if (txt==(TString)"fit")
472 fEstimationMode = kFit;
473 rc = kTRUE;
474 }
475 return rc;
476}
Note: See TracBrowser for help on using the repository browser.