#ifndef MARS_MJTrainRanForest #define MARS_MJTrainRanForest #ifndef MARS_MJob #include "MJob.h" #endif class MTask; class MFilter; class MJTrainRanForest : public MJob { protected: Bool_t fDebug; Bool_t fEnableWeights; TList fRules; TList fPreCuts; TList fTrainCuts; TList fTestCuts; TList fPreTasks; TList fPostTasks; UShort_t fNumTrees; UShort_t fNdSize; UShort_t fNumTry; Bool_t WriteDisplay(const char *fname) const; void AddCut(TList &l, const char *rule); void AddPar(TList &l, const char *rule, const char *name); void Add(TList &l, MTask *f); public: MJTrainRanForest() : fDebug(kFALSE), fEnableWeights(kFALSE) { fNumTrees = 100; //100 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated fNdSize = 1; //1 } void AddPreTask(MTask *t) { Add(fPreTasks, t); } void AddPreTask(const char *rule, const char *name="MWeight") { AddPar(fPreTasks, rule, name); } void AddPostTask(MTask *t) { Add(fPostTasks, t); } void AddPostTask(const char *rule, const char *name="MWeight") { AddPar(fPostTasks, rule, name); } void SetDebug(Bool_t b=kTRUE) { fDebug = b; } void SetWeights(const char *rule) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(rule); } void SetWeights(MTask *t) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(t); } void AddPreCut(const char *rule) { AddCut(fPreCuts, rule); } void AddPreCut(MFilter *f) { Add(fPreCuts, (MTask*)(f)); } void AddTrainCut(const char *rule) { AddCut(fTrainCuts, rule); } void AddTrainCut(MFilter *f) { Add(fTrainCuts, (MTask*)(f)); } void AddTestCut(const char *rule) { AddCut(fTestCuts, rule); } void AddTestCut(MFilter *f) { Add(fTestCuts, (MTask*)(f)); } void SetNumTrees(UShort_t n=100) { fNumTrees = n; } void SetNdSize(UShort_t n=5) { fNdSize = n; } void SetNumTry(UShort_t n=0) { fNumTry = n; } Int_t AddParameter(const char *rule); ClassDef(MJTrainRanForest, 0)//Base class for Random Forest training classes }; #endif