source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 7764

Last change on this file since 7764 was 7710, checked in by tbretz, 18 years ago
*** empty log message ***
File size: 12.7 KB
Line 
1/* ================================q======================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
19! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
20!
21! Copyright: MAGIC Software Development, 2000-2005
22!
23!
24\* ======================================================================== */
25
26/////////////////////////////////////////////////////////////////////////////
27//
28// MRanForestCalc
29//
30//
31////////////////////////////////////////////////////////////////////////////
32#include "MRanForestCalc.h"
33
34#include <TF1.h>
35#include <TGraph.h>
36#include <TVector.h>
37
38#include "MHMatrix.h"
39
40#include "MLog.h"
41#include "MLogManip.h"
42
43#include "MData.h"
44#include "MDataArray.h"
45
46#include "MRanForest.h"
47#include "MParameters.h"
48
49#include "MParList.h"
50#include "MTaskList.h"
51#include "MEvtLoop.h"
52#include "MRanForestGrow.h"
53#include "MFillH.h"
54
55ClassImp(MRanForestCalc);
56
57using namespace std;
58
59const TString MRanForestCalc::gsDefName = "MRanForestCalc";
60const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
61
62const TString MRanForestCalc::gsNameOutput = "RanForestOut";
63
64MRanForestCalc::MRanForestCalc(const char *name, const char *title)
65 : fData(0), fRFOut(0), fTestMatrix(0),
66 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
67 fLastDataColumnHasWeights(kFALSE),
68 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
69{
70 fName = name ? name : gsDefName.Data();
71 fTitle = title ? title : gsDefTitle.Data();
72
73 // FIXME:
74 fNumTrees = 100; //100
75 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
76 fNdSize = 1; //1
77}
78
79MRanForestCalc::~MRanForestCalc()
80{
81 fEForests.Delete();
82}
83
84// --------------------------------------------------------------------------
85//
86// ver=0: One yes/no-classification forest is trained for each bin.
87// the yes/no classification is done using the grid
88// ver=1: One classification forest is trained. The last column contains a
89// value which is turned into a classifier by rf itself using the grid
90// ver=2: One classification forest is trained. The last column already contains
91// the classifier
92// ver=3: A regression forest is trained. The last column contains the
93// classifier
94//
95Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
96{
97 gLog.Separator("MRanForestCalc - Train");
98
99 if (!matrixtrain.GetColumns())
100 {
101 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
102 return kFALSE;
103 }
104
105 const Int_t ncols = matrixtrain.GetM().GetNcols();
106 const Int_t nrows = matrixtrain.GetM().GetNrows();
107 if (ncols<=0 || nrows <=0)
108 {
109 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
110 return kFALSE;
111 }
112
113 // rules (= combination of image par) to be used for energy estimation
114 TFile fileRF(fFileName, "recreate");
115 if (!fileRF.IsOpen())
116 {
117 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
118 return kFALSE;
119 }
120
121 // The number of columns which have to be removed for the training
122 // The last data column may contain weight which also have to be removed
123 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
124
125 const MDataArray &dcol = *matrixtrain.GetColumns();
126
127 // Make a copy of the rules for accessing the train-data
128 MDataArray usedrules;
129 for (Int_t i=0; i<ncols; i++)
130 if (i<ncols-nobs) // -3 is important!!!
131 usedrules.AddEntry(dcol[i].GetRule());
132 else
133 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
134
135 // In the case of regression store the rule to be regessed in the
136 // last entry of your rules
137 MDataArray rules(usedrules);
138 rules.AddEntry(ver<3?"Classification":dcol[ncols-1].GetRule());
139
140 // prepare train-matrix finally used
141 TMatrix mat(matrixtrain.GetM());
142
143 // Resize it such that the obsolete columns are removed
144 mat.ResizeTo(nrows, ncols-nobs+1);
145
146 if (fDebug)
147 gLog.SetNullOutput(kTRUE);
148
149 // In the case one independant RF is trained for each bin (e.g.
150 // energy-bin) train all of them
151 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
152 for (Int_t ie=0; ie<nbins; ie++)
153 {
154 // In the case weights should be used initialize the
155 // corresponding array
156 TArrayF weights(nrows);
157 if (fLastDataColumnHasWeights)
158 for (Int_t j=0; j<nrows; j++)
159 {
160 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
161 if (j%100==0)
162 cout << weights[j] << " ";
163 }
164
165 // Setup the matrix such that the last comlumn contains
166 // the classifier or the regeression target value
167 switch (ver)
168 {
169 case 0: // Replace last column by a classification which is 1 in
170 // the case the event belongs to this bin, 0 otherwise
171 {
172 Int_t irows=0;
173 for (Int_t j=0; j<nrows; j++)
174 {
175 const Double_t value = matrixtrain.GetM()(j,ncols-1);
176 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
177
178 mat(j, ncols-nobs) = inside ? 1 : 0;
179
180 if (inside)
181 irows++;
182 }
183 if (irows==0)
184 *fLog << warn << "WARNING - Skipping";
185 else
186 *fLog << inf << "Training RF for";
187
188 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
189
190 if (irows==0)
191 continue;
192 }
193 break;
194
195 case 1: // Use last column as classifier or for regression
196 case 2:
197 case 3:
198 for (Int_t j=0; j<nrows; j++)
199 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
200 break;
201 }
202
203 MHMatrix matrix(mat, &rules, "MatrixTrain");
204
205 MParList plist;
206 MTaskList tlist;
207 plist.AddToList(&tlist);
208 plist.AddToList(&matrix);
209
210 MRanForest rf;
211 rf.SetNumTrees(fNumTrees);
212 rf.SetNumTry(fNumTry);
213 rf.SetNdSize(fNdSize);
214 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
215 if (ver==1)
216 rf.SetGrid(grid);
217 if (fLastDataColumnHasWeights)
218 rf.SetWeights(weights);
219
220 plist.AddToList(&rf);
221
222 MRanForestGrow rfgrow;
223 tlist.AddToList(&rfgrow);
224
225 MFillH fillh("MHRanForestGini");
226 tlist.AddToList(&fillh);
227
228 MEvtLoop evtloop;
229 evtloop.SetParList(&plist);
230 evtloop.SetDisplay(fDisplay);
231 evtloop.SetLogStream(fLog);
232
233 if (!evtloop.Eventloop())
234 return kFALSE;
235
236 if (fDebug)
237 gLog.SetNullOutput(kFALSE);
238
239 if (ver==0)
240 {
241 // Calculate bin center
242 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
243
244 // save whole forest
245 rf.SetUserVal(E);
246 rf.SetName(Form("%.10f", E));
247 }
248
249 rf.Write();
250 }
251
252 // save rules
253 usedrules.Write("rules");
254
255 return kTRUE;
256}
257
258Int_t MRanForestCalc::ReadForests(MParList &plist)
259{
260 TFile fileRF(fFileName, "read");
261 if (!fileRF.IsOpen())
262 {
263 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
264 return kFALSE;
265 }
266
267 fEForests.Delete();
268
269 TIter Next(fileRF.GetListOfKeys());
270 TObject *o=0;
271 while ((o=Next()))
272 {
273 MRanForest *forest=0;
274 fileRF.GetObject(o->GetName(), forest);
275 if (!forest)
276 continue;
277
278 forest->SetUserVal(atof(o->GetName()));
279
280 fEForests.Add(forest);
281 }
282
283 // Maybe fEForests[0].fRules could be used instead?
284 if (fData->Read("rules")<=0)
285 {
286 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
287 return kFALSE;
288 }
289
290 return kTRUE;
291}
292
293Int_t MRanForestCalc::PreProcess(MParList *plist)
294{
295 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
296 if (!fRFOut)
297 return kFALSE;
298
299 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
300 if (!fData)
301 return kFALSE;
302
303 if (!ReadForests(*plist))
304 {
305 *fLog << err << "Reading RFs failed... aborting." << endl;
306 return kFALSE;
307 }
308
309 *fLog << inf << "RF read from " << fFileName << endl;
310
311 if (fTestMatrix)
312 return kTRUE;
313
314 fData->Print();
315
316 if (!fData->PreProcess(plist))
317 {
318 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
319 return kFALSE;
320 }
321
322 return kTRUE;
323}
324
325Int_t MRanForestCalc::Process()
326{
327 TVector event;
328 if (fTestMatrix)
329 *fTestMatrix >> event;
330 else
331 *fData >> event;
332
333 // --------------- Single Tree RF -------------------
334 if (fEForests.GetEntriesFast()==1)
335 {
336 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
337 fRFOut->SetVal(rf->CalcHadroness(event));
338 fRFOut->SetReadyToSave();
339
340 return kTRUE;
341 }
342
343 // --------------- Multi Tree RF -------------------
344 static TF1 f1("f1", "gaus");
345
346 Double_t sume = 0;
347 Double_t sumh = 0;
348 Double_t maxh = 0;
349 Double_t maxe = 0;
350
351 Double_t max = -1e10;
352 Double_t min = 1e10;
353
354 TIter Next(&fEForests);
355 MRanForest *rf = 0;
356
357 TGraph g;
358 while ((rf=(MRanForest*)Next()))
359 {
360 const Double_t h = rf->CalcHadroness(event);
361 const Double_t e = rf->GetUserVal();
362
363 g.SetPoint(g.GetN(), e, h);
364
365 sume += e*h;
366 sumh += h;
367
368 if (h>maxh)
369 {
370 maxh = h;
371 maxe = e;
372 }
373 if (e>max)
374 max = e;
375 if (e<min)
376 min = e;
377 }
378
379 switch (fEstimationMode)
380 {
381 case kMean:
382 fRFOut->SetVal(pow(10, sume/sumh));
383 break;
384 case kMaximum:
385 fRFOut->SetVal(pow(10, maxe));
386 break;
387 case kFit:
388 f1.SetParameter(0, maxh);
389 f1.SetParameter(1, maxe);
390 f1.SetParameter(2, 0.125);
391 g.Fit(&f1, "Q0N");
392 fRFOut->SetVal(pow(10, f1.GetParameter(1)));
393 break;
394 }
395
396 fRFOut->SetReadyToSave();
397
398 return kTRUE;
399}
400
401void MRanForestCalc::Print(Option_t *o) const
402{
403 *fLog << all;
404 *fLog << GetDescriptor() << ":" << endl;
405 *fLog << " - Forest ";
406 switch (fEForests.GetEntries())
407 {
408 case 0: *fLog << "not yet initialized." << endl; break;
409 case 1: *fLog << "is a single tree forest." << endl; break;
410 default: *fLog << "is a multi tree forest." << endl; break;
411 }
412 /*
413 *fLog << " - Trees: " << fNumTrees << endl;
414 *fLog << " - Trys: " << fNumTry << endl;
415 *fLog << " - Node Size: " << fNdSize << endl;
416 *fLog << " - Node Size: " << fNdSize << endl;
417 */
418 *fLog << " - FileName: " << fFileName << endl;
419 *fLog << " - NameOutput: " << fNameOutput << endl;
420}
421
422// --------------------------------------------------------------------------
423//
424//
425Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
426{
427 Bool_t rc = kFALSE;
428 if (IsEnvDefined(env, prefix, "FileName", print))
429 {
430 rc = kTRUE;
431 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
432 }
433 if (IsEnvDefined(env, prefix, "Debug", print))
434 {
435 rc = kTRUE;
436 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
437 }
438 if (IsEnvDefined(env, prefix, "NameOutput", print))
439 {
440 rc = kTRUE;
441 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
442 }
443 if (IsEnvDefined(env, prefix, "EstimationMode", print))
444 {
445 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
446 txt = txt.Strip(TString::kBoth);
447 txt.ToLower();
448 if (txt==(TString)"mean")
449 fEstimationMode = kMean;
450 if (txt==(TString)"maximum")
451 fEstimationMode = kMaximum;
452 if (txt==(TString)"fit")
453 fEstimationMode = kFit;
454 rc = kTRUE;
455 }
456 return rc;
457}
Note: See TracBrowser for help on using the repository browser.