source: branches/AddingGoogleTestEnvironment/mjtrain/MJTrainSeparation.h@ 18179

Last change on this file since 18179 was 7724, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 4.4 KB
Line 
1#ifndef MARS_MJTrainSeparation
2#define MARS_MJTrainSeparation
3
4#ifndef MARS_MJTrainRanForest
5#include "MJTrainRanForest.h"
6#endif
7
8#ifndef MARS_MDataSet
9#include "MDataSet.h"
10#endif
11
12class MH3;
13
14class MJTrainSeparation : public MJTrainRanForest
15{
16public:
17 enum Type_t { kTrainOn, kTrainOff, kTestOn, kTestOff };
18
19private:
20 MDataSet fDataSetTest;
21 MDataSet fDataSetTrain;
22
23 UInt_t fNum[4];
24
25 TList fPreTasksSet[4];
26 TList fPostTasksSet[4];
27
28 Bool_t fAutoTrain;
29 Bool_t fUseRegression;
30
31 Bool_t fEnableWeights[4];
32
33 Float_t fFluxTrain;
34 Float_t fFluxTest;
35
36 // Result
37 void DisplayResult(MH3 &h31, MH3 &h32, Float_t ontime);
38
39 // Auto training
40 Bool_t GetEventsProduced(MDataSet &set, Double_t &num, Double_t &min, Double_t &max) const;
41 Double_t GetDataRate(MDataSet &set, Double_t &num) const;
42 Double_t GetNumMC(MDataSet &set) const;
43 Float_t AutoTrain(MDataSet &set, Type_t typon, Type_t typoff, Float_t flux);
44
45public:
46 MJTrainSeparation() :
47 fAutoTrain(kFALSE), fUseRegression(kFALSE),
48 fFluxTrain(2e-7), fFluxTest(2e-7)
49 { for (int i=0; i<4; i++) { fEnableWeights[i]=kFALSE; fNum[i] = (UInt_t)-1; } }
50
51 void SetDataSetTrain(const MDataSet &ds, UInt_t non=(UInt_t)-1, UInt_t noff=(UInt_t)-1)
52 {
53 ds.Copy(fDataSetTrain);
54
55 fDataSetTrain.SetNumAnalysis(1);
56
57 fNum[kTrainOn] = non;
58 fNum[kTrainOff] = noff;
59 }
60 void SetDataSetTest(const MDataSet &ds, UInt_t non=(UInt_t)-1, UInt_t noff=(UInt_t)-1)
61 {
62 ds.Copy(fDataSetTest);
63
64 fDataSetTest.SetNumAnalysis(1);
65
66 fNum[kTestOn] = non;
67 fNum[kTestOff] = noff;
68 }
69
70 // Deprecated, used for test purpose
71 void AddPreTask(Type_t typ, MTask *t) { Add(fPreTasksSet[typ], t); }
72 void AddPreTask(Type_t typ, const char *rule, const char *name="MWeight") { AddPar(fPreTasksSet[typ], rule, name); }
73
74 void AddPostTask(Type_t typ, MTask *t) { Add(fPostTasksSet[typ], t); }
75 void AddPostTask(Type_t typ, const char *rule, const char *name="MWeight") { AddPar(fPostTasksSet[typ], rule, name); }
76
77 void SetWeights(Type_t typ, const char *rule) { if (fEnableWeights[typ]) return; fEnableWeights[typ]=kTRUE; AddPostTask(typ, rule); }
78 void SetWeights(Type_t typ, MTask *t) { if (fEnableWeights[typ]) return; fEnableWeights[typ]=kTRUE; AddPostTask(typ, t); }
79
80 // Standard user interface
81 void AddPreTaskOn(MTask *t) { AddPreTask(kTrainOn, t); AddPreTask(kTestOn, t); }
82 void AddPreTaskOn(const char *rule, const char *name="MWeight") { AddPreTask(kTrainOn, rule, name); AddPreTask(kTestOn, rule, name); }
83 void AddPreTaskOff(MTask *t) { AddPreTask(kTrainOff, t); AddPreTask(kTestOff, t); }
84 void AddPreTaskOff(const char *rule, const char *name="MWeight") { AddPreTask(kTrainOff, rule, name); AddPreTask(kTestOff, rule, name); }
85
86 void AddPostTaskOn(MTask *t) { AddPostTask(kTrainOn, t); AddPostTask(kTestOn, t); }
87 void AddPostTaskOn(const char *rule, const char *name="MWeight") { AddPostTask(kTrainOn, rule, name); AddPostTask(kTestOn, rule, name); }
88 void AddPostTaskOff(MTask *t) { AddPostTask(kTrainOff, t); AddPostTask(kTestOff, t); }
89 void AddPostTaskOff(const char *rule, const char *name="MWeight") { AddPostTask(kTrainOff, rule, name); AddPostTask(kTestOff, rule, name); }
90
91 void SetWeightsOn(const char *rule) { SetWeights(kTrainOn, rule); SetWeights(kTestOn, rule); }
92 void SetWeightsOn(MTask *t) { SetWeights(kTrainOn, t); SetWeights(kTestOn, t); }
93 void SetWeightsOff(const char *rule) { SetWeights(kTrainOff, rule); SetWeights(kTestOff, rule); }
94 void SetWeightsOff(MTask *t) { SetWeights(kTrainOff, t); SetWeights(kTestOff, t); }
95
96 void SetFluxTrain(Float_t f) { fFluxTrain = f; }
97 void SetFluxTest(Float_t f) { fFluxTest = f; }
98 void SetFlux(Float_t f) { SetFluxTrain(f); SetFluxTest(f); }
99
100 void EnableAutoTrain(Bool_t b=kTRUE) { fAutoTrain = b; }
101 void EnableRegression(Bool_t b=kTRUE) { fUseRegression = b; }
102 void EnableClassification(Bool_t b=kTRUE) { fUseRegression = !b; }
103
104 Bool_t Train(const char *out);
105
106 ClassDef(MJTrainSeparation, 0)//Class to train Random Forest gamma-/background-separation
107};
108
109#endif
Note: See TracBrowser for help on using the repository browser.