source: branches/MarsMoreSimulationTruth/mjtrain/MJTrainCuts.h@ 20115

Last change on this file since 20115 was 9870, checked in by tbretz, 14 years ago
Added new class MJTrainCuts to help analysing the performance of the random forest and finding classical cuts.
File size: 4.1 KB
Line 
1#ifndef MARS_MJTrainCuts
2#define MARS_MJTrainCuts
3
4#ifndef MARS_MJTrainRanForest
5#include "MJTrainRanForest.h"
6#endif
7
8#ifndef ROOT_TObjArray
9#include <TObjArray.h>
10#endif
11
12#ifndef MARS_MDataSet
13#include "MDataSet.h"
14#endif
15
16class MH3;
17class MBinning;
18
19class MJTrainCuts : public MJTrainRanForest
20{
21public:
22 enum Type_t { kTrainOn, kTrainOff, kTestOn, kTestOff };
23
24private:
25 MDataSet fDataSetOn;
26 MDataSet fDataSetOff;
27
28 UInt_t fNum[4];
29/*
30 TList fPreTasksSet[4];
31 TList fPostTasksSet[4];
32
33 Bool_t fEnableWeights[4];
34*/
35 Bool_t fUseRegression;
36
37 TList fHists;
38 TObjArray fBinnings;
39
40 // Result
41 void DisplayResult(MH3 &h31, MH3 &h32, Float_t ontime);
42
43public:
44 MJTrainCuts() : fUseRegression(kFALSE)
45 {
46 for (int i=0; i<4; i++)
47 {
48 //fEnableWeights[i]=kFALSE;
49 fNum[i] = (UInt_t)-1;
50 }
51
52 fHists.SetOwner();
53 fBinnings.SetOwner();
54 }
55
56 void SetDataSetOn(const MDataSet &ds, UInt_t ntrain=(UInt_t)-1, UInt_t ntest=(UInt_t)-1)
57 {
58 ds.Copy(fDataSetOn);
59
60 fDataSetOn.SetNumAnalysis(1);
61
62 fNum[kTestOn] = ntrain;
63 fNum[kTrainOn] = ntest;
64 }
65 void SetDataSetOff(const MDataSet &ds, UInt_t ntrain=(UInt_t)-1, UInt_t ntest=(UInt_t)-1)
66 {
67 ds.Copy(fDataSetOff);
68
69 fDataSetOff.SetNumAnalysis(2);
70
71 fNum[kTestOff] = ntrain;
72 fNum[kTrainOff] = ntest;
73 }
74
75 // Add Histogram
76 void AddHist(UInt_t nx);
77 void AddHist(UInt_t nx, UInt_t ny);
78 void AddHist(UInt_t nx, UInt_t ny, UInt_t nz);
79
80 void AddBinning(UInt_t n, const MBinning &bins);
81 //void AddBinning(const MBinning &bins);
82
83 // Standard user interface
84/*
85 void SetWeights(Type_t typ, const char *rule) { if (fEnableWeights[typ]) return; fEnableWeights[typ]=kTRUE; AddPostTask(typ, rule); }
86 void SetWeights(Type_t typ, MTask *t) { if (fEnableWeights[typ]) return; fEnableWeights[typ]=kTRUE; AddPostTask(typ, t); }
87
88 void AddPreTaskOn(MTask *t) { AddPreTask(kTrainOn, t); AddPreTask(kTestOn, t); }
89 void AddPreTaskOn(const char *rule, const char *name="MWeight") { AddPreTask(kTrainOn, rule, name); AddPreTask(kTestOn, rule, name); }
90 void AddPreTaskOff(MTask *t) { AddPreTask(kTrainOff, t); AddPreTask(kTestOff, t); }
91 void AddPreTaskOff(const char *rule, const char *name="MWeight") { AddPreTask(kTrainOff, rule, name); AddPreTask(kTestOff, rule, name); }
92
93 void AddPostTaskOn(MTask *t) { AddPostTask(kTrainOn, t); AddPostTask(kTestOn, t); }
94 void AddPostTaskOn(const char *rule, const char *name="MWeight") { AddPostTask(kTrainOn, rule, name); AddPostTask(kTestOn, rule, name); }
95 void AddPostTaskOff(MTask *t) { AddPostTask(kTrainOff, t); AddPostTask(kTestOff, t); }
96 void AddPostTaskOff(const char *rule, const char *name="MWeight") { AddPostTask(kTrainOff, rule, name); AddPostTask(kTestOff, rule, name); }
97
98 void AddPreTask(Type_t typ, MTask *t) { Add(fPreTasksSet[typ], t); }
99 void AddPreTask(Type_t typ, const char *rule, const char *name="MWeight") { AddPar(fPreTasksSet[typ], rule, name); }
100 void AddPostTask(Type_t typ, MTask *t) { Add(fPostTasksSet[typ], t); }
101 void AddPostTask(Type_t typ, const char *rule, const char *name="MWeight") { AddPar(fPostTasksSet[typ], rule, name); }
102
103 void SetWeightsOn(const char *rule) { SetWeights(kTrainOn, rule); SetWeights(kTestOn, rule); }
104 void SetWeightsOn(MTask *t) { SetWeights(kTrainOn, t); SetWeights(kTestOn, t); }
105 void SetWeightsOff(const char *rule) { SetWeights(kTrainOff, rule); SetWeights(kTestOff, rule); }
106 void SetWeightsOff(MTask *t) { SetWeights(kTrainOff, t); SetWeights(kTestOff, t); }
107*/
108 void EnableRegression(Bool_t b=kTRUE) { fUseRegression = b; }
109 void EnableClassification(Bool_t b=kTRUE) { fUseRegression = !b; }
110
111 // Main function to start processing
112 Bool_t Process(const char *out);
113
114 ClassDef(MJTrainCuts, 0)//Class to help finding cuts using the random forest
115};
116
117#endif
Note: See TracBrowser for help on using the repository browser.