source: trunk/MagicSoft/Mars/mranforest/MRanForestCalc.cc@ 7450

Last change on this file since 7450 was 7425, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 11.6 KB
Line 
1/* ======================================================================== *\
2!
3! *
4! * This file is part of MARS, the MAGIC Analysis and Reconstruction
5! * Software. It is distributed to you in the hope that it can be a useful
6! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
7! * It is distributed WITHOUT ANY WARRANTY.
8! *
9! * Permission to use, copy, modify and distribute this software and its
10! * documentation for any purpose is hereby granted without fee,
11! * provided that the above copyright notice appear in all copies and
12! * that both that copyright notice and this permission notice appear
13! * in supporting documentation. It is provided "as is" without express
14! * or implied warranty.
15! *
16!
17!
18! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
19! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
20!
21! Copyright: MAGIC Software Development, 2000-2005
22!
23!
24\* ======================================================================== */
25
26/////////////////////////////////////////////////////////////////////////////
27//
28// MRanForestCalc
29//
30//
31////////////////////////////////////////////////////////////////////////////
32#include "MRanForestCalc.h"
33
34#include <TVector.h>
35
36#include "MHMatrix.h"
37
38#include "MLog.h"
39#include "MLogManip.h"
40
41#include "MData.h"
42#include "MDataArray.h"
43
44#include "MRanForest.h"
45#include "MParameters.h"
46
47#include "MParList.h"
48#include "MTaskList.h"
49#include "MEvtLoop.h"
50#include "MRanForestGrow.h"
51#include "MFillH.h"
52
53ClassImp(MRanForestCalc);
54
55using namespace std;
56
57const TString MRanForestCalc::gsDefName = "MRanForestCalc";
58const TString MRanForestCalc::gsDefTitle = "RF for energy estimation";
59
60const TString MRanForestCalc::gsNameOutput = "RanForestOut";
61
62MRanForestCalc::MRanForestCalc(const char *name, const char *title)
63 : fData(0), fRFOut(0), fTestMatrix(0),
64 fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
65 fNameOutput(gsNameOutput), fDebug(kFALSE), fEstimationMode(kMean)
66{
67 fName = name ? name : gsDefName.Data();
68 fTitle = title ? title : gsDefTitle.Data();
69
70 // FIXME:
71 fNumTrees = 100; //100
72 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
73 fNdSize = 1; //1
74}
75
76MRanForestCalc::~MRanForestCalc()
77{
78 fEForests.Delete();
79}
80
81// --------------------------------------------------------------------------
82//
83// ver=0: One yes/no-classification forest is trained for each bin.
84// the yes/no classification is done using the grid
85// ver=1: One classification forest is trained. The last column contains a
86// value which is turned into a classifier by rf itself using the grid
87// ver=2: One classification forest is trained. The last column already contains
88// the classifier
89// ver=3: A regression forest is trained. The last column contains the
90// classifier
91//
92Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
93{
94 gLog.Separator("MRanForestCalc - Train");
95
96 if (!matrixtrain.GetColumns())
97 {
98 *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
99 return kFALSE;
100 }
101
102 const Int_t ncols = matrixtrain.GetM().GetNcols();
103 const Int_t nrows = matrixtrain.GetM().GetNrows();
104 if (ncols<=0 || nrows <=0)
105 {
106 *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
107 return kFALSE;
108 }
109
110 // rules (= combination of image par) to be used for energy estimation
111 TFile fileRF(fFileName, "recreate");
112 if (!fileRF.IsOpen())
113 {
114 *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
115 return kFALSE;
116 }
117
118 const Int_t nobs = fNumObsoleteVariables; // Number of obsolete columns
119
120 const MDataArray &dcol = *matrixtrain.GetColumns();
121
122 MDataArray usedrules;
123 for (Int_t i=0; i<ncols; i++)
124 if (i<ncols-nobs) // -3 is important!!!
125 usedrules.AddEntry(dcol[i].GetRule());
126 else
127 *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;
128
129 MDataArray rules(usedrules);
130 rules.AddEntry(ver<3?"Classification":dcol[ncols-1].GetRule());
131
132 // prepare matrix for current energy bin
133 TMatrix mat(matrixtrain.GetM());
134
135 // last column must be removed (true energy col.)
136 mat.ResizeTo(nrows, ncols-nobs+1);
137
138 if (fDebug)
139 gLog.SetNullOutput(kTRUE);
140
141 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
142 for (Int_t ie=0; ie<nbins; ie++)
143 {
144 switch (ver)
145 {
146 case 0: // Replace last column by a classification
147 {
148 Int_t irows=0;
149 for (Int_t j=0; j<nrows; j++)
150 {
151 const Double_t energy = matrixtrain.GetM()(j,ncols-1);
152 const Bool_t inside = energy>grid[ie] && energy<=grid[ie+1];
153
154 mat(j, ncols-nobs) = inside ? 1 : 0;
155
156 if (inside)
157 irows++;
158 }
159 if (irows==0)
160 *fLog << warn << "WARNING - Skipping";
161 else
162 *fLog << inf << "Training RF for";
163
164 *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
165
166 if (irows==0)
167 continue;
168 }
169 break;
170
171 case 1: // Use last column as classifier or for regression
172 case 2:
173 case 3:
174 for (Int_t j=0; j<nrows; j++)
175 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
176 break;
177 }
178
179 MHMatrix matrix(mat, &rules, "MatrixTrain");
180
181 MParList plist;
182 MTaskList tlist;
183 plist.AddToList(&tlist);
184 plist.AddToList(&matrix);
185
186 MRanForest rf;
187 rf.SetNumTrees(fNumTrees);
188 rf.SetNumTry(fNumTry);
189 rf.SetNdSize(fNdSize);
190 rf.SetClassify(ver<3 ? kTRUE : kFALSE);
191 if (ver==1)
192 rf.SetGrid(grid);
193
194 plist.AddToList(&rf);
195
196 MRanForestGrow rfgrow;
197 tlist.AddToList(&rfgrow);
198
199 MFillH fillh("MHRanForestGini");
200 tlist.AddToList(&fillh);
201
202 MEvtLoop evtloop;
203 evtloop.SetParList(&plist);
204 evtloop.SetDisplay(fDisplay);
205 evtloop.SetLogStream(fLog);
206
207 if (!evtloop.Eventloop())
208 return kFALSE;
209
210 if (fDebug)
211 gLog.SetNullOutput(kFALSE);
212
213 if (ver==0)
214 {
215 // Calculate bin center
216 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
217
218 // save whole forest
219 rf.SetUserVal(E);
220 rf.SetName(Form("%.10f", E));
221 }
222
223 rf.Write();
224 }
225
226 // save rules
227 usedrules.Write("rules");
228
229 return kTRUE;
230}
231
232Int_t MRanForestCalc::ReadForests(MParList &plist)
233{
234 TFile fileRF(fFileName, "read");
235 if (!fileRF.IsOpen())
236 {
237 *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
238 return kFALSE;
239 }
240
241 fEForests.Delete();
242
243 TIter Next(fileRF.GetListOfKeys());
244 TObject *o=0;
245 while ((o=Next()))
246 {
247 MRanForest *forest=0;
248 fileRF.GetObject(o->GetName(), forest);
249 if (!forest)
250 continue;
251
252 forest->SetUserVal(atof(o->GetName()));
253
254 fEForests.Add(forest);
255 }
256
257 // Maybe fEForests[0].fRules yould be used instead?
258
259 if (fData->Read("rules")<=0)
260 {
261 *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
262 return kFALSE;
263 }
264
265 return kTRUE;
266}
267
268Int_t MRanForestCalc::PreProcess(MParList *plist)
269{
270 fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
271 if (!fRFOut)
272 return kFALSE;
273
274 fData = (MDataArray*)plist->FindCreateObj("MDataArray");
275 if (!fData)
276 return kFALSE;
277
278 if (!ReadForests(*plist))
279 {
280 *fLog << err << "Reading RFs failed... aborting." << endl;
281 return kFALSE;
282 }
283
284 *fLog << inf << "RF read from " << fFileName << endl;
285
286 if (fTestMatrix)
287 return kTRUE;
288
289 fData->Print();
290
291 if (!fData->PreProcess(plist))
292 {
293 *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
294 return kFALSE;
295 }
296
297 return kTRUE;
298}
299
300#include <TGraph.h>
301#include <TF1.h>
302Int_t MRanForestCalc::Process()
303{
304 TVector event;
305 if (fTestMatrix)
306 *fTestMatrix >> event;
307 else
308 *fData >> event;
309
310 // --------------- Single Tree RF -------------------
311 if (fEForests.GetEntriesFast()==1)
312 {
313 MRanForest *rf = static_cast<MRanForest*>(fEForests.UncheckedAt(0));
314 fRFOut->SetVal(rf->CalcHadroness(event));
315 fRFOut->SetReadyToSave();
316
317 return kTRUE;
318 }
319
320 // --------------- Multi Tree RF -------------------
321 static TF1 f1("f1", "gaus");
322
323 Double_t sume = 0;
324 Double_t sumh = 0;
325 Double_t maxh = 0;
326 Double_t maxe = 0;
327
328 Double_t max = -1e10;
329 Double_t min = 1e10;
330
331 TIter Next(&fEForests);
332 MRanForest *rf = 0;
333
334 TGraph g;
335 while ((rf=(MRanForest*)Next()))
336 {
337 const Double_t h = rf->CalcHadroness(event);
338 const Double_t e = rf->GetUserVal();
339
340 g.SetPoint(g.GetN(), e, h);
341
342 sume += e*h;
343 sumh += h;
344
345 if (h>maxh)
346 {
347 maxh = h;
348 maxe = e;
349 }
350 if (e>max)
351 max = e;
352 if (e<min)
353 min = e;
354 }
355
356 switch (fEstimationMode)
357 {
358 case kMean:
359 fRFOut->SetVal(pow(10, sume/sumh));
360 break;
361 case kMaximum:
362 fRFOut->SetVal(pow(10, maxe));
363 break;
364 case kFit:
365 f1.SetParameter(0, maxh);
366 f1.SetParameter(1, maxe);
367 f1.SetParameter(2, 0.125);
368 g.Fit(&f1, "Q0N");
369 fRFOut->SetVal(pow(10, f1.GetParameter(1)));
370 break;
371 }
372
373 fRFOut->SetReadyToSave();
374
375 return kTRUE;
376}
377
378void MRanForestCalc::Print(Option_t *o) const
379{
380 *fLog << all;
381 *fLog << GetDescriptor() << ":" << endl;
382 *fLog << " - Forest ";
383 switch (fEForests.GetEntries())
384 {
385 case 0: *fLog << "not yet initialized." << endl; break;
386 case 1: *fLog << "is a single tree forest." << endl; break;
387 default: *fLog << "is a multi tree forest." << endl; break;
388 }
389 /*
390 *fLog << " - Trees: " << fNumTrees << endl;
391 *fLog << " - Trys: " << fNumTry << endl;
392 *fLog << " - Node Size: " << fNdSize << endl;
393 *fLog << " - Node Size: " << fNdSize << endl;
394 */
395 *fLog << " - FileName: " << fFileName << endl;
396 *fLog << " - NameOutput: " << fNameOutput << endl;
397}
398
399// --------------------------------------------------------------------------
400//
401//
402Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
403{
404 Bool_t rc = kFALSE;
405 if (IsEnvDefined(env, prefix, "FileName", print))
406 {
407 rc = kTRUE;
408 SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
409 }
410 if (IsEnvDefined(env, prefix, "Debug", print))
411 {
412 rc = kTRUE;
413 SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
414 }
415 if (IsEnvDefined(env, prefix, "NameOutput", print))
416 {
417 rc = kTRUE;
418 SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
419 }
420 if (IsEnvDefined(env, prefix, "EstimationMode", print))
421 {
422 TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
423 txt = txt.Strip(TString::kBoth);
424 txt.ToLower();
425 if (txt==(TString)"mean")
426 fEstimationMode = kMean;
427 if (txt==(TString)"maximum")
428 fEstimationMode = kMaximum;
429 if (txt==(TString)"fit")
430 fEstimationMode = kFit;
431 rc = kTRUE;
432 }
433 return rc;
434}
Note: See TracBrowser for help on using the repository browser.