source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 8203

Last change on this file since 8203 was 8203, checked in by tbretz, 18 years ago
*** empty log message ***
File size: 12.9 KB
Line 
1/* ======================================================================== *\
2! $Name: not supported by cvs2svn $:$Id: MRanForestCalc.cc,v 1.25 2006-11-02 08:57:00 tbretz Exp $
3! --------------------------------------------------------------------------
4!
5! *
6! * This file is part of MARS, the MAGIC Analysis and Reconstruction
7! * Software. It is distributed to you in the hope that it can be a useful
8! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
9! * It is distributed WITHOUT ANY WARRANTY.
10! *
11! * Permission to use, copy, modify and distribute this software and its
12! * documentation for any purpose is hereby granted without fee,
13! * provided that the above copyright notice appear in all copies and
14! * that both that copyright notice and this permission notice appear
15! * in supporting documentation. It is provided "as is" without express
16! * or implied warranty.
17! *
18!
19!
20! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
21! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
22!
23! Copyright: MAGIC Software Development, 2000-2006
24!
25!
26\* ======================================================================== */
27
28/////////////////////////////////////////////////////////////////////////////
29//
30// MRanForestCalc
31//
32//
33////////////////////////////////////////////////////////////////////////////
34#include "MRanForestCalc.h"
35
36#include <TF1.h>
37#include <TFile.h>
38#include <TGraph.h>
39#include <TVector.h>
40
41#include "MHMatrix.h"
42
43#include "MLog.h"
44#include "MLogManip.h"
45
46#include "MData.h"
47#include "MDataArray.h"
48
49#include "MRanForest.h"
50#include "MParameters.h"
51
52#include "MParList.h"
53#include "MTaskList.h"
54#include "MEvtLoop.h"
55#include "MRanForestGrow.h"
56#include "MFillH.h"
57
58ClassImp(MRanForestCalc);
59
60using namespace std;
61
62const TString MRanForestCalc::gsDefName = "MRanForestCalc";
63const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
64
65const TString MRanForestCalc::gsNameOutput = "RanForestOut";
66
67MRanForestCalc::MRanForestCalc(const char *name, const char *title)
68 : fData(0), fRFOut(0), fTestMatrix(0),
69 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
70 fLastDataColumnHasWeights(kFALSE),
71 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
72{
73 fName = name ? name : gsDefName.Data();
74 fTitle = title ? title : gsDefTitle.Data();
75
76 // FIXME:
77 fNumTrees = 100; //100
78 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
79 fNdSize = 1; //1
80}
81
82MRanForestCalc::~MRanForestCalc()
83{
84 fEForests.Delete();
85}
86
87// --------------------------------------------------------------------------
88//
89// ver=0: One yes/no-classification forest is trained for each bin.
90// the yes/no classification is done using the grid
91// ver=1: One classification forest is trained. The last column contains a
92// value which is turned into a classifier by rf itself using the grid
93// ver=2: One classification forest is trained. The last column already contains
94// the classifier
95// ver=3: A regression forest is trained. The last column contains the
96// classifier
97//
98Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
99{
100 gLog.Separator("MRanForestCalc - Train");
101
102 if (!matrixtrain.GetColumns())
103 {
104 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
105 return kFALSE;
106 }
107
108 const Int_t ncols = matrixtrain.GetM().GetNcols();
109 const Int_t nrows = matrixtrain.GetM().GetNrows();
110 if (ncols<=0 || nrows <=0)
111 {
112 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
113 return kFALSE;
114 }
115
116 // rules (= combination of image par) to be used for energy estimation
117 TFile fileRF(fFileName, "recreate");
118 if (!fileRF.IsOpen())
119 {
120 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
121 return kFALSE;
122 }
123
124 // The number of columns which have to be removed for the training
125 // The last data column may contain weight which also have to be removed
126 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
127
128 const MDataArray &dcol = *matrixtrain.GetColumns();
129
130 // Make a copy of the rules for accessing the train-data
131 MDataArray usedrules;
132 for (Int_t i=0; i<ncols; i++)
133 if (i<ncols-nobs) // -3 is important!!!
134 usedrules.AddEntry(dcol[i].GetRule());
135 else
136 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
137
138 // In the case of regression store the rule to be regessed in the
139 // last entry of your rules
140 MDataArray rules(usedrules);
141 rules.AddEntry(ver<3?"Classification.fVal":dcol[ncols-1].GetRule().Data());
142
143 // prepare train-matrix finally used
144 TMatrix mat(matrixtrain.GetM());
145
146 // Resize it such that the obsolete columns are removed
147 mat.ResizeTo(nrows, ncols-nobs+1);
148
149 if (fDebug)
150 gLog.SetNullOutput(kTRUE);
151
152 // In the case one independant RF is trained for each bin (e.g.
153 // energy-bin) train all of them
154 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
155 for (Int_t ie=0; ie<nbins; ie++)
156 {
157 // In the case weights should be used initialize the
158 // corresponding array
159 TArrayF weights(nrows);
160 if (fLastDataColumnHasWeights)
161 for (Int_t j=0; j<nrows; j++)
162 {
163 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
164 if (j%250==0)
165 cout << weights[j] << " ";
166 }
167
168 // Setup the matrix such that the last comlumn contains
169 // the classifier or the regeression target value
170 switch (ver)
171 {
172 case 0: // Replace last column by a classification which is 1 in
173 // the case the event belongs to this bin, 0 otherwise
174 {
175 Int_t irows=0;
176 for (Int_t j=0; j<nrows; j++)
177 {
178 const Double_t value = matrixtrain.GetM()(j,ncols-1);
179 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
180
181 mat(j, ncols-nobs) = inside ? 1 : 0;
182
183 if (inside)
184 irows++;
185 }
186 if (irows==0)
187 *fLog << warn << "WARNING - Skipping";
188 else
189 *fLog << inf << "Training RF for";
190
191 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
192
193 if (irows==0)
194 continue;
195 }
196 break;
197
198 case 1: // Use last column as classifier or for regression
199 case 2:
200 case 3:
201 for (Int_t j=0; j<nrows; j++)
202 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
203 break;
204 }
205
206 MHMatrix matrix(mat, &rules, "MatrixTrain");
207
208 MParList plist;
209 MTaskList tlist;
210 plist.AddToList(&tlist);
211 plist.AddToList(&matrix);
212
213 MRanForest rf;
214 rf.SetNumTrees(fNumTrees);
215 rf.SetNumTry(fNumTry);
216 rf.SetNdSize(fNdSize);
217 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
218 if (ver==1)
219 rf.SetGrid(grid);
220 if (fLastDataColumnHasWeights)
221 rf.SetWeights(weights);
222
223 plist.AddToList(&rf);
224
225 MRanForestGrow rfgrow;
226 tlist.AddToList(&rfgrow);
227
228 MFillH fillh("MHRanForestGini");
229 tlist.AddToList(&fillh);
230
231 MEvtLoop evtloop;
232 evtloop.SetParList(&plist);
233 evtloop.SetDisplay(fDisplay);
234 evtloop.SetLogStream(fLog);
235
236 if (!evtloop.Eventloop())
237 return kFALSE;
238
239 if (fDebug)
240 gLog.SetNullOutput(kFALSE);
241
242 if (ver==0)
243 {
244 // Calculate bin center
245 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
246
247 // save whole forest
248 rf.SetUserVal(E);
249 rf.SetName(Form("%.10f", E));
250 }
251
252 rf.Write();
253 }
254
255 // save rules
256 usedrules.Write("rules");
257
258 return kTRUE;
259}
260
261Int_t MRanForestCalc::ReadForests(MParList &plist)
262{
263 TFile fileRF(fFileName, "read");
264 if (!fileRF.IsOpen())
265 {
266 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
267 return kFALSE;
268 }
269
270 fEForests.Delete();
271
272 TIter Next(fileRF.GetListOfKeys());
273 TObject *o=0;
274 while ((o=Next()))
275 {
276 MRanForest *forest=0;
277 fileRF.GetObject(o->GetName(), forest);
278 if (!forest)
279 continue;
280
281 forest->SetUserVal(atof(o->GetName()));
282
283 fEForests.Add(forest);
284 }
285
286 // Maybe fEForests[0].fRules could be used instead?
287 if (fData->Read("rules")<=0)
288 {
289 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
290 return kFALSE;
291 }
292
293 return kTRUE;
294}
295
296Int_t MRanForestCalc::PreProcess(MParList *plist)
297{
298 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
299 if (!fRFOut)
300 return kFALSE;
301
302 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
303 if (!fData)
304 return kFALSE;
305
306 if (!ReadForests(*plist))
307 {
308 *fLog << err << "Reading RFs failed... aborting." << endl;
309 return kFALSE;
310 }
311
312 *fLog << inf << "RF read from " << fFileName << endl;
313
314 if (fTestMatrix)
315 return kTRUE;
316
317 fData->Print();
318
319 if (!fData->PreProcess(plist))
320 {
321 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
322 return kFALSE;
323 }
324
325 return kTRUE;
326}
327
328Int_t MRanForestCalc::Process()
329{
330 TVector event;
331 if (fTestMatrix)
332 *fTestMatrix >> event;
333 else
334 *fData >> event;
335
336 // --------------- Single Tree RF -------------------
337 if (fEForests.GetEntriesFast()==1)
338 {
339 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
340 fRFOut->SetVal(rf->CalcHadroness(event));
341 fRFOut->SetReadyToSave();
342
343 return kTRUE;
344 }
345
346 // --------------- Multi Tree RF -------------------
347 static TF1 f1("f1", "gaus");
348
349 Double_t sume = 0;
350 Double_t sumh = 0;
351 Double_t maxh = 0;
352 Double_t maxe = 0;
353
354 Double_t max = -1e10;
355 Double_t min = 1e10;
356
357 TIter Next(&fEForests);
358 MRanForest *rf = 0;
359
360 TGraph g;
361 while ((rf=(MRanForest*)Next()))
362 {
363 const Double_t h = rf->CalcHadroness(event);
364 const Double_t e = rf->GetUserVal();
365
366 g.SetPoint(g.GetN(), e, h);
367
368 sume += e*h;
369 sumh += h;
370
371 if (h>maxh)
372 {
373 maxh = h;
374 maxe = e;
375 }
376 if (e>max)
377 max = e;
378 if (e<min)
379 min = e;
380 }
381
382 switch (fEstimationMode)
383 {
384 case kMean:
385 fRFOut->SetVal(pow(10, sume/sumh));
386 break;
387 case kMaximum:
388 fRFOut->SetVal(pow(10, maxe));
389 break;
390 case kFit:
391 f1.SetParameter(0, maxh);
392 f1.SetParameter(1, maxe);
393 f1.SetParameter(2, 0.125);
394 g.Fit(&f1, "Q0N");
395 fRFOut->SetVal(pow(10, f1.GetParameter(1)));
396 break;
397 }
398
399 fRFOut->SetReadyToSave();
400
401 return kTRUE;
402}
403
404void MRanForestCalc::Print(Option_t *o) const
405{
406 *fLog << all;
407 *fLog << GetDescriptor() << ":" << endl;
408 *fLog << " - Forest ";
409 switch (fEForests.GetEntries())
410 {
411 case 0: *fLog << "not yet initialized." << endl; break;
412 case 1: *fLog << "is a single tree forest." << endl; break;
413 default: *fLog << "is a multi tree forest." << endl; break;
414 }
415 /*
416 *fLog << " - Trees: " << fNumTrees << endl;
417 *fLog << " - Trys: " << fNumTry << endl;
418 *fLog << " - Node Size: " << fNdSize << endl;
419 *fLog << " - Node Size: " << fNdSize << endl;
420 */
421 *fLog << " - FileName: " << fFileName << endl;
422 *fLog << " - NameOutput: " << fNameOutput << endl;
423}
424
425// --------------------------------------------------------------------------
426//
427//
428Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
429{
430 Bool_t rc = kFALSE;
431 if (IsEnvDefined(env, prefix, "FileName", print))
432 {
433 rc = kTRUE;
434 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
435 }
436 if (IsEnvDefined(env, prefix, "Debug", print))
437 {
438 rc = kTRUE;
439 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
440 }
441 if (IsEnvDefined(env, prefix, "NameOutput", print))
442 {
443 rc = kTRUE;
444 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
445 }
446 if (IsEnvDefined(env, prefix, "EstimationMode", print))
447 {
448 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
449 txt = txt.Strip(TString::kBoth);
450 txt.ToLower();
451 if (txt==(TString)"mean")
452 fEstimationMode = kMean;
453 if (txt==(TString)"maximum")
454 fEstimationMode = kMaximum;
455 if (txt==(TString)"fit")
456 fEstimationMode = kFit;
457 rc = kTRUE;
458 }
459 return rc;
460}
Note: See TracBrowser for help on using the repository browser.