| 1 | #ifndef MARS_MJTrainRanForest
|
|---|
| 2 | #define MARS_MJTrainRanForest
|
|---|
| 3 |
|
|---|
| 4 | #ifndef MARS_MJob
|
|---|
| 5 | #include "MJob.h"
|
|---|
| 6 | #endif
|
|---|
| 7 |
|
|---|
| 8 | class MTask;
|
|---|
| 9 | class MFilter;
|
|---|
| 10 |
|
|---|
| 11 | class MJTrainRanForest : public MJob
|
|---|
| 12 | {
|
|---|
| 13 | protected:
|
|---|
| 14 | Bool_t fDebug;
|
|---|
| 15 | Bool_t fEnableWeights;
|
|---|
| 16 |
|
|---|
| 17 | TList fRules;
|
|---|
| 18 |
|
|---|
| 19 | TList fPreCuts;
|
|---|
| 20 | TList fTrainCuts;
|
|---|
| 21 | TList fTestCuts;
|
|---|
| 22 | TList fPreTasks;
|
|---|
| 23 | TList fPostTasks;
|
|---|
| 24 |
|
|---|
| 25 | UShort_t fNumTrees;
|
|---|
| 26 | UShort_t fNdSize;
|
|---|
| 27 | UShort_t fNumTry;
|
|---|
| 28 |
|
|---|
| 29 | Bool_t WriteDisplay(const char *fname) const;
|
|---|
| 30 |
|
|---|
| 31 | void AddCut(TList &l, const char *rule);
|
|---|
| 32 | void AddPar(TList &l, const char *rule, const char *name);
|
|---|
| 33 | void Add(TList &l, MTask *f);
|
|---|
| 34 |
|
|---|
| 35 | public:
|
|---|
| 36 | MJTrainRanForest() : fDebug(kFALSE), fEnableWeights(kFALSE)
|
|---|
| 37 | {
|
|---|
| 38 | fNumTrees = 100; //100
|
|---|
| 39 | fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
|
|---|
| 40 | fNdSize = 1; //1
|
|---|
| 41 | }
|
|---|
| 42 |
|
|---|
| 43 | void AddPreTask(MTask *t) { Add(fPreTasks, t); }
|
|---|
| 44 | void AddPreTask(const char *rule,
|
|---|
| 45 | const char *name="MWeight") { AddPar(fPreTasks, rule, name); }
|
|---|
| 46 |
|
|---|
| 47 | void AddPostTask(MTask *t) { Add(fPostTasks, t); }
|
|---|
| 48 | void AddPostTask(const char *rule,
|
|---|
| 49 | const char *name="MWeight") { AddPar(fPostTasks, rule, name); }
|
|---|
| 50 |
|
|---|
| 51 | void SetDebug(Bool_t b=kTRUE) { fDebug = b; }
|
|---|
| 52 |
|
|---|
| 53 | void SetWeights(const char *rule) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(rule); }
|
|---|
| 54 | void SetWeights(MTask *t) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(t); }
|
|---|
| 55 |
|
|---|
| 56 | void AddPreCut(const char *rule) { AddCut(fPreCuts, rule); }
|
|---|
| 57 | void AddPreCut(MFilter *f) { Add(fPreCuts, (MTask*)(f)); }
|
|---|
| 58 |
|
|---|
| 59 | void AddTrainCut(const char *rule) { AddCut(fTrainCuts, rule); }
|
|---|
| 60 | void AddTrainCut(MFilter *f) { Add(fTrainCuts, (MTask*)(f)); }
|
|---|
| 61 |
|
|---|
| 62 | void AddTestCut(const char *rule) { AddCut(fTestCuts, rule); }
|
|---|
| 63 | void AddTestCut(MFilter *f) { Add(fTestCuts, (MTask*)(f)); }
|
|---|
| 64 |
|
|---|
| 65 | void SetNumTrees(UShort_t n=100) { fNumTrees = n; }
|
|---|
| 66 | void SetNdSize(UShort_t n=5) { fNdSize = n; }
|
|---|
| 67 | void SetNumTry(UShort_t n=0) { fNumTry = n; }
|
|---|
| 68 |
|
|---|
| 69 | Int_t AddParameter(const char *rule);
|
|---|
| 70 |
|
|---|
| 71 | ClassDef(MJTrainRanForest, 0)//Base class for Random Forest training classes
|
|---|
| 72 | };
|
|---|
| 73 |
|
|---|
| 74 | #endif
|
|---|