source: trunk/MagicSoft/Mars/mjtrain/MJTrainRanForest.h@ 8051

Last change on this file since 8051 was 7697, checked in by tbretz, 19 years ago
*** empty log message ***
File size: 2.2 KB
Line 
1#ifndef MARS_MJTrainRanForest
2#define MARS_MJTrainRanForest
3
4#ifndef MARS_MJob
5#include "MJob.h"
6#endif
7
8class MTask;
9class MFilter;
10
11class MJTrainRanForest : public MJob
12{
13protected:
14 Bool_t fDebug;
15 Bool_t fEnableWeights;
16
17 TList fRules;
18
19 TList fPreCuts;
20 TList fTrainCuts;
21 TList fTestCuts;
22 TList fPreTasks;
23 TList fPostTasks;
24
25 UShort_t fNumTrees;
26 UShort_t fNdSize;
27 UShort_t fNumTry;
28
29 Bool_t WriteDisplay(const char *fname) const;
30
31 void AddCut(TList &l, const char *rule);
32 void AddPar(TList &l, const char *rule, const char *name);
33 void Add(TList &l, MTask *f);
34
35public:
36 MJTrainRanForest() : fDebug(kFALSE), fEnableWeights(kFALSE)
37 {
38 fNumTrees = 100; //100
39 fNumTry = 0; //3 0 means: in MRanForest estimated best value will be calculated
40 fNdSize = 1; //1
41 }
42
43 void AddPreTask(MTask *t) { Add(fPreTasks, t); }
44 void AddPreTask(const char *rule,
45 const char *name="MWeight") { AddPar(fPreTasks, rule, name); }
46
47 void AddPostTask(MTask *t) { Add(fPostTasks, t); }
48 void AddPostTask(const char *rule,
49 const char *name="MWeight") { AddPar(fPostTasks, rule, name); }
50
51 void SetDebug(Bool_t b=kTRUE) { fDebug = b; }
52
53 void SetWeights(const char *rule) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(rule); }
54 void SetWeights(MTask *t) { if (fEnableWeights) return; fEnableWeights=kTRUE; AddPostTask(t); }
55
56 void AddPreCut(const char *rule) { AddCut(fPreCuts, rule); }
57 void AddPreCut(MFilter *f) { Add(fPreCuts, (MTask*)(f)); }
58
59 void AddTrainCut(const char *rule) { AddCut(fTrainCuts, rule); }
60 void AddTrainCut(MFilter *f) { Add(fTrainCuts, (MTask*)(f)); }
61
62 void AddTestCut(const char *rule) { AddCut(fTestCuts, rule); }
63 void AddTestCut(MFilter *f) { Add(fTestCuts, (MTask*)(f)); }
64
65 void SetNumTrees(UShort_t n=100) { fNumTrees = n; }
66 void SetNdSize(UShort_t n=5) { fNdSize = n; }
67 void SetNumTry(UShort_t n=0) { fNumTry = n; }
68
69 Int_t AddParameter(const char *rule);
70
71 ClassDef(MJTrainRanForest, 0)//Base class for Random Forest training classes
72};
73
74#endif
Note: See TracBrowser for help on using the repository browser.