source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 8809

Last change on this file since 8809 was 8706, checked in by tbretz, 17 years ago
*** empty log message ***
File size: 13.8 KB
Line 
1/* ======================================================================== *\
2! $Name: not supported by cvs2svn $:$Id: MRanForestCalc.cc,v 1.30 2007-08-24 12:58:49 tbretz Exp $
3! --------------------------------------------------------------------------
4!
5! *
6! * This file is part of MARS, the MAGIC Analysis and Reconstruction
7! * Software. It is distributed to you in the hope that it can be a useful
8! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
9! * It is distributed WITHOUT ANY WARRANTY.
10! *
11! * Permission to use, copy, modify and distribute this software and its
12! * documentation for any purpose is hereby granted without fee,
13! * provided that the above copyright notice appear in all copies and
14! * that both that copyright notice and this permission notice appear
15! * in supporting documentation. It is provided "as is" without express
16! * or implied warranty.
17! *
18!
19!
20! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
21! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
22!
23! Copyright: MAGIC Software Development, 2000-2006
24!
25!
26\* ======================================================================== */
27
28/////////////////////////////////////////////////////////////////////////////
29//
30// MRanForestCalc
31//
32//
33////////////////////////////////////////////////////////////////////////////
34#include "MRanForestCalc.h"
35
36#include <TF1.h>
37#include <TFile.h>
38#include <TGraph.h>
39#include <TVector.h>
40
41#include "MHMatrix.h"
42
43#include "MLog.h"
44#include "MLogManip.h"
45
46#include "MData.h"
47#include "MDataArray.h"
48
49#include "MRanForest.h"
50#include "MParameters.h"
51
52#include "MParList.h"
53#include "MTaskList.h"
54#include "MEvtLoop.h"
55#include "MRanForestGrow.h"
56#include "MFillH.h"
57
58ClassImp(MRanForestCalc);
59
60using namespace std;
61
62const TString MRanForestCalc::gsDefName = "MRanForestCalc";
63const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
64
65const TString MRanForestCalc::gsNameOutput = "RanForestOut";
66const TString MRanForestCalc::gsNameEvalFunc = "EvalFunction";
67
68MRanForestCalc::MRanForestCalc(const char *name, const char *title)
69 : fData(0), fRFOut(0), fTestMatrix(0), fFunc("x"),
70 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
71 fLastDataColumnHasWeights(kFALSE),
72 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
73{
74 fName = name ? name : gsDefName.Data();
75 fTitle = title ? title : gsDefTitle.Data();
76
77 // FIXME:
78 fNumTrees = 100; //100
79 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
80 fNdSize = 1; //1
81}
82
83MRanForestCalc::~MRanForestCalc()
84{
85 fEForests.Delete();
86}
87
88// --------------------------------------------------------------------------
89//
90// Set a function which is applied to the output of the random forest
91//
92Bool_t MRanForestCalc::SetFunction(const char *func)
93{
94 return !fFunc.SetRule(func);
95}
96
97// --------------------------------------------------------------------------
98//
99// ver=0: One yes/no-classification forest is trained for each bin.
100// the yes/no classification is done using the grid
101// ver=1: One classification forest is trained. The last column contains a
102// value which is turned into a classifier by rf itself using the grid
103// ver=2: One classification forest is trained. The last column already contains
104// the classifier
105// ver=3: A regression forest is trained. The last column contains the
106// classifier
107//
108Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
109{
110 gLog.Separator("MRanForestCalc - Train");
111
112 if (!matrixtrain.GetColumns())
113 {
114 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
115 return kFALSE;
116 }
117
118 const Int_t ncols = matrixtrain.GetM().GetNcols();
119 const Int_t nrows = matrixtrain.GetM().GetNrows();
120 if (ncols<=0 || nrows <=0)
121 {
122 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
123 return kFALSE;
124 }
125
126 // rules (= combination of image par) to be used for energy estimation
127 TFile fileRF(fFileName, "recreate");
128 if (!fileRF.IsOpen())
129 {
130 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
131 return kFALSE;
132 }
133
134 // The number of columns which have to be removed for the training
135 // The last data column may contain weight which also have to be removed
136 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
137
138 const MDataArray &dcol = *matrixtrain.GetColumns();
139
140 // Make a copy of the rules for accessing the train-data
141 MDataArray usedrules;
142 for (Int_t i=0; i<ncols; i++)
143 if (i<ncols-nobs) // -3 is important!!!
144 usedrules.AddEntry(dcol[i].GetRule());
145 else
146 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
147
148 // In the case of regression store the rule to be regessed in the
149 // last entry of your rules
150 MDataArray rules(usedrules);
151 rules.AddEntry(ver<3?"Classification.fVal":dcol[ncols-1].GetRule().Data());
152
153 // prepare train-matrix finally used
154 TMatrix mat(matrixtrain.GetM());
155
156 // Resize it such that the obsolete columns are removed
157 mat.ResizeTo(nrows, ncols-nobs+1);
158
159 if (fDebug)
160 gLog.SetNullOutput(kTRUE);
161
162 // In the case one independant RF is trained for each bin (e.g.
163 // energy-bin) train all of them
164 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
165 for (Int_t ie=0; ie<nbins; ie++)
166 {
167 // In the case weights should be used initialize the
168 // corresponding array
169 Double_t sum = 0;
170
171 TArrayF weights(nrows);
172 if (fLastDataColumnHasWeights)
173 {
174 for (Int_t j=0; j<nrows; j++)
175 {
176 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
177 sum += weights[j];
178 }
179 }
180
181 *fLog << inf << "MRanForestCalc::Train: Sum of weights " << sum << endl;
182
183 // Setup the matrix such that the last comlumn contains
184 // the classifier or the regeression target value
185 switch (ver)
186 {
187 case 0: // Replace last column by a classification which is 1 in
188 // the case the event belongs to this bin, 0 otherwise
189 {
190 Int_t irows=0;
191 for (Int_t j=0; j<nrows; j++)
192 {
193 const Double_t value = matrixtrain.GetM()(j,ncols-1);
194 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
195
196 mat(j, ncols-nobs) = inside ? 1 : 0;
197
198 if (inside)
199 irows++;
200 }
201 if (irows==0)
202 *fLog << warn << "WARNING - Skipping";
203 else
204 *fLog << inf << "Training RF for";
205
206 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
207
208 if (irows==0)
209 continue;
210 }
211 break;
212
213 case 1: // Use last column as classifier or for regression
214 case 2:
215 case 3:
216 for (Int_t j=0; j<nrows; j++)
217 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
218 break;
219 }
220
221 MHMatrix matrix(mat, &rules, "MatrixTrain");
222
223 MParList plist;
224 MTaskList tlist;
225 plist.AddToList(&tlist);
226 plist.AddToList(&matrix);
227
228 MRanForest rf;
229 rf.SetNumTrees(fNumTrees);
230 rf.SetNumTry(fNumTry);
231 rf.SetNdSize(fNdSize);
232 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
233 if (ver==1)
234 rf.SetGrid(grid);
235 if (fLastDataColumnHasWeights)
236 rf.SetWeights(weights);
237
238 plist.AddToList(&rf);
239
240 MRanForestGrow rfgrow;
241 tlist.AddToList(&rfgrow);
242
243 MFillH fillh("MHRanForestGini");
244 tlist.AddToList(&fillh);
245
246 MEvtLoop evtloop(fTitle);
247 evtloop.SetParList(&plist);
248 evtloop.SetDisplay(fDisplay);
249 evtloop.SetLogStream(fLog);
250
251 if (!evtloop.Eventloop())
252 return kFALSE;
253
254 if (fDebug)
255 gLog.SetNullOutput(kFALSE);
256
257 if (ver==0)
258 {
259 // Calculate bin center
260 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
261
262 // save whole forest
263 rf.SetUserVal(E);
264 rf.SetName(Form("%.10f", E));
265 }
266
267 rf.Write();
268 }
269
270 // save rules
271 usedrules.Write("rules");
272
273 fFunc.Write(gsNameEvalFunc);
274
275 return kTRUE;
276}
277
278Int_t MRanForestCalc::ReadForests(MParList &plist)
279{
280 TFile fileRF(fFileName, "read");
281 if (!fileRF.IsOpen())
282 {
283 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
284 return kFALSE;
285 }
286
287 fEForests.Delete();
288
289 TIter Next(fileRF.GetListOfKeys());
290 TObject *o=0;
291 while ((o=Next()))
292 {
293 MRanForest *forest=0;
294 fileRF.GetObject(o->GetName(), forest);
295 if (!forest)
296 continue;
297
298 forest->SetUserVal(atof(o->GetName()));
299
300 fEForests.Add(forest);
301 }
302
303 // Maybe fEForests[0].fRules could be used instead?
304 if (fData->Read("rules")<=0)
305 {
306 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
307 return kFALSE;
308 }
309
310 if (fileRF.GetListOfKeys()->FindObject(gsNameEvalFunc))
311 {
312 if (fFunc.Read(gsNameEvalFunc)<=0)
313 {
314 *fLog << err << "ERROR - Reading '" << gsNameEvalFunc << "' from file " << fFileName << endl;
315 return kFALSE;
316 }
317
318 *fLog << inf << "Evaluation function found in file: " << fFunc.GetRule() << endl;
319 }
320
321 return kTRUE;
322}
323
324Int_t MRanForestCalc::PreProcess(MParList *plist)
325{
326 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
327 if (!fRFOut)
328 return kFALSE;
329
330 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
331 if (!fData)
332 return kFALSE;
333
334 if (!ReadForests(*plist))
335 {
336 *fLog << err << "Reading RFs failed... aborting." << endl;
337 return kFALSE;
338 }
339
340 *fLog << inf << "RF read from " << fFileName << endl;
341
342 if (!fFunc.PreProcess(plist))
343 {
344 *fLog << err << "PreProcessing of evaluation function failed... aborting." << endl;
345 return kFALSE;
346 }
347
348 if (fTestMatrix)
349 return kTRUE;
350
351 fData->Print();
352
353 if (!fData->PreProcess(plist))
354 {
355 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
356 return kFALSE;
357 }
358
359 return kTRUE;
360}
361
362Double_t MRanForestCalc::Eval() const
363{
364 TVector event;
365 if (fTestMatrix)
366 *fTestMatrix >> event;
367 else
368 *fData >> event;
369
370 // --------------- Single Tree RF -------------------
371 if (fEForests.GetEntriesFast()==1)
372 {
373 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
374 return rf->CalcHadroness(event);
375 }
376
377 // --------------- Multi Tree RF -------------------
378 static TF1 f1("f1", "gaus");
379
380 Double_t sume = 0;
381 Double_t sumh = 0;
382 Double_t maxh = 0;
383 Double_t maxe = 0;
384
385 Double_t max = -1e10;
386 Double_t min = 1e10;
387
388 TIter Next(&fEForests);
389 MRanForest *rf = 0;
390
391 TGraph g;
392 while ((rf=(MRanForest*)Next()))
393 {
394 const Double_t h = rf->CalcHadroness(event);
395 const Double_t e = rf->GetUserVal();
396
397 g.SetPoint(g.GetN(), e, h);
398
399 sume += e*h;
400 sumh += h;
401
402 if (h>maxh)
403 {
404 maxh = h;
405 maxe = e;
406 }
407 if (e>max)
408 max = e;
409 if (e<min)
410 min = e;
411 }
412
413 switch (fEstimationMode)
414 {
415 case kMean:
416 return sume/sumh;
417 case kMaximum:
418 return maxe;
419 case kFit:
420 f1.SetParameter(0, maxh);
421 f1.SetParameter(1, maxe);
422 f1.SetParameter(2, 0.125);
423 g.Fit(&f1, "Q0N");
424 return f1.GetParameter(1);
425 }
426
427 return 0;
428}
429
430Int_t MRanForestCalc::Process()
431{
432 const Double_t val = Eval();
433
434 fRFOut->SetVal(fFunc.Eval(val));
435 fRFOut->SetReadyToSave();
436
437 return kTRUE;
438}
439
440void MRanForestCalc::Print(Option_t *o) const
441{
442 *fLog << all;
443 *fLog << GetDescriptor() << ":" << endl;
444 *fLog << " - Forest ";
445 switch (fEForests.GetEntries())
446 {
447 case 0: *fLog << "not yet initialized." << endl; break;
448 case 1: *fLog << "is a single tree forest." << endl; break;
449 default: *fLog << "is a multi tree forest." << endl; break;
450 }
451 /*
452 *fLog << " - Trees: " << fNumTrees << endl;
453 *fLog << " - Trys: " << fNumTry << endl;
454 *fLog << " - Node Size: " << fNdSize << endl;
455 *fLog << " - Node Size: " << fNdSize << endl;
456 */
457 *fLog << " - FileName: " << fFileName << endl;
458 *fLog << " - NameOutput: " << fNameOutput << endl;
459}
460
461// --------------------------------------------------------------------------
462//
463//
464Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
465{
466 Bool_t rc = kFALSE;
467 if (IsEnvDefined(env, prefix, "FileName", print))
468 {
469 rc = kTRUE;
470 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
471 }
472 if (IsEnvDefined(env, prefix, "Debug", print))
473 {
474 rc = kTRUE;
475 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
476 }
477 if (IsEnvDefined(env, prefix, "NameOutput", print))
478 {
479 rc = kTRUE;
480 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
481 }
482 if (IsEnvDefined(env, prefix, "EstimationMode", print))
483 {
484 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
485 txt = txt.Strip(TString::kBoth);
486 txt.ToLower();
487 if (txt==(TString)"mean")
488 fEstimationMode = kMean;
489 if (txt==(TString)"maximum")
490 fEstimationMode = kMaximum;
491 if (txt==(TString)"fit")
492 fEstimationMode = kFit;
493 rc = kTRUE;
494 }
495 return rc;
496}
Note: See TracBrowser for help on using the repository browser.