source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 7814

Last change on this file since 7814 was 7784, checked in by tbretz, 18 years ago
*** empty log message ***
File size: 12.7 KB
Line 
1/* ================================q======================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
19! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
20!
21! Copyright: MAGIC Software Development, 2000-2005
22!
23!
24\* ======================================================================== */
25
26/////////////////////////////////////////////////////////////////////////////
27//
28// MRanForestCalc
29//
30//
31////////////////////////////////////////////////////////////////////////////
32#include "MRanForestCalc.h"
33
34#include <TF1.h>
35#include <TFile.h>
36#include <TGraph.h>
37#include <TVector.h>
38
39#include "MHMatrix.h"
40
41#include "MLog.h"
42#include "MLogManip.h"
43
44#include "MData.h"
45#include "MDataArray.h"
46
47#include "MRanForest.h"
48#include "MParameters.h"
49
50#include "MParList.h"
51#include "MTaskList.h"
52#include "MEvtLoop.h"
53#include "MRanForestGrow.h"
54#include "MFillH.h"
55
56ClassImp(MRanForestCalc);
57
58using namespace std;
59
60const TString MRanForestCalc::gsDefName = "MRanForestCalc";
61const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
62
63const TString MRanForestCalc::gsNameOutput = "RanForestOut";
64
65MRanForestCalc::MRanForestCalc(const char *name, const char *title)
66 : fData(0), fRFOut(0), fTestMatrix(0),
67 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
68 fLastDataColumnHasWeights(kFALSE),
69 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
70{
71 fName = name ? name : gsDefName.Data();
72 fTitle = title ? title : gsDefTitle.Data();
73
74 // FIXME:
75 fNumTrees = 100; //100
76 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
77 fNdSize = 1; //1
78}
79
80MRanForestCalc::~MRanForestCalc()
81{
82 fEForests.Delete();
83}
84
85// --------------------------------------------------------------------------
86//
87// ver=0: One yes/no-classification forest is trained for each bin.
88// the yes/no classification is done using the grid
89// ver=1: One classification forest is trained. The last column contains a
90// value which is turned into a classifier by rf itself using the grid
91// ver=2: One classification forest is trained. The last column already contains
92// the classifier
93// ver=3: A regression forest is trained. The last column contains the
94// classifier
95//
96Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
97{
98 gLog.Separator("MRanForestCalc - Train");
99
100 if (!matrixtrain.GetColumns())
101 {
102 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
103 return kFALSE;
104 }
105
106 const Int_t ncols = matrixtrain.GetM().GetNcols();
107 const Int_t nrows = matrixtrain.GetM().GetNrows();
108 if (ncols<=0 || nrows <=0)
109 {
110 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
111 return kFALSE;
112 }
113
114 // rules (= combination of image par) to be used for energy estimation
115 TFile fileRF(fFileName, "recreate");
116 if (!fileRF.IsOpen())
117 {
118 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
119 return kFALSE;
120 }
121
122 // The number of columns which have to be removed for the training
123 // The last data column may contain weight which also have to be removed
124 const Int_t nobs = fNumObsoleteVariables + (fLastDataColumnHasWeights?1:0); // Number of obsolete columns
125
126 const MDataArray &dcol = *matrixtrain.GetColumns();
127
128 // Make a copy of the rules for accessing the train-data
129 MDataArray usedrules;
130 for (Int_t i=0; i<ncols; i++)
131 if (i<ncols-nobs) // -3 is important!!!
132 usedrules.AddEntry(dcol[i].GetRule());
133 else
134 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
135
136 // In the case of regression store the rule to be regessed in the
137 // last entry of your rules
138 MDataArray rules(usedrules);
139 rules.AddEntry(ver<3?"Classification":dcol[ncols-1].GetRule().Data());
140
141 // prepare train-matrix finally used
142 TMatrix mat(matrixtrain.GetM());
143
144 // Resize it such that the obsolete columns are removed
145 mat.ResizeTo(nrows, ncols-nobs+1);
146
147 if (fDebug)
148 gLog.SetNullOutput(kTRUE);
149
150 // In the case one independant RF is trained for each bin (e.g.
151 // energy-bin) train all of them
152 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
153 for (Int_t ie=0; ie<nbins; ie++)
154 {
155 // In the case weights should be used initialize the
156 // corresponding array
157 TArrayF weights(nrows);
158 if (fLastDataColumnHasWeights)
159 for (Int_t j=0; j<nrows; j++)
160 {
161 weights[j] = matrixtrain.GetM()(j, ncols-nobs);
162 if (j%100==0)
163 cout << weights[j] << " ";
164 }
165
166 // Setup the matrix such that the last comlumn contains
167 // the classifier or the regeression target value
168 switch (ver)
169 {
170 case 0: // Replace last column by a classification which is 1 in
171 // the case the event belongs to this bin, 0 otherwise
172 {
173 Int_t irows=0;
174 for (Int_t j=0; j<nrows; j++)
175 {
176 const Double_t value = matrixtrain.GetM()(j,ncols-1);
177 const Bool_t inside = value>grid[ie] && value<=grid[ie+1];
178
179 mat(j, ncols-nobs) = inside ? 1 : 0;
180
181 if (inside)
182 irows++;
183 }
184 if (irows==0)
185 *fLog << warn << "WARNING - Skipping";
186 else
187 *fLog << inf << "Training RF for";
188
189 *fLog << " bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
190
191 if (irows==0)
192 continue;
193 }
194 break;
195
196 case 1: // Use last column as classifier or for regression
197 case 2:
198 case 3:
199 for (Int_t j=0; j<nrows; j++)
200 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
201 break;
202 }
203
204 MHMatrix matrix(mat, &rules, "MatrixTrain");
205
206 MParList plist;
207 MTaskList tlist;
208 plist.AddToList(&tlist);
209 plist.AddToList(&matrix);
210
211 MRanForest rf;
212 rf.SetNumTrees(fNumTrees);
213 rf.SetNumTry(fNumTry);
214 rf.SetNdSize(fNdSize);
215 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
216 if (ver==1)
217 rf.SetGrid(grid);
218 if (fLastDataColumnHasWeights)
219 rf.SetWeights(weights);
220
221 plist.AddToList(&rf);
222
223 MRanForestGrow rfgrow;
224 tlist.AddToList(&rfgrow);
225
226 MFillH fillh("MHRanForestGini");
227 tlist.AddToList(&fillh);
228
229 MEvtLoop evtloop;
230 evtloop.SetParList(&plist);
231 evtloop.SetDisplay(fDisplay);
232 evtloop.SetLogStream(fLog);
233
234 if (!evtloop.Eventloop())
235 return kFALSE;
236
237 if (fDebug)
238 gLog.SetNullOutput(kFALSE);
239
240 if (ver==0)
241 {
242 // Calculate bin center
243 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
244
245 // save whole forest
246 rf.SetUserVal(E);
247 rf.SetName(Form("%.10f", E));
248 }
249
250 rf.Write();
251 }
252
253 // save rules
254 usedrules.Write("rules");
255
256 return kTRUE;
257}
258
259Int_t MRanForestCalc::ReadForests(MParList &plist)
260{
261 TFile fileRF(fFileName, "read");
262 if (!fileRF.IsOpen())
263 {
264 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
265 return kFALSE;
266 }
267
268 fEForests.Delete();
269
270 TIter Next(fileRF.GetListOfKeys());
271 TObject *o=0;
272 while ((o=Next()))
273 {
274 MRanForest *forest=0;
275 fileRF.GetObject(o->GetName(), forest);
276 if (!forest)
277 continue;
278
279 forest->SetUserVal(atof(o->GetName()));
280
281 fEForests.Add(forest);
282 }
283
284 // Maybe fEForests[0].fRules could be used instead?
285 if (fData->Read("rules")<=0)
286 {
287 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
288 return kFALSE;
289 }
290
291 return kTRUE;
292}
293
294Int_t MRanForestCalc::PreProcess(MParList *plist)
295{
296 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
297 if (!fRFOut)
298 return kFALSE;
299
300 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
301 if (!fData)
302 return kFALSE;
303
304 if (!ReadForests(*plist))
305 {
306 *fLog << err << "Reading RFs failed... aborting." << endl;
307 return kFALSE;
308 }
309
310 *fLog << inf << "RF read from " << fFileName << endl;
311
312 if (fTestMatrix)
313 return kTRUE;
314
315 fData->Print();
316
317 if (!fData->PreProcess(plist))
318 {
319 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
320 return kFALSE;
321 }
322
323 return kTRUE;
324}
325
326Int_t MRanForestCalc::Process()
327{
328 TVector event;
329 if (fTestMatrix)
330 *fTestMatrix >> event;
331 else
332 *fData >> event;
333
334 // --------------- Single Tree RF -------------------
335 if (fEForests.GetEntriesFast()==1)
336 {
337 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
338 fRFOut->SetVal(rf->CalcHadroness(event));
339 fRFOut->SetReadyToSave();
340
341 return kTRUE;
342 }
343
344 // --------------- Multi Tree RF -------------------
345 static TF1 f1("f1", "gaus");
346
347 Double_t sume = 0;
348 Double_t sumh = 0;
349 Double_t maxh = 0;
350 Double_t maxe = 0;
351
352 Double_t max = -1e10;
353 Double_t min = 1e10;
354
355 TIter Next(&fEForests);
356 MRanForest *rf = 0;
357
358 TGraph g;
359 while ((rf=(MRanForest*)Next()))
360 {
361 const Double_t h = rf->CalcHadroness(event);
362 const Double_t e = rf->GetUserVal();
363
364 g.SetPoint(g.GetN(), e, h);
365
366 sume += e*h;
367 sumh += h;
368
369 if (h>maxh)
370 {
371 maxh = h;
372 maxe = e;
373 }
374 if (e>max)
375 max = e;
376 if (e<min)
377 min = e;
378 }
379
380 switch (fEstimationMode)
381 {
382 case kMean:
383 fRFOut->SetVal(pow(10, sume/sumh));
384 break;
385 case kMaximum:
386 fRFOut->SetVal(pow(10, maxe));
387 break;
388 case kFit:
389 f1.SetParameter(0, maxh);
390 f1.SetParameter(1, maxe);
391 f1.SetParameter(2, 0.125);
392 g.Fit(&f1, "Q0N");
393 fRFOut->SetVal(pow(10, f1.GetParameter(1)));
394 break;
395 }
396
397 fRFOut->SetReadyToSave();
398
399 return kTRUE;
400}
401
402void MRanForestCalc::Print(Option_t *o) const
403{
404 *fLog << all;
405 *fLog << GetDescriptor() << ":" << endl;
406 *fLog << " - Forest ";
407 switch (fEForests.GetEntries())
408 {
409 case 0: *fLog << "not yet initialized." << endl; break;
410 case 1: *fLog << "is a single tree forest." << endl; break;
411 default: *fLog << "is a multi tree forest." << endl; break;
412 }
413 /*
414 *fLog << " - Trees: " << fNumTrees << endl;
415 *fLog << " - Trys: " << fNumTry << endl;
416 *fLog << " - Node Size: " << fNdSize << endl;
417 *fLog << " - Node Size: " << fNdSize << endl;
418 */
419 *fLog << " - FileName: " << fFileName << endl;
420 *fLog << " - NameOutput: " << fNameOutput << endl;
421}
422
423// --------------------------------------------------------------------------
424//
425//
426Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
427{
428 Bool_t rc = kFALSE;
429 if (IsEnvDefined(env, prefix, "FileName", print))
430 {
431 rc = kTRUE;
432 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
433 }
434 if (IsEnvDefined(env, prefix, "Debug", print))
435 {
436 rc = kTRUE;
437 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
438 }
439 if (IsEnvDefined(env, prefix, "NameOutput", print))
440 {
441 rc = kTRUE;
442 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
443 }
444 if (IsEnvDefined(env, prefix, "EstimationMode", print))
445 {
446 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
447 txt = txt.Strip(TString::kBoth);
448 txt.ToLower();
449 if (txt==(TString)"mean")
450 fEstimationMode = kMean;
451 if (txt==(TString)"maximum")
452 fEstimationMode = kMaximum;
453 if (txt==(TString)"fit")
454 fEstimationMode = kFit;
455 rc = kTRUE;
456 }
457 return rc;
458}
Note: See TracBrowser for help on using the repository browser.