source: branches/AddingGoogleTestEnvironment/mranforest/MRanForestCalc.cc@ 20094

Last change on this file since 20094 was 10166, checked in by tbretz, 14 years ago
Removed the old obsolete cvs header line.
File size: 13.7 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
19! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
20!
21! Copyright: MAGIC Software Development, 2000-2008
22!
23!
24\* ======================================================================== */
25
26/////////////////////////////////////////////////////////////////////////////
27//
28// MRanForestCalc
29//
30//
31////////////////////////////////////////////////////////////////////////////
32#include "MRanForestCalc.h"
33
34#include <stdlib.h> // Needed for atof in some cases
35
36#include <TMath.h>
37
38#include <TF1.h>
39#include <TFile.h>
40#include <TGraph.h>
41#include <TVector.h>
42
43#include "MHMatrix.h"
44
45#include "MLog.h"
46#include "MLogManip.h"
47
48#include "MData.h"
49#include "MDataArray.h"
50
51#include "MRanForest.h"
52#include "MParameters.h"
53
54#include "MParList.h"
55#include "MTaskList.h"
56#include "MEvtLoop.h"
57#include "MRanForestGrow.h"
58#include "MFillH.h"
59
60ClassImp(MRanForestCalc);
61
62using namespace std;
63
64const TString MRanForestCalc::gsDefName = "MRanForestCalc";
65const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
66
67const TString MRanForestCalc::gsNameOutput = "RanForestOut";
68const TString MRanForestCalc::gsNameEvalFunc = "EvalFunction";
69
70MRanForestCalc::MRanForestCalc(const char *name, const char *title)
71 : fData(0), fRFOut(0), fTestMatrix(0), fFunc("x"),
72 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
73 fLastDataColumnHasWeights(kFALSE),
74 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
75{
76 fName = name ? name : gsDefName.Data();
77 fTitle = title ? title : gsDefTitle.Data();
78
79 // FIXME:
80 fNumTrees = 100; //100
81 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
82 fNdSize = 1; //1
83}
84
85MRanForestCalc::~MRanForestCalc()
86{
87 fEForests.Delete();
88}
89
90// --------------------------------------------------------------------------
91//
92// Set a function which is applied to the output of the random forest
93//
94Bool_t MRanForestCalc::SetFunction(const char *func)
95{
96 return !fFunc.SetRule(func);
97}
98
99// --------------------------------------------------------------------------
100//
101// ver=0: One yes/no-classification forest is trained for each bin.
102// the yes/no classification is done using the grid
103// ver=1: One classification forest is trained. The last column contains a
104// value which is turned into a classifier by rf itself using the grid
105// ver=2: One classification forest is trained. The last column already contains
106// the classifier
107// ver=3: A regression forest is trained. The last column contains the
108// classifier
109//
110Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
111{
112 gLog.Separator("MRanForestCalc - Train");
113
114 if (!matrixtrain.GetColumns())
115 {
116 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
117 return kFALSE;
118 }
119
120 const Int_t ncols = matrixtrain.GetM().GetNcols();
121 const Int_t nrows = matrixtrain.GetM().GetNrows();
122 if (ncols<=0 || nrows <=0)
123 {
124 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
125 return kFALSE;
126 }
127
128 // rules (= combination of image par) to be used for energy estimation
129 TFile fileRF(fFileName, "recreate");
130 if (!fileRF.IsOpen())
131 {
132 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
133 return kFALSE;
134 }
135
136 // The number of columns which have to be removed for the training
137 // The last data column may contain weight which also have to be removed
138 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
139
140 const MDataArray &dcol = *matrixtrain.GetColumns();
141
142 // Make a copy of the rules for accessing the train-data
143 MDataArray usedrules;
144 for (Int_t i=0; i<ncols; i++)
145 if (i<ncols-nobs) // -3 is important!!!
146 usedrules.AddEntry(dcol[i].GetRule());
147 else
148 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
149
150 // In the case of regression store the rule to be regessed in the
151 // last entry of your rules
152 MDataArray rules(usedrules);
153 rules.AddEntry(ver<3?"Classification.fVal":dcol[ncols-1].GetRule().Data());
154
155 // prepare train-matrix finally used
156 TMatrix mat(matrixtrain.GetM());
157
158 // Resize it such that the obsolete columns are removed
159 mat.ResizeTo(nrows, ncols-nobs+1);
160
161 if (fDebug)
162 gLog.SetNullOutput(kTRUE);
163
164 // In the case one independant RF is trained for each bin (e.g.
165 // energy-bin) train all of them
166 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
167 for (Int_t ie=0; ie<nbins; ie++)
168 {
169 // In the case weights should be used initialize the
170 // corresponding array
171 Double_t sum = 0;
172
173 TArrayF weights(nrows);
174 if (fLastDataColumnHasWeights)
175 {
176 for (Int_t j=0; j<nrows; j++)
177 {
178 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
179 sum += weights[j];
180 }
181 }
182
183 *fLog << inf << "MRanForestCalc::Train: Sum of weights " << sum << endl;
184
185 // Setup the matrix such that the last comlumn contains
186 // the classifier or the regeression target value
187 switch (ver)
188 {
189 case 0: // Replace last column by a classification which is 1 in
190 // the case the event belongs to this bin, 0 otherwise
191 {
192 Int_t irows=0;
193 for (Int_t j=0; j<nrows; j++)
194 {
195 const Double_t value = matrixtrain.GetM()(j,ncols-1);
196 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
197
198 mat(j, ncols-nobs) = inside ? 1 : 0;
199
200 if (inside)
201 irows++;
202 }
203 if (irows==0)
204 *fLog << warn << "WARNING - Skipping";
205 else
206 *fLog << inf << "Training RF for";
207
208 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
209
210 if (irows==0)
211 continue;
212 }
213 break;
214
215 case 1: // Use last column as classifier or for regression
216 case 2:
217 case 3:
218 for (Int_t j=0; j<nrows; j++)
219 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
220 break;
221 }
222
223 MHMatrix matrix(mat, &rules, "MatrixTrain");
224
225 MParList plist;
226 MTaskList tlist;
227 plist.AddToList(&tlist);
228 plist.AddToList(&matrix);
229
230 MRanForest rf;
231 rf.SetNumTrees(fNumTrees);
232 rf.SetNumTry(fNumTry);
233 rf.SetNdSize(fNdSize);
234 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
235 if (ver==1)
236 rf.SetGrid(grid);
237 if (fLastDataColumnHasWeights)
238 rf.SetWeights(weights);
239
240 plist.AddToList(&rf);
241
242 MRanForestGrow rfgrow;
243 tlist.AddToList(&rfgrow);
244
245 MFillH fillh("MHRanForestGini");
246 tlist.AddToList(&fillh);
247
248 MEvtLoop evtloop(fTitle);
249 evtloop.SetParList(&plist);
250 evtloop.SetDisplay(fDisplay);
251 evtloop.SetLogStream(fLog);
252
253 if (!evtloop.Eventloop())
254 return kFALSE;
255
256 if (fDebug)
257 gLog.SetNullOutput(kFALSE);
258
259 if (ver==0)
260 {
261 // Calculate bin center
262 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
263
264 // save whole forest
265 rf.SetUserVal(E);
266 rf.SetName(Form("%.10f", E));
267 }
268
269 rf.Write();
270 }
271
272 // save rules
273 usedrules.Write("rules");
274
275 fFunc.Write(gsNameEvalFunc);
276
277 return kTRUE;
278}
279
280Int_t MRanForestCalc::ReadForests(MParList &plist)
281{
282 TFile fileRF(fFileName, "read");
283 if (!fileRF.IsOpen())
284 {
285 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
286 return kFALSE;
287 }
288
289 fEForests.Delete();
290
291 TIter Next(fileRF.GetListOfKeys());
292 TObject *o=0;
293 while ((o=Next()))
294 {
295 MRanForest *forest=0;
296 fileRF.GetObject(o->GetName(), forest);
297 if (!forest)
298 continue;
299
300 forest->SetUserVal(atof(o->GetName()));
301
302 fEForests.Add(forest);
303 }
304
305 // Maybe fEForests[0].fRules could be used instead?
306 if (fData->Read("rules")<=0)
307 {
308 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
309 return kFALSE;
310 }
311
312 if (fileRF.GetListOfKeys()->FindObject(gsNameEvalFunc))
313 {
314 if (fFunc.Read(gsNameEvalFunc)<=0)
315 {
316 *fLog << err << "ERROR - Reading '" << gsNameEvalFunc << "' from file " << fFileName << endl;
317 return kFALSE;
318 }
319
320 *fLog << inf << "Evaluation function found in file: " << fFunc.GetRule() << endl;
321 }
322
323 return kTRUE;
324}
325
326Int_t MRanForestCalc::PreProcess(MParList *plist)
327{
328 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
329 if (!fRFOut)
330 return kFALSE;
331
332 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
333 if (!fData)
334 return kFALSE;
335
336 if (!ReadForests(*plist))
337 {
338 *fLog << err << "Reading RFs failed... aborting." << endl;
339 return kFALSE;
340 }
341
342 *fLog << inf << "RF read from " << fFileName << endl;
343
344 if (!fFunc.PreProcess(plist))
345 {
346 *fLog << err << "PreProcessing of evaluation function failed... aborting." << endl;
347 return kFALSE;
348 }
349
350 if (fTestMatrix)
351 return kTRUE;
352
353 fData->Print();
354
355 if (!fData->PreProcess(plist))
356 {
357 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
358 return kFALSE;
359 }
360
361 return kTRUE;
362}
363
364Double_t MRanForestCalc::Eval() const
365{
366 TVector event;
367 if (fTestMatrix)
368 *fTestMatrix >> event;
369 else
370 *fData >> event;
371
372 // --------------- Single Tree RF -------------------
373 if (fEForests.GetEntriesFast()==1)
374 {
375 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
376 return rf->CalcHadroness(event);
377 }
378
379 // --------------- Multi Tree RF -------------------
380 static TF1 f1("f1", "gaus");
381
382 Double_t sume = 0;
383 Double_t sumh = 0;
384 Double_t maxh = 0;
385 Double_t maxe = 0;
386
387 Double_t max = -1e10;
388 Double_t min = 1e10;
389
390 TIter Next(&fEForests);
391 MRanForest *rf = 0;
392
393 TGraph g;
394 while ((rf=(MRanForest*)Next()))
395 {
396 const Double_t h = rf->CalcHadroness(event);
397 const Double_t e = rf->GetUserVal();
398
399 g.SetPoint(g.GetN(), e, h);
400
401 sume += e*h;
402 sumh += h;
403
404 if (h>maxh)
405 {
406 maxh = h;
407 maxe = e;
408 }
409 if (e>max)
410 max = e;
411 if (e<min)
412 min = e;
413 }
414
415 switch (fEstimationMode)
416 {
417 case kMean:
418 return sume/sumh;
419 case kMaximum:
420 return maxe;
421 case kFit:
422 f1.SetParameter(0, maxh);
423 f1.SetParameter(1, maxe);
424 f1.SetParameter(2, 0.125);
425 g.Fit(&f1, "Q0N");
426 return f1.GetParameter(1);
427 }
428
429 return 0;
430}
431
432Int_t MRanForestCalc::Process()
433{
434 const Double_t val = Eval();
435
436 fRFOut->SetVal(fFunc.Eval(val));
437 fRFOut->SetReadyToSave();
438
439 return kTRUE;
440}
441
442void MRanForestCalc::Print(Option_t *o) const
443{
444 *fLog << all;
445 *fLog << GetDescriptor() << ":" << endl;
446 *fLog << " - Forest ";
447 switch (fEForests.GetEntries())
448 {
449 case 0: *fLog << "not yet initialized." << endl; break;
450 case 1: *fLog << "is a single tree forest." << endl; break;
451 default: *fLog << "is a multi tree forest." << endl; break;
452 }
453 /*
454 *fLog << " - Trees: " << fNumTrees << endl;
455 *fLog << " - Trys: " << fNumTry << endl;
456 *fLog << " - Node Size: " << fNdSize << endl;
457 *fLog << " - Node Size: " << fNdSize << endl;
458 */
459 *fLog << " - FileName: " << fFileName << endl;
460 *fLog << " - NameOutput: " << fNameOutput << endl;
461}
462
463// --------------------------------------------------------------------------
464//
465//
466Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
467{
468 Bool_t rc = kFALSE;
469 if (IsEnvDefined(env, prefix, "FileName", print))
470 {
471 rc = kTRUE;
472 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
473 }
474 if (IsEnvDefined(env, prefix, "Debug", print))
475 {
476 rc = kTRUE;
477 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
478 }
479 if (IsEnvDefined(env, prefix, "NameOutput", print))
480 {
481 rc = kTRUE;
482 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
483 }
484 if (IsEnvDefined(env, prefix, "EstimationMode", print))
485 {
486 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
487 txt = txt.Strip(TString::kBoth);
488 txt.ToLower();
489 if (txt==(TString)"mean")
490 fEstimationMode = kMean;
491 if (txt==(TString)"maximum")
492 fEstimationMode = kMaximum;
493 if (txt==(TString)"fit")
494 fEstimationMode = kFit;
495 rc = kTRUE;
496 }
497 return rc;
498}
Note: See TracBrowser for help on using the repository browser.