Changeset 7396 for trunk/MagicSoft/Mars/mranforest
- Timestamp:
- 11/14/05 09:45:33 (19 years ago)
- Location:
- trunk/MagicSoft/Mars/mranforest
- Files:
-
- 10 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/MagicSoft/Mars/mranforest/MRFEnergyEst.cc
r7178 r7396 17 17 ! 18 18 ! Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de> 19 ! Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de> 19 20 ! 20 21 ! Copyright: MAGIC Software Development, 2000-2005 … … 31 32 #include "MRFEnergyEst.h" 32 33 33 #include <TFile.h>34 #include <TList.h>35 36 #include <TH1.h>37 #include <TH2.h>38 #include <TStyle.h>39 #include <TCanvas.h>40 #include <TMath.h>41 34 #include <TVector.h> 42 35 … … 45 38 #include "MLog.h" 46 39 #include "MLogManip.h" 40 41 #include "MData.h" 42 #include "MDataArray.h" 43 44 #include "MRanForest.h" 45 #include "MParameters.h" 47 46 48 47 #include "MParList.h" 49 48 #include "MTaskList.h" 50 49 #include "MEvtLoop.h" 51 52 #include "MRanTree.h"53 #include "MRanForest.h"54 50 #include "MRanForestGrow.h" 55 51 56 #include "MData.h"57 #include "MParameters.h"58 59 52 ClassImp(MRFEnergyEst); 60 53 61 54 using namespace std; 62 55 63 static const TString gsDefName = "MRFEnergyEst"; 64 static const TString gsDefTitle = "RF for energy estimation"; 65 66 // -------------------------------------------------------------------------- 67 // 68 // Default constructor. Set the name and title of this object 69 // 56 const TString MRFEnergyEst::gsDefName = "MRFEnergyEst"; 57 const TString MRFEnergyEst::gsDefTitle = "RF for energy estimation"; 58 70 59 MRFEnergyEst::MRFEnergyEst(const char *name, const char *title) 71 60 : fDebug(kFALSE), fData(0), fEnergyEst(0), … … 75 64 fName = name ? name : gsDefName.Data(); 76 65 fTitle = title ? title : gsDefTitle.Data(); 66 67 // FIXME: 68 fNumTrees = 100; //100 69 fNumTry = 5; //3 70 fNdSize = 0; //1 0 means: in MRanForest estimated best value will be calculated 77 71 } 78 72 … … 82 76 } 83 77 84 // -------------------------------------------------------------------------- 85 // 86 // Train the RF with the goven MHMatrix. The last column must contain the 87 // True energy. 88 // 89 Int_t MRFEnergyEst::Train(const MHMatrix &matrixtrain, const TArrayD &grid) 78 Int_t MRFEnergyEst::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver) 90 79 { 91 80 gLog.Separator("MRFEnergyEst - Train"); … … 105 94 } 106 95 107 const Int_t nbins = grid.GetSize()-1;108 if (nbins<=0)109 {110 *fLog << err << "ERROR - Energy grid not vaild... abort." << endl;111 return kFALSE;112 }113 114 96 // rules (= combination of image par) to be used for energy estimation 115 97 TFile fileRF(fFileName, "recreate"); … … 122 104 const Int_t nobs = 3; // Number of obsolete columns 123 105 106 MDataArray &dcol = *matrixtrain.GetColumns(); 107 124 108 MDataArray usedrules; 125 for(Int_t i=0; i<ncols-nobs; i++) // -3 is important!!! 126 usedrules.AddEntry((*matrixtrain.GetColumns())[i].GetRule()); 127 128 // training of RF for each energy bin 109 for (Int_t i=0; i<ncols-nobs; i++) // -3 is important!!! 110 usedrules.AddEntry(dcol[i].GetRule()); 111 112 MDataArray rules(usedrules); 113 rules.AddEntry(ver<2?"Classification":dcol[ncols-1].GetRule()); 114 115 // prepare matrix for current energy bin 116 TMatrix mat(matrixtrain.GetM()); 117 118 // last column must be removed (true energy col.) 119 mat.ResizeTo(nrows, ncols-nobs+1); 120 121 if (fDebug) 122 gLog.SetNullOutput(kTRUE); 123 124 const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1; 129 125 for (Int_t ie=0; ie<nbins; ie++) 130 126 { 131 TMatrix mat1(nrows, ncols); 132 TMatrix mat0(nrows, ncols); 133 134 // prepare matrix for current energy bin 135 Int_t irow1=0; 136 Int_t irow0=0; 137 138 const TMatrix &m = matrixtrain.GetM(); 139 for (Int_t j=0; j<nrows; j++) 127 switch (ver) 140 128 { 141 const Double_t energy = m(j,ncols-1); 142 143 if (energy>grid[ie] && energy<=grid[ie+1]) 144 TMatrixFRow(mat1, irow1++) = TMatrixFRow_const(m,j); 145 else 146 TMatrixFRow(mat0, irow0++) = TMatrixFRow_const(m,j); 129 case 0: // Replace Energy Grid by classification 130 { 131 Int_t irows=0; 132 for (Int_t j=0; j<nrows; j++) 133 { 134 const Double_t energy = matrixtrain.GetM()(j,ncols-1); 135 const Bool_t inside = energy>grid[ie] && energy<=grid[ie+1]; 136 137 mat(j, ncols-nobs) = inside ? 1 : 0; 138 139 if (inside) 140 irows++; 141 } 142 if (irows==0) 143 *fLog << warn << "WARNING - Skipping"; 144 else 145 *fLog << inf << "Training RF for"; 146 147 *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl; 148 149 if (irows==0) 150 continue; 151 } 152 break; 153 154 case 1: // Use Energy as classifier 155 case 2: 156 for (Int_t j=0; j<nrows; j++) 157 mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1); 158 break; 147 159 } 148 160 149 const Bool_t invalid = irow1==0 || irow0==0; 150 151 if (invalid) 152 *fLog << warn << "WARNING - Skipping"; 153 else 154 *fLog << inf << "Training RF for"; 155 156 *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irow0 << " " << irow1 << endl; 157 158 if (invalid) 159 continue; 160 161 if (fDebug) 162 gLog.SetNullOutput(kTRUE); 163 164 // last column must be removed (true energy col.) 165 mat1.ResizeTo(irow1, ncols-nobs); 166 mat0.ResizeTo(irow0, ncols-nobs); 167 168 // create MHMatrix as input for RF 169 MHMatrix matrix1(mat1, "MatrixHadrons"); 170 MHMatrix matrix0(mat0, "MatrixGammas"); 171 172 // training of RF 161 MHMatrix matrix(mat, &rules, "MatrixTrain"); 162 163 MParList plist; 173 164 MTaskList tlist; 174 MParList plist;175 165 plist.AddToList(&tlist); 176 plist.AddToList(&matrix0); 177 plist.AddToList(&matrix1); 166 plist.AddToList(&matrix); 167 168 MRanForest rf; 169 rf.SetNumTrees(fNumTrees); 170 rf.SetNumTry(fNumTry); 171 rf.SetNdSize(fNdSize); 172 rf.SetClassify(ver<2 ? 1 : 0); 173 if (ver==1) 174 rf.SetGrid(grid); 175 176 plist.AddToList(&rf); 178 177 179 178 MRanForestGrow rfgrow; 180 rfgrow.SetNumTrees(fNumTrees); // number of trees181 rfgrow.SetNumTry(fNumTry); // number of trials in random split selection182 rfgrow.SetNdSize(fNdSize); // limit for nodesize183 184 179 tlist.AddToList(&rfgrow); 185 180 186 181 MEvtLoop evtloop; 187 evtloop.SetDisplay(fDisplay);188 182 evtloop.SetParList(&plist); 189 183 … … 194 188 gLog.SetNullOutput(kFALSE); 195 189 196 // Calculate bin center 197 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2; 198 199 // save whole forest 200 MRanForest *forest=(MRanForest*)plist.FindObject("MRanForest"); 201 forest->SetUserVal(E); 202 forest->Write(Form("%.10f", E)); 190 if (ver==0) 191 { 192 // Calculate bin center 193 const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2; 194 195 // save whole forest 196 rf.SetUserVal(E); 197 rf.SetName(Form("%.10f", E)); 198 } 199 200 rf.Write(); 203 201 } 204 202 … … 224 222 while ((o=Next())) 225 223 { 226 MRanForest *forest ;224 MRanForest *forest=0; 227 225 fileRF.GetObject(o->GetName(), forest); 228 226 if (!forest) … … 234 232 } 235 233 234 // Maybe fEForests[0].fRules yould be used instead? 235 236 236 if (fData->Read("rules")<=0) 237 237 { … … 249 249 return kFALSE; 250 250 251 cout << "MDataArray" << endl; 252 251 253 fData = (MDataArray*)plist->FindCreateObj("MDataArray"); 252 254 if (!fData) 253 255 return kFALSE; 254 256 257 cout << "ReadForests" << endl; 258 255 259 if (!ReadForests(*plist)) 256 260 { … … 275 279 } 276 280 277 // --------------------------------------------------------------------------278 //279 //280 281 #include <TGraph.h> 281 282 #include <TF1.h> 282 283 Int_t MRFEnergyEst::Process() 283 284 { 284 static TF1 f1("f1", "gaus");285 286 285 TVector event; 287 286 if (fTestMatrix) … … 290 289 *fData >> event; 291 290 291 // --------------- Single Tree RF ------------------- 292 if (fEForests.GetEntries()==1) 293 { 294 const MRanForest *rf = (MRanForest*)fEForests[0]; 295 fEnergyEst->SetVal(rf->CalcHadroness(event)); 296 fEnergyEst->SetReadyToSave(); 297 298 return kTRUE; 299 } 300 301 // --------------- Multi Tree RF ------------------- 302 static TF1 f1("f1", "gaus"); 303 292 304 Double_t sume = 0; 293 305 Double_t sumh = 0; … … 308 320 const Double_t h = rf->CalcHadroness(event); 309 321 const Double_t e = rf->GetUserVal(); 322 310 323 g.SetPoint(g.GetN(), e, h); 324 311 325 sume += e*h; 312 326 sumh += h; 327 313 328 if (h>maxh) 314 329 { … … 337 352 fEnergyEst->SetVal(pow(10, f1.GetParameter(1))); 338 353 break; 339 340 } 354 } 355 341 356 fEnergyEst->SetReadyToSave(); 342 357 -
trunk/MagicSoft/Mars/mranforest/MRFEnergyEst.h
r7178 r7396 9 9 #include <TObjArray.h> 10 10 #endif 11 11 12 #ifndef ROOT_TArrayD 12 13 #include <TArrayD.h> 13 14 #endif 14 15 15 class MHMatrix;16 16 class MDataArray; 17 17 class MParameterD; 18 class MHMatrix; 18 19 19 20 class MRFEnergyEst : public MTask … … 26 27 kFit 27 28 }; 29 28 30 private: 31 static const TString gsDefName; //! Default Name 32 static const TString gsDefTitle; //! Default Title 33 29 34 Bool_t fDebug; // Debugging of eventloop while training on/off 30 35 … … 39 44 Int_t fNdSize; //! Training parameters 40 45 41 MHMatrix *fTestMatrix; //! Test Matrix 46 MHMatrix *fTestMatrix; //! Test Matrix used in Process (together with MMatrixLoop) 42 47 43 48 EstimationMode_t fEstimationMode; 49 50 private: 51 // MTask 52 Int_t PreProcess(MParList *plist); 53 Int_t Process(); 44 54 45 55 // MRFEnergyEst 46 56 Int_t ReadForests(MParList &plist); 47 57 48 // MTask49 Int_t PreProcess(MParList *plist);50 Int_t Process();51 52 58 // MParContainer 53 59 Int_t ReadEnv(const TEnv &env, TString prefix, Bool_t print); 60 61 // Train Interface 62 Int_t Train(const MHMatrix &n, const TArrayD &grid, Int_t ver=2); 54 63 55 64 public: … … 58 67 59 68 // Setter for estimation 60 void SetFileName(TString str) { fFileName = str; }61 void 69 void SetFileName(TString filename) { fFileName = filename; } 70 void SetEstimationMode(EstimationMode_t op) { fEstimationMode = op; } 62 71 63 72 // Setter for training 64 void SetNumTrees(Int_t n=-1){ fNumTrees = n; }65 void SetNdSize(Int_t n=-1){ fNdSize = n; }66 void SetNumTry(Int_t n=-1){ fNumTry = n; }67 void SetDebug(Bool_t b=kTRUE){ fDebug = b; }73 void SetNumTrees(UShort_t n=100) { fNumTrees = n; } 74 void SetNdSize(UShort_t n=5) { fNdSize = n; } 75 void SetNumTry(UShort_t n=0) { fNumTry = n; } 76 void SetDebug(Bool_t b=kTRUE) { fDebug = b; } 68 77 69 78 // Train Interface 70 Int_t Train(const MHMatrix &n, const TArrayD &grid); 79 Int_t TrainMultiRF(const MHMatrix &n, const TArrayD &grid) 80 { 81 return Train(n, grid, 0); 82 } 83 Int_t TrainSingleRF(const MHMatrix &n, const TArrayD &grid=TArrayD()) 84 { 85 return Train(n, grid, grid.GetSize()==0 ? 2 : 1); 86 } 71 87 72 88 // Test Interface -
trunk/MagicSoft/Mars/mranforest/MRanForest.cc
r7170 r7396 16 16 ! 17 17 ! 18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@ alwa02.physik.uni-siegen.de>18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de> 19 19 ! 20 ! Copyright: MAGIC Software Development, 2000-200 320 ! Copyright: MAGIC Software Development, 2000-2005 21 21 ! 22 22 ! … … 39 39 // split selection (which is subject to MRanTree::GrowTree()) 40 40 // 41 // Version 2:42 // + fUserVal43 //44 41 ///////////////////////////////////////////////////////////////////////////// 45 42 #include "MRanForest.h" 46 43 47 #include <T Matrix.h>48 #include <TRandom 3.h>44 #include <TVector.h> 45 #include <TRandom.h> 49 46 50 47 #include "MHMatrix.h" 51 48 #include "MRanTree.h" 49 #include "MData.h" 50 #include "MDataArray.h" 51 #include "MParList.h" 52 52 53 53 #include "MLog.h" … … 62 62 // Default constructor. 63 63 // 64 MRanForest::MRanForest(const char *name, const char *title) : f NumTrees(100), fRanTree(NULL),fUsePriors(kFALSE), fUserVal(-1)64 MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1) 65 65 { 66 66 fName = name ? name : "MRanForest"; … … 71 71 } 72 72 73 MRanForest::MRanForest(const MRanForest &rf) 74 { 75 // Copy constructor 76 fName = rf.fName; 77 fTitle = rf.fTitle; 78 79 fClassify = rf.fClassify; 80 fNumTrees = rf.fNumTrees; 81 fNumTry = rf.fNumTry; 82 fNdSize = rf.fNdSize; 83 fTreeNo = rf.fTreeNo; 84 fRanTree = NULL; 85 86 fRules=new MDataArray(); 87 fRules->Reset(); 88 89 MDataArray *newrules=rf.fRules; 90 91 for(Int_t i=0;i<newrules->GetNumEntries();i++) 92 { 93 MData &data=(*newrules)[i]; 94 fRules->AddEntry(data.GetRule()); 95 } 96 97 // trees 98 fForest=new TObjArray(); 99 fForest->SetOwner(kTRUE); 100 101 TObjArray *newforest=rf.fForest; 102 for(Int_t i=0;i<newforest->GetEntries();i++) 103 { 104 MRanTree *rantree=(MRanTree*)newforest->At(i); 105 106 MRanTree *newtree=new MRanTree(*rantree); 107 fForest->Add(newtree); 108 } 109 110 fHadTrue = rf.fHadTrue; 111 fHadEst = rf.fHadEst; 112 fDataSort = rf.fDataSort; 113 fDataRang = rf.fDataRang; 114 fClassPop = rf.fClassPop; 115 fWeight = rf.fWeight; 116 fTreeHad = rf.fTreeHad; 117 118 fNTimesOutBag = rf.fNTimesOutBag; 119 120 } 121 73 122 // -------------------------------------------------------------------------- 74 //75 123 // Destructor. 76 //77 124 MRanForest::~MRanForest() 78 125 { 79 126 delete fForest; 127 } 128 129 MRanTree *MRanForest::GetTree(Int_t i) 130 { 131 return (MRanTree*)(fForest->At(i)); 80 132 } 81 133 … … 88 140 } 89 141 90 void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam) 91 { 92 const Float_t sum=prior_gam+prior_had; 93 94 prior_gam/=sum; 95 prior_had/=sum; 96 97 fPrior[0]=prior_had; 98 fPrior[1]=prior_gam; 99 100 fUsePriors=kTRUE; 142 void MRanForest::SetNumTry(Int_t n) 143 { 144 fNumTry=TMath::Max(n,0); 145 } 146 147 void MRanForest::SetNdSize(Int_t n) 148 { 149 fNdSize=TMath::Max(n,1); 150 } 151 152 void MRanForest::SetWeights(const TArrayF &weights) 153 { 154 const int n=weights.GetSize(); 155 fWeight.Set(n); 156 fWeight=weights; 157 158 return; 159 } 160 161 void MRanForest::SetGrid(const TArrayD &grid) 162 { 163 const int n=grid.GetSize(); 164 165 for(int i=0;i<n-1;i++) 166 if(grid[i]>=grid[i+1]) 167 { 168 *fLog<<inf<<"Grid points must be in increasing order! Ignoring grid."<<endl; 169 return; 170 } 171 172 fGrid=grid; 173 174 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl; 175 //for(int i=0;i<n;i++) 176 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl; 177 178 return; 101 179 } 102 180 103 181 Int_t MRanForest::GetNumDim() const 104 182 { 105 return fGammas ? fGammas->GetM().GetNcols() : 0; 106 } 107 183 return fMatrix ? fMatrix->GetNcols() : 0; 184 } 185 186 Int_t MRanForest::GetNumData() const 187 { 188 return fMatrix ? fMatrix->GetNrows() : 0; 189 } 190 191 Int_t MRanForest::GetNclass() const 192 { 193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray()); 194 195 return int(fClass[maxidx])+1; 196 } 197 198 void MRanForest::PrepareClasses() 199 { 200 const int numdata=fHadTrue.GetSize(); 201 202 if(fGrid.GetSize()>0) 203 { 204 // classes given by grid 205 const int ngrid=fGrid.GetSize(); 206 207 for(int j=0;j<numdata;j++) 208 { 209 // Array is supposed to be sorted prior to this call. 210 // If match is found, function returns position of element. 211 // If no match found, function gives nearest element smaller 212 // than value. 213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]); 214 215 fClass[j] = k; 216 } 217 218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray()); 219 int min = fClass[minidx]; 220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min; 221 222 }else{ 223 // classes directly given 224 for (Int_t j=0;j<numdata;j++) 225 fClass[j] = int(fHadTrue[j]+0.5); 226 } 227 228 return; 229 } 230 231 /* 232 Bool_t MRanForest::PreProcess(MParList *plist) 233 { 234 if (!fRules) 235 { 236 *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl; 237 return kFALSE; 238 } 239 240 if (!fRules->PreProcess(plist)) 241 { 242 *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl; 243 return kFALSE; 244 } 245 246 return kTRUE; 247 } 248 */ 249 250 Double_t MRanForest::CalcHadroness() 251 { 252 TVector event; 253 *fRules >> event; 254 255 return CalcHadroness(event); 256 } 108 257 109 258 Double_t MRanForest::CalcHadroness(const TVector &event) … … 117 266 while ((tree=(MRanTree*)Next())) 118 267 { 119 fTreeHad[ntree]=tree->TreeHad(event); 120 hadroness+=fTreeHad[ntree]; 268 hadroness+=(fTreeHad[ntree]=tree->TreeHad(event)); 121 269 ntree++; 122 270 } … … 126 274 Bool_t MRanForest::AddTree(MRanTree *rantree=NULL) 127 275 { 128 if (rantree) 129 fRanTree=rantree; 130 if (!fRanTree) 276 fRanTree = rantree ? rantree:fRanTree; 277 278 if (!fRanTree) return kFALSE; 279 280 MRanTree *newtree=new MRanTree(*fRanTree); 281 fForest->Add(newtree); 282 283 return kTRUE; 284 } 285 286 Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist) 287 { 288 //------------------------------------------------------------------- 289 // access matrix, copy last column (target) preliminarily 290 // into fHadTrue 291 TMatrix mat_tmp = mat->GetM(); 292 int dim = mat_tmp.GetNcols(); 293 int numdata = mat_tmp.GetNrows(); 294 295 fMatrix=new TMatrix(mat_tmp); 296 297 fHadTrue.Set(numdata); 298 fHadTrue.Reset(0); 299 300 for (Int_t j=0;j<numdata;j++) 301 fHadTrue[j] = (*fMatrix)(j,dim-1); 302 303 // remove last col 304 fMatrix->ResizeTo(numdata,dim-1); 305 dim=fMatrix->GetNcols(); 306 307 //------------------------------------------------------------------- 308 // setup labels for classification/regression 309 fClass.Set(numdata); 310 fClass.Reset(0); 311 312 if(fClassify) PrepareClasses(); 313 314 //------------------------------------------------------------------- 315 // allocating and initializing arrays 316 fHadEst.Set(numdata); fHadEst.Reset(0); 317 fNTimesOutBag.Set(numdata); fNTimesOutBag.Reset(0); 318 fDataSort.Set(dim*numdata); fDataSort.Reset(0); 319 fDataRang.Set(dim*numdata); fDataRang.Reset(0); 320 321 if(fWeight.GetSize()!=numdata) 322 { 323 fWeight.Set(numdata); 324 fWeight.Reset(1.); 325 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl; 326 } 327 328 //------------------------------------------------------------------- 329 // setup rules to be used for classification/regression 330 MDataArray *allrules=(MDataArray*)mat->GetColumns(); 331 if(allrules==NULL) 332 { 333 *fLog << err <<"Rules of matrix == null, exiting"<< endl; 131 334 return kFALSE; 132 133 fForest->Add((MRanTree*)fRanTree->Clone()); 134 135 return kTRUE; 136 } 137 138 Int_t MRanForest::GetNumData() const 139 { 140 return fHadrons && fGammas ? fHadrons->GetM().GetNrows()+fGammas->GetM().GetNrows() : 0; 141 } 142 143 Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam) 144 { 145 // pointer to training data 146 fHadrons=mhad; 147 fGammas=mgam; 148 149 // determine data entries and dimension of Hillas-parameter space 150 //fNumHad=fHadrons->GetM().GetNrows(); 151 //fNumGam=fGammas->GetM().GetNrows(); 152 153 const Int_t dim = GetNumDim(); 154 155 if (dim!=fGammas->GetM().GetNcols()) 156 return kFALSE; 157 158 const Int_t numdata = GetNumData(); 159 160 // allocating and initializing arrays 161 fHadTrue.Set(numdata); 162 fHadTrue.Reset(); 163 fHadEst.Set(numdata); 164 165 fPrior.Set(2); 166 fClassPop.Set(2); 167 fWeight.Set(numdata); 168 fNTimesOutBag.Set(numdata); 169 fNTimesOutBag.Reset(); 170 171 fDataSort.Set(dim*numdata); 172 fDataRang.Set(dim*numdata); 173 174 // calculating class populations (= no. of gammas and hadrons) 175 fClassPop.Reset(); 176 for(Int_t n=0;n<numdata;n++) 177 fClassPop[fHadTrue[n]]++; 178 179 // setting weights taking into account priors 180 if (!fUsePriors) 181 fWeight.Reset(1.); 182 else 183 { 184 for(Int_t j=0;j<2;j++) 185 fPrior[j] *= (fClassPop[j]>=1) ? (Float_t)numdata/fClassPop[j]:0; 186 187 for(Int_t n=0;n<numdata;n++) 188 fWeight[n]=fPrior[fHadTrue[n]]; 189 } 190 191 // create fDataSort 192 CreateDataSort(); 335 } 336 337 fRules=new MDataArray(); fRules->Reset(); 338 TString target_rule; 339 340 for(Int_t i=0;i<dim+1;i++) 341 { 342 MData &data=(*allrules)[i]; 343 if(i<dim) 344 fRules->AddEntry(data.GetRule()); 345 else 346 target_rule=data.GetRule(); 347 } 348 349 *fLog << inf <<endl; 350 *fLog << inf <<"Setting up RF for training on target:"<<endl<<" "<<target_rule.Data()<<endl; 351 *fLog << inf <<"Following rules are used as input to RF:"<<endl; 352 353 for(Int_t i=0;i<dim;i++) 354 { 355 MData &data=(*fRules)[i]; 356 *fLog<<inf<<" "<<i<<") "<<data.GetRule()<<endl<<flush; 357 } 358 359 *fLog << inf <<endl; 360 361 //------------------------------------------------------------------- 362 // prepare (sort) data for fast optimization algorithm 363 if(!CreateDataSort()) return kFALSE; 364 365 //------------------------------------------------------------------- 366 // access and init tree container 367 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree"); 193 368 194 369 if(!fRanTree) … … 197 372 return kFALSE; 198 373 } 199 fRanTree->SetRules(fGammas->GetColumns()); 374 375 fRanTree->SetClassify(fClassify); 376 fRanTree->SetNdSize(fNdSize); 377 378 if(fNumTry==0) 379 { 380 double ddim = double(dim); 381 382 fNumTry=int(sqrt(ddim)+0.5); 383 *fLog<<inf<<endl; 384 *fLog<<inf<<"Set no. of trials to the recommended value of round("<<sqrt(ddim)<<") = "; 385 *fLog<<inf<<fNumTry<<endl; 386 387 } 388 fRanTree->SetNumTry(fNumTry); 389 390 *fLog<<inf<<endl; 391 *fLog<<inf<<"Following settings for the tree growing are used:"<<endl; 392 *fLog<<inf<<" Number of Trees : "<<fNumTrees<<endl; 393 *fLog<<inf<<" Number of Trials: "<<fNumTry<<endl; 394 *fLog<<inf<<" Final Node size : "<<fNdSize<<endl; 395 200 396 fTreeNo=0; 201 397 202 398 return kTRUE; 203 }204 205 void MRanForest::InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag)206 {207 for (Int_t ievt=from;ievt<to;ievt++)208 {209 if (jinbag[ievt]>0)210 continue;211 fHadEst[ievt] += fRanTree->TreeHad(m, ievt-from);212 fNTimesOutBag[ievt]++;213 }214 399 } 215 400 … … 224 409 fTreeNo++; 225 410 411 //------------------------------------------------------------------- 226 412 // initialize running output 413 414 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())]; 415 Bool_t calcResolution=(minfloat>0.001); 416 227 417 if (fTreeNo==1) 228 418 { 229 419 *fLog << inf << endl; 230 *fLog << underline; // << "1st col 2nd col" << endl; 231 *fLog << "no. of tree error in % (calulated using oob-data -> overestim. of error)" << endl; 420 *fLog << underline; 421 422 if(calcResolution) 423 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl; 424 else 425 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl; 426 // 12345678901234567890123456789012345678901234567890 232 427 } 233 428 234 429 const Int_t numdata = GetNumData(); 235 430 const Int_t nclass = GetNclass(); 431 432 //------------------------------------------------------------------- 236 433 // bootstrap aggregating (bagging) -> sampling with replacement: 237 // 238 // The integer k is randomly (uniformly) chosen from the set 239 // {0,1,...,fNumData-1}, which is the set of the index numbers of 240 // all events in the training sample 241 TArrayF classpopw(2); 434 435 TArrayF classpopw(nclass); 242 436 TArrayI jinbag(numdata); // Initialization includes filling with 0 243 437 TArrayF winbag(numdata); // Initialization includes filling with 0 244 438 439 float square=0; float mean=0; 440 245 441 for (Int_t n=0; n<numdata; n++) 246 442 { 443 // The integer k is randomly (uniformly) chosen from the set 444 // {0,1,...,numdata-1}, which is the set of the index numbers of 445 // all events in the training sample 446 247 447 const Int_t k = Int_t(gRandom->Rndm()*numdata); 248 448 249 classpopw[fHadTrue[k]]+=fWeight[k]; 449 if(fClassify) 450 classpopw[fClass[k]]+=fWeight[k]; 451 else 452 classpopw[0]+=fWeight[k]; 453 454 mean +=fHadTrue[k]*fWeight[k]; 455 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k]; 456 250 457 winbag[k]+=fWeight[k]; 251 458 jinbag[k]=1; 252 } 253 459 460 } 461 462 //------------------------------------------------------------------- 254 463 // modifying sorted-data array for in-bag data: 255 // 464 256 465 // In bagging procedure ca. 2/3 of all elements in the original 257 466 // training sample are used to build the in-bag data … … 261 470 ModifyDataSort(datsortinbag, ninbag, jinbag); 262 471 263 const TMatrix &hadrons=fHadrons->GetM(); 264 const TMatrix &gammas =fGammas->GetM(); 265 266 // growing single tree 267 fRanTree->GrowTree(hadrons,gammas,fHadTrue,datsortinbag, 268 fDataRang,classpopw,jinbag,winbag); 269 472 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square, 473 jinbag,winbag,nclass); 474 475 //------------------------------------------------------------------- 270 476 // error-estimates from out-of-bag data (oob data): 271 477 // … … 277 483 // determined from oob-data is underestimated, but can still be taken as upper limit. 278 484 279 const Int_t numhad = hadrons.GetNrows(); 280 InitHadEst(0, numhad, hadrons, jinbag); 281 InitHadEst(numhad, numdata, gammas, jinbag); 282 /* 283 for (Int_t ievt=0;ievt<numhad;ievt++) 284 { 285 if (jinbag[ievt]>0) 286 continue; 287 fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt); 485 for (Int_t ievt=0;ievt<numdata;ievt++) 486 { 487 if (jinbag[ievt]>0) continue; 488 489 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt); 288 490 fNTimesOutBag[ievt]++; 289 } 290 291 for (Int_t ievt=numhad;ievt<numdata;ievt++) 292 { 293 if (jinbag[ievt]>0) 294 continue; 295 fHadEst[ievt] += fRanTree->TreeHad(gammas, ievt-numhad); 296 fNTimesOutBag[ievt]++; 297 } 298 */ 491 492 } 493 299 494 Int_t n=0; 300 Double_t ferr=0; 495 double ferr=0; 496 301 497 for (Int_t ievt=0;ievt<numdata;ievt++) 302 if (fNTimesOutBag[ievt]!=0) 498 { 499 if(fNTimesOutBag[ievt]!=0) 303 500 { 304 const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt]; 501 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt]; 502 if(calcResolution) val/=fHadTrue[ievt]; 503 305 504 ferr += val*val; 306 505 n++; 307 506 } 308 507 } 309 508 ferr = TMath::Sqrt(ferr/n); 310 509 510 //------------------------------------------------------------------- 311 511 // give running output 312 *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl; 512 *fLog << inf << setw(5) << fTreeNo; 513 *fLog << inf << setw(20) << fRanTree->GetNumEndNodes(); 514 *fLog << inf << Form("%20.2f", ferr*100.); 515 *fLog << inf << endl; 313 516 314 517 // adding tree to forest … … 318 521 } 319 522 320 void MRanForest::CreateDataSort() 321 { 322 // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData, 323 // which does actually not exist) are sorted from lowest to highest. 523 Bool_t MRanForest::CreateDataSort() 524 { 525 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs. 526 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r: 527 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.). 324 528 // 325 // 326 // fHadrons(0,0) ... fHadrons(0,nhad-1) fGammas(0,0) ... fGammas(0,ngam-1) 327 // . . . . 328 // fData(m,n) = . . . . 329 // . . . . 330 // fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1) 331 // 332 // 333 // Then fDataSort(m,n) is the event number in which fData(m,n) occurs. 334 // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest 335 // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging). 529 // There may be more then 1 event with rang r (due to bagging). 530 336 531 const Int_t numdata = GetNumData(); 532 const Int_t dim = GetNumDim(); 337 533 338 534 TArrayF v(numdata); 339 535 TArrayI isort(numdata); 340 536 341 const TMatrix &hadrons=fHadrons->GetM(); 342 const TMatrix &gammas=fGammas->GetM(); 343 344 const Int_t numgam = gammas.GetNrows(); 345 const Int_t numhad = hadrons.GetNrows(); 346 347 for (Int_t j=0;j<numhad;j++) 348 fHadTrue[j]=1; 349 350 for (Int_t j=0;j<numgam;j++) 351 fHadTrue[j+numhad]=0; 352 353 const Int_t dim = GetNumDim(); 537 354 538 for (Int_t mvar=0;mvar<dim;mvar++) 355 539 { 356 for(Int_t n=0;n<numhad;n++) 540 541 for(Int_t n=0;n<numdata;n++) 357 542 { 358 v[n]= hadrons(n,mvar);543 v[n]=(*fMatrix)(n,mvar); 359 544 isort[n]=n; 360 } 361 362 for(Int_t n=0;n<numgam;n++) 363 { 364 v[n+numhad]=gammas(n,mvar); 365 isort[n+numhad]=n; 545 546 if(TMath::IsNaN(v[n])) 547 { 548 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar; 549 *fLog << err <<" has the value NaN."<<endl; 550 return kFALSE; 551 } 366 552 } 367 553 … … 371 557 // of that v[n], which is the n-th from the lowest (assume the original 372 558 // event numbers are 0,1,...). 559 560 // control sorting 561 for(int n=1;n<numdata;n++) 562 if(v[isort[n-1]]>v[isort[n]]) 563 { 564 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar; 565 *fLog << err <<" not at correct sorting position."<<endl; 566 return kFALSE; 567 } 373 568 374 569 for(Int_t n=0;n<numdata-1;n++) … … 388 583 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1]; 389 584 } 585 return kTRUE; 390 586 } 391 587 -
trunk/MagicSoft/Mars/mranforest/MRanForest.h
r7170 r7396 4 4 #ifndef MARS_MParContainer 5 5 #include "MParContainer.h" 6 #endif7 8 #ifndef MARS_MRanTree9 #include "MRanTree.h"10 #endif11 12 #ifndef MARS_MDataArray13 #include "MDataArray.h"14 6 #endif 15 7 … … 26 18 #endif 27 19 28 #ifndef ROOT_TObjArray 29 #include <TObjArray.h> 30 #endif 20 class TMatrix; 21 class TVector; 22 class TObjArray; 31 23 32 #ifndef ROOT_TRandom 33 #include <TRandom.h> 34 #endif 35 24 class MRanTree; 25 class MDataArray; 36 26 class MHMatrix; 37 class MRanTree; 38 class TVector; 39 class TMatrix; 27 class MParList; 40 28 41 29 class MRanForest : public MParContainer 42 30 { 43 31 private: 44 Int_t fNumTrees; 45 Int_t fTreeNo; //! 32 Int_t fClassify; 46 33 47 MRanTree *fRanTree; //! 48 TObjArray *fForest; 34 Int_t fNumTrees; // Number of trees 35 Int_t fNumTry; // Number of tries 36 Int_t fNdSize; // Size of node 37 38 Int_t fTreeNo; //! Number of tree 39 40 MRanTree *fRanTree; //! Pointer to some tree 41 MDataArray *fRules; //! Pointer to corresponding rules 42 TObjArray *fForest; // Array containing forest 49 43 50 44 // training data 51 MHMatrix *fGammas; //! 52 MHMatrix *fHadrons; //! 45 TMatrix *fMatrix; //! 53 46 54 47 // true and estimated hadronness 55 TArrayI fHadTrue; //! 56 TArrayF fHadEst; //! 48 TArrayI fClass; //! 49 TArrayD fGrid; //! 50 TArrayF fHadTrue; //! 51 TArrayF fHadEst; //! 57 52 58 53 // data sorted according to parameters 59 TArrayI fDataSort; //!60 TArrayI fDataRang; //!61 TArrayI fClassPop; //!54 TArrayI fDataSort; //! 55 TArrayI fDataRang; //! 56 TArrayI fClassPop; //! 62 57 63 58 // weights 64 Bool_t fUsePriors; //! 65 TArrayF fPrior; //! 66 TArrayF fWeight; //! 67 TArrayI fNTimesOutBag;//! 59 TArrayF fWeight; //! 60 TArrayI fNTimesOutBag; //! 68 61 69 62 // estimates for classification error of growing forest 70 TArrayD fTreeHad; //63 TArrayD fTreeHad; // Hadronness values 71 64 72 Double_t fUserVal; 73 74 void InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag); 65 Double_t fUserVal; // A user value describing this tree (eg E-mc) 75 66 76 67 protected: 77 68 // create and modify (->due to bagging) fDataSort 78 voidCreateDataSort();69 Bool_t CreateDataSort(); 79 70 void ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag); 80 71 81 72 public: 82 73 MRanForest(const char *name=NULL, const char *title=NULL); 74 MRanForest(const MRanForest &rf); 75 83 76 ~MRanForest(); 84 77 85 // initialize forest86 void Set Priors(Float_t prior_had, Float_t prior_gam);78 void SetGrid(const TArrayD &grid); 79 void SetWeights(const TArrayF &weights); 87 80 void SetNumTrees(Int_t n); 88 81 89 // tree growing 90 //void SetupForest(); 91 Bool_t SetupGrow(MHMatrix *mhad,MHMatrix *mgam); 82 void SetNumTry(Int_t n); 83 void SetNdSize(Int_t n); 84 85 void SetClassify(Int_t n){ fClassify=n; } 86 void PrepareClasses(); 87 88 // tree growing 89 Bool_t SetupGrow(MHMatrix *mat,MParList *plist); 92 90 Bool_t GrowForest(); 93 void SetCurTree(MRanTree *rantree) { fRanTree=rantree; }91 void SetCurTree(MRanTree *rantree) { fRanTree=rantree; } 94 92 Bool_t AddTree(MRanTree *rantree); 95 void SetUserVal(Double_t d) { fUserVal = d; }93 void SetUserVal(Double_t d) { fUserVal = d; } 96 94 97 95 // getter methods 98 96 TObjArray *GetForest() { return fForest; } 99 97 MRanTree *GetCurTree() { return fRanTree; } 100 MRanTree *GetTree(Int_t i) { return (MRanTree*)(fForest->At(i)); } 101 MDataArray *GetRules() { return ((MRanTree*)(fForest->At(0)))->GetRules(); } 98 MRanTree *GetTree(Int_t i); 99 MDataArray *GetRules() { return fRules; } 100 102 101 103 102 Int_t GetNumTrees() const { return fNumTrees; } 104 103 Int_t GetNumData() const; 105 104 Int_t GetNumDim() const; 105 Int_t GetNclass() const; 106 106 Double_t GetTreeHad(Int_t i) const { return fTreeHad.At(i); } 107 107 Double_t GetUserVal() const { return fUserVal; } … … 109 109 // use forest to calculate hadronness of event 110 110 Double_t CalcHadroness(const TVector &event); 111 Double_t CalcHadroness(); 111 112 112 113 Bool_t AsciiWrite(ostream &out) const; 113 114 114 ClassDef(MRanForest, 2) // Storage container for tree structure115 ClassDef(MRanForest, 1) // Storage container for tree structure 115 116 }; 116 117 -
trunk/MagicSoft/Mars/mranforest/MRanForestGrow.cc
r7130 r7396 16 16 ! 17 17 ! 18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@ alwa02.physik.uni-siegen.de>18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de> 19 19 ! 20 ! Copyright: MAGIC Software Development, 2000-200 320 ! Copyright: MAGIC Software Development, 2000-2005 21 21 ! 22 22 ! … … 24 24 25 25 ///////////////////////////////////////////////////////////////////////////// 26 // //27 // MRanForestGrow //28 // //29 // Grows a random forest. //30 // //26 // 27 // MRanForestGrow 28 // 29 // Grows a random forest. 30 // 31 31 ///////////////////////////////////////////////////////////////////////////// 32 32 #include "MRanForestGrow.h" … … 38 38 39 39 #include "MParList.h" 40 41 #include "MRanTree.h"42 40 #include "MRanForest.h" 43 41 … … 46 44 using namespace std; 47 45 48 static const TStringgsDefName = "MRead";49 static const TString gsDefTitle = "Tree Classification Loop 1/2";46 const TString MRanForestGrow::gsDefName = "MRead"; 47 const TString MRanForestGrow::gsDefTitle = "Task to train a random forst"; 50 48 51 // --------------------------------------------------------------------------52 //53 // Setup histograms and the number of distances which are used for54 // avaraging in CalcDist55 //56 49 MRanForestGrow::MRanForestGrow(const char *name, const char *title) 57 50 { 58 //59 51 // set the name and title of this object 60 // 52 61 53 fName = name ? name : gsDefName.Data(); 62 54 fTitle = title ? title : gsDefTitle.Data(); 63 55 64 SetNumTrees();65 SetNumTry();66 SetNdSize();56 // SetNumTrees(); 57 // SetNumTry(); 58 // SetNdSize(); 67 59 } 68 60 69 // --------------------------------------------------------------------------70 //71 // Needs:72 // - MatrixGammas [MHMatrix]73 // - MatrixHadrons {MHMatrix]74 // - MHadroness75 // - all data containers used to build the matrixes76 //77 // The matrix object can be filles using MFillH. And must be of the same78 // number of columns (with the same meaning).79 //80 61 Int_t MRanForestGrow::PreProcess(MParList *plist) 81 62 { 82 fM Gammas = (MHMatrix*)plist->FindObject("MatrixGammas", "MHMatrix");83 if (!fM Gammas)63 fMatrix = (MHMatrix*)plist->FindObject("MatrixTrain", "MHMatrix"); 64 if (!fMatrix) 84 65 { 85 *fLog << err << dbginf << "MatrixGammas [MHMatrix] not found... aborting." << endl; 86 return kFALSE; 87 } 88 89 fMHadrons = (MHMatrix*)plist->FindObject("MatrixHadrons", "MHMatrix"); 90 if (!fMHadrons) 91 { 92 *fLog << err << dbginf << "MatrixHadrons [MHMatrix] not found... aborting." << endl; 93 return kFALSE; 94 } 95 96 if (fMGammas->GetM().GetNcols() != fMHadrons->GetM().GetNcols()) 97 { 98 *fLog << err << dbginf << "Error matrices have different numbers of columns... aborting." << endl; 99 return kFALSE; 100 } 101 102 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree"); 103 if (!fRanTree) 104 { 105 *fLog << err << dbginf << "MRanTree not found... aborting." << endl; 66 *fLog << err << dbginf << "MatrixTrain [MHMatrix] not found... aborting." << endl; 106 67 return kFALSE; 107 68 } … … 114 75 } 115 76 116 fRanTree->SetNumTry(fNumTry); 117 fRanTree->SetNdSize(fNdSize); 118 fRanForest->SetCurTree(fRanTree); 119 fRanForest->SetNumTrees(fNumTrees); 77 // fRanForest->SetNumTry(fNumTry); 78 // fRanForest->SetNdSize(fNdSize); 79 // fRanForest->SetNumTrees(fNumTrees); 120 80 121 return fRanForest->SetupGrow(fM Hadrons,fMGammas);81 return fRanForest->SetupGrow(fMatrix,plist); 122 82 } 123 83 124 // --------------------------------------------------------------------------125 //126 //127 84 Int_t MRanForestGrow::Process() 128 85 { 129 const Bool_t not_last=fRanForest->GrowForest(); 130 131 fRanTree->SetReadyToSave(); 132 133 return not_last; 86 return fRanForest->GrowForest(); 134 87 } 135 88 136 89 Int_t MRanForestGrow::PostProcess() 137 90 { 138 fRanTree->SetReadyToSave();139 91 fRanForest->SetReadyToSave(); 140 92 -
trunk/MagicSoft/Mars/mranforest/MRanForestGrow.h
r7130 r7396 9 9 class MParList; 10 10 class MRanForest; 11 class MRanTree;12 11 13 12 class MRanForestGrow : public MRead 14 13 { 15 14 private: 16 MRanTree *fRanTree; 15 static const TString gsDefName; 16 static const TString gsDefTitle; 17 18 // Int_t fNumTrees; 19 // Int_t fNumTry; 20 // Int_t fNdSize; 21 17 22 MRanForest *fRanForest; 18 MHMatrix *fMGammas; //! Gammas describing matrix 19 MHMatrix *fMHadrons; //! Hadrons (non gammas) describing matrix 20 21 Int_t fNumTrees; 22 Int_t fNumTry; 23 Int_t fNdSize; 23 MHMatrix *fMatrix; //! matrix with events 24 24 25 25 Int_t PreProcess(MParList *pList); … … 34 34 MRanForestGrow(const char *name=NULL, const char *title=NULL); 35 35 36 void SetNumTrees(Int_t n=-1) { fNumTrees=n>0?n:100; }37 void SetNumTry(Int_t n=-1) { fNumTry =n>0?n: 3; }38 void SetNdSize(Int_t n=-1) { fNdSize =n>0?n: 1; }36 // void SetNumTrees(Int_t n=-1) { fNumTrees=n>0?n:100; } 37 // void SetNumTry(Int_t n=-1) { fNumTry =n>0?n: 3; } 38 // void SetNdSize(Int_t n=-1) { fNdSize =n>0?n: 1; } 39 39 40 40 ClassDef(MRanForestGrow, 0) // Task to grow a random forest -
trunk/MagicSoft/Mars/mranforest/MRanTree.cc
r7142 r7396 16 16 ! 17 17 ! 18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@ alwa02.physik.uni-siegen.de>18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de> 19 19 ! 20 ! Copyright: MAGIC Software Development, 2000-200 320 ! Copyright: MAGIC Software Development, 2000-2005 21 21 ! 22 22 ! … … 38 38 #include <TRandom.h> 39 39 40 #include "MDataArray.h"41 42 40 #include "MLog.h" 43 41 #include "MLogManip.h" … … 47 45 using namespace std; 48 46 47 49 48 // -------------------------------------------------------------------------- 50 //51 49 // Default constructor. 52 50 // 53 MRanTree::MRanTree(const char *name, const char *title):f NdSize(0), fNumTry(3), fData(NULL)51 MRanTree::MRanTree(const char *name, const char *title):fClassify(1),fNdSize(0), fNumTry(3) 54 52 { 55 53 56 54 fName = name ? name : "MRanTree"; 57 55 fTitle = title ? title : "Storage container for structure of a single tree"; 56 } 57 58 // -------------------------------------------------------------------------- 59 // Copy constructor 60 // 61 MRanTree::MRanTree(const MRanTree &tree) 62 { 63 fName = tree.fName; 64 fTitle = tree.fTitle; 65 66 fClassify = tree.fClassify; 67 fNdSize = tree.fNdSize; 68 fNumTry = tree.fNumTry; 69 70 fNumNodes = tree.fNumNodes; 71 fNumEndNodes = tree.fNumEndNodes; 72 73 fBestVar = tree.fBestVar; 74 fTreeMap1 = tree.fTreeMap1; 75 fTreeMap2 = tree.fTreeMap2; 76 fBestSplit = tree.fBestSplit; 77 fGiniDec = tree.fGiniDec; 58 78 } 59 79 … … 75 95 } 76 96 77 void MRanTree::GrowTree( const TMatrix &mhad, const TMatrix &mgam,78 const TArrayI &hadtrue, TArrayI &datasort,79 const TArrayI &datarang, TArrayF &tclasspop, TArrayI &jinbag,80 const TArrayF &winbag)97 void MRanTree::GrowTree(TMatrix *mat, const TArrayF &hadtrue, const TArrayI &idclass, 98 TArrayI &datasort, const TArrayI &datarang, TArrayF &tclasspop, 99 float &mean, float &square, TArrayI &jinbag, const TArrayF &winbag, 100 const int nclass) 81 101 { 82 102 // arrays have to be initialized with generous size, so number of total nodes (nrnodes) 83 103 // is estimated for worst case 84 const Int_t numdim =m had.GetNcols();104 const Int_t numdim =mat->GetNcols(); 85 105 const Int_t numdata=winbag.GetSize(); 86 106 const Int_t nrnodes=2*numdata+1; … … 88 108 // number of events in bootstrap sample 89 109 Int_t ninbag=0; 90 for (Int_t n=0;n<numdata;n++) 91 if(jinbag[n]==1) ninbag++; 92 93 TArrayI bestsplit(nrnodes); 94 TArrayI bestsplitnext(nrnodes); 95 96 fBestVar.Set(nrnodes); 97 fTreeMap1.Set(nrnodes); 98 fTreeMap2.Set(nrnodes); 99 fBestSplit.Set(nrnodes); 100 101 fTreeMap1.Reset(); 102 fTreeMap2.Reset(); 103 fBestSplit.Reset(); 104 105 fGiniDec.Set(numdim); 106 fGiniDec.Reset(); 110 for (Int_t n=0;n<numdata;n++) if(jinbag[n]==1) ninbag++; 111 112 TArrayI bestsplit(nrnodes); bestsplit.Reset(0); 113 TArrayI bestsplitnext(nrnodes); bestsplitnext.Reset(0); 114 115 fBestVar.Set(nrnodes); fBestVar.Reset(0); 116 fTreeMap1.Set(nrnodes); fTreeMap1.Reset(0); 117 fTreeMap2.Set(nrnodes); fTreeMap2.Reset(0); 118 fBestSplit.Set(nrnodes); fBestSplit.Reset(0); 119 fGiniDec.Set(numdim); fGiniDec.Reset(0); 120 121 122 if(fClassify) 123 FindBestSplit=&MRanTree::FindBestSplitGini; 124 else 125 FindBestSplit=&MRanTree::FindBestSplitSigma; 107 126 108 127 // tree growing 109 BuildTree(datasort,datarang,hadtrue, bestsplit,110 bestsplitnext,tclasspop,winbag,ninbag);128 BuildTree(datasort,datarang,hadtrue,idclass,bestsplit, bestsplitnext, 129 tclasspop,mean,square,winbag,ninbag,nclass); 111 130 112 131 // post processing, determine cut (or split) values fBestSplit 113 Int_t nhad=mhad.GetNrows();114 115 132 for(Int_t k=0; k<nrnodes; k++) 116 133 { … … 122 139 const Int_t &msp =fBestVar[k]; 123 140 124 fBestSplit[k] = bsp<nhad ? mhad(bsp, msp):mgam(bsp-nhad, msp);125 fBestSplit[k] += bspn<nhad ? mhad(bspn,msp):mgam(bspn-nhad,msp);126 fBestSplit[k] /= 2 ;141 fBestSplit[k] = (*mat)(bsp, msp); 142 fBestSplit[k] += (*mat)(bspn,msp); 143 fBestSplit[k] /= 2.; 127 144 } 128 145 … … 134 151 } 135 152 136 Int_t MRanTree::FindBestSplit(const TArrayI &datasort,const TArrayI &datarang, 137 const TArrayI &hadtrue,Int_t ndstart,Int_t ndend,TArrayF &tclasspop, 138 Int_t &msplit,Float_t &decsplit,Int_t &nbest, 139 const TArrayF &winbag) 153 int MRanTree::FindBestSplitGini(const TArrayI &datasort,const TArrayI &datarang, 154 const TArrayF &hadtrue,const TArrayI &idclass, 155 Int_t ndstart,Int_t ndend, TArrayF &tclasspop, 156 float &mean, float &square, Int_t &msplit, 157 Float_t &decsplit,Int_t &nbest, const TArrayF &winbag, 158 const int nclass) 140 159 { 141 160 const Int_t nrnodes = fBestSplit.GetSize(); … … 143 162 const Int_t mdim = fGiniDec.GetSize(); 144 163 145 // weighted class populations after split 146 TArrayF wc(2); 147 TArrayF wr(2); // right node 148 149 // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...) 164 TArrayF wr(nclass); wr.Reset(0);// right node 165 166 // For the best split, msplit is the index of the variable (e.g Hillas par., 167 // zenith angle ,...) 150 168 // split on. decsplit is the decreae in impurity measured by Gini-index. 151 169 // nsplit is the case number of value of msplit split on, … … 158 176 Double_t pno=0; 159 177 Double_t pdo=0; 160 for (Int_t j=0; j<2; j++) 178 179 for (Int_t j=0; j<nclass; j++) 161 180 { 162 181 pno+=tclasspop[j]*tclasspop[j]; … … 165 184 166 185 const Double_t crit0=pno/pdo; 167 Int_t jstat=0;168 186 169 187 // start main loop through variables to find best split, … … 184 202 Double_t rld=0; 185 203 186 TArrayF wl( 2); // left node204 TArrayF wl(nclass); wl.Reset(0.);// left node //nclass 187 205 wr = tclasspop; 188 206 189 207 Double_t critvar=-1.0e20; 190 208 for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++) 209 { 210 const Int_t &nc = datasort[mn+nsp]; 211 const Int_t &k = idclass[nc]; 212 const Float_t &u = winbag[nc]; 213 214 // do classification, Gini index as split rule 215 rln+=u*(2*wl[k]+u); 216 rrn+=u*(-2*wr[k]+u); 217 218 rld+=u; 219 rrd-=u; 220 221 wl[k]+=u; 222 wr[k]-=u; 223 224 if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]]) 225 continue; 226 227 if (TMath::Min(rrd,rld)<=1.0e-5) 228 continue; 229 230 const Double_t crit=(rln/rld)+(rrn/rrd); 231 232 233 if (crit<=critvar) continue; 234 235 nbestvar=nsp; 236 critvar=crit; 237 } 238 239 if (critvar<=critmax) continue; 240 241 msplit=mvar; 242 nbest=nbestvar; 243 critmax=critvar; 244 } 245 246 decsplit=critmax-crit0; 247 248 return critmax<-1.0e10 ? 1 : 0; 249 } 250 251 int MRanTree::FindBestSplitSigma(const TArrayI &datasort,const TArrayI &datarang, 252 const TArrayF &hadtrue, const TArrayI &idclass, 253 Int_t ndstart,Int_t ndend, TArrayF &tclasspop, 254 float &mean, float &square, Int_t &msplit, 255 Float_t &decsplit,Int_t &nbest, const TArrayF &winbag, 256 const int nclass) 257 { 258 const Int_t nrnodes = fBestSplit.GetSize(); 259 const Int_t numdata = (nrnodes-1)/2; 260 const Int_t mdim = fGiniDec.GetSize(); 261 262 float wr=0;// right node 263 264 // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...) 265 // split on. decsplit is the decreae in impurity measured by Gini-index. 266 // nsplit is the case number of value of msplit split on, 267 // and nsplitnext is the case number of the next larger value of msplit. 268 269 Int_t nbestvar=0; 270 271 // compute initial values of numerator and denominator of split-index, 272 273 // resolution 274 //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0]; 275 //Double_t pdo= (tclasspop[0]-1.)*mean*mean; 276 277 // n*resolution 278 //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0]; 279 //Double_t pdo= mean*mean; 280 281 // variance 282 //Double_t pno=-(square-mean*mean/tclasspop[0]); 283 //Double_t pdo= (tclasspop[0]-1.); 284 285 // n*variance 286 Double_t pno= (square-mean*mean/tclasspop[0]); 287 Double_t pdo= 1.; 288 289 // 1./(n*variance) 290 //Double_t pno= 1.; 291 //Double_t pdo= (square-mean*mean/tclasspop[0]); 292 293 const Double_t crit0=pno/pdo; 294 295 // start main loop through variables to find best split, 296 297 Double_t critmin=1.0e40; 298 299 // random split selection, number of trials = fNumTry 300 for (Int_t mt=0; mt<fNumTry; mt++) 301 { 302 const Int_t mvar=Int_t(gRandom->Rndm()*mdim); 303 const Int_t mn = mvar*numdata; 304 305 Double_t rrn=0, rrd=0, rln=0, rld=0; 306 Double_t esumr=0, esuml=0, e2sumr=0,e2suml=0; 307 308 esumr =mean; 309 e2sumr=square; 310 esuml =0; 311 e2suml=0; 312 313 float wl=0.;// left node 314 wr = tclasspop[0]; 315 316 Double_t critvar=critmin; 191 317 for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++) 192 318 { 193 319 const Int_t &nc=datasort[mn+nsp]; 194 const Int_t &k=hadtrue[nc]; 195 320 const Float_t &f=hadtrue[nc];; 196 321 const Float_t &u=winbag[nc]; 197 322 198 rln+=u*(2*wl[k]+u); 199 rrn+=u*(-2*wr[k]+u); 200 rld+=u; 201 rrd-=u; 202 203 wl[k]+=u; 204 wr[k]-=u; 323 e2sumr-=u*f*f; 324 esumr -=u*f; 325 wr -=u; 326 327 //------------------------------------------- 328 // resolution 329 //rrn=(wr*e2sumr-esumr*esumr)*wr; 330 //rrd=(wr-1.)*esumr*esumr; 331 332 // resolution times n 333 //rrn=(wr*e2sumr-esumr*esumr)*wr; 334 //rrd=esumr*esumr; 335 336 // sigma 337 //rrn=(e2sumr-esumr*esumr/wr); 338 //rrd=(wr-1.); 339 340 // sigma times n 341 rrn=(e2sumr-esumr*esumr/wr); 342 rrd=1.; 343 344 // 1./(n*variance) 345 //rrn=1.; 346 //rrd=(e2sumr-esumr*esumr/wr); 347 //------------------------------------------- 348 349 e2suml+=u*f*f; 350 esuml +=u*f; 351 wl +=u; 352 353 //------------------------------------------- 354 // resolution 355 //rln=(wl*e2suml-esuml*esuml)*wl; 356 //rld=(wl-1.)*esuml*esuml; 357 358 // resolution times n 359 //rln=(wl*e2suml-esuml*esuml)*wl; 360 //rld=esuml*esuml; 361 362 // sigma 363 //rln=(e2suml-esuml*esuml/wl); 364 //rld=(wl-1.); 365 366 // sigma times n 367 rln=(e2suml-esuml*esuml/wl); 368 rld=1.; 369 370 // 1./(n*variance) 371 //rln=1.; 372 //rld=(e2suml-esuml*esuml/wl); 373 //------------------------------------------- 205 374 206 375 if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]]) 207 376 continue; 377 208 378 if (TMath::Min(rrd,rld)<=1.0e-5) 209 379 continue; 210 380 211 381 const Double_t crit=(rln/rld)+(rrn/rrd); 212 if (crit<=critvar) 213 382 383 if (crit>=critvar) continue; 214 384 215 385 nbestvar=nsp; … … 217 387 } 218 388 219 if (critvar<=critmax) 220 continue; 389 if (critvar>=critmin) continue; 221 390 222 391 msplit=mvar; 223 392 nbest=nbestvar; 224 critmax=critvar; 225 } 226 227 decsplit=critmax-crit0; 228 229 return critmax<-1.0e10 ? 1 : jstat; 230 } 231 232 void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart, 233 Int_t ndend,TArrayI &idmove,TArrayI &ncase,Int_t msplit, 393 critmin=critvar; 394 } 395 396 decsplit=crit0-critmin; 397 398 //return critmin>1.0e20 ? 1 : 0; 399 return decsplit<0 ? 1 : 0; 400 } 401 402 void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart, Int_t ndend, 403 TArrayI &idmove,TArrayI &ncase,Int_t msplit, 234 404 Int_t nbest,Int_t &ndendl) 235 405 { … … 240 410 const Int_t mdim = fGiniDec.GetSize(); 241 411 242 TArrayI tdatasort(numdata); 412 TArrayI tdatasort(numdata); tdatasort.Reset(0); 243 413 244 414 // compute idmove = indicator of case nos. going left 245 246 415 for (Int_t nsp=ndstart;nsp<=ndend;nsp++) 247 416 { … … 252 421 253 422 // shift case. nos. right and left for numerical variables. 254 255 423 for(Int_t msh=0;msh<mdim;msh++) 256 424 { … … 280 448 } 281 449 282 void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang, 283 const TArrayI & hadtrue, TArrayI &bestsplit,284 TArray I &bestsplitnext, TArrayF &tclasspop,285 const TArrayF &winbag, Int_t ninbag)450 void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang, const TArrayF &hadtrue, 451 const TArrayI &idclass, TArrayI &bestsplit, TArrayI &bestsplitnext, 452 TArrayF &tclasspop, float &tmean, float &tsquare, const TArrayF &winbag, 453 Int_t ninbag, const int nclass) 286 454 { 287 455 // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData. … … 302 470 const Int_t numdata = (nrnodes-1)/2; 303 471 304 TArrayI nodepop(nrnodes); 305 TArrayI nodestart(nrnodes); 306 TArrayI parent(nrnodes); 307 308 TArrayI ncase(numdata); 309 TArrayI idmove(numdata); 310 TArrayI iv(mdim); 311 312 TArrayF classpop(nrnodes* 2);313 TArrayI nodestatus(nrnodes); 314 315 for (Int_t j=0;j< 2;j++)472 TArrayI nodepop(nrnodes); nodepop.Reset(0); 473 TArrayI nodestart(nrnodes); nodestart.Reset(0); 474 TArrayI parent(nrnodes); parent.Reset(0); 475 476 TArrayI ncase(numdata); ncase.Reset(0); 477 TArrayI idmove(numdata); idmove.Reset(0); 478 TArrayI iv(mdim); iv.Reset(0); 479 480 TArrayF classpop(nrnodes*nclass); classpop.Reset(0.);//nclass 481 TArrayI nodestatus(nrnodes); nodestatus.Reset(0); 482 483 for (Int_t j=0;j<nclass;j++) 316 484 classpop[j*nrnodes+0]=tclasspop[j]; 485 486 TArrayF mean(nrnodes); mean.Reset(0.); 487 TArrayF square(nrnodes); square.Reset(0.); 488 489 mean[0]=tmean; 490 square[0]=tsquare; 491 317 492 318 493 Int_t ncur=0; … … 330 505 const Int_t ndstart=nodestart[kbuild]; 331 506 const Int_t ndend=ndstart+nodepop[kbuild]-1; 332 for (Int_t j=0;j<2;j++) 507 508 for (Int_t j=0;j<nclass;j++) 333 509 tclasspop[j]=classpop[j*nrnodes+kbuild]; 510 511 tmean=mean[kbuild]; 512 tsquare=square[kbuild]; 334 513 335 514 Int_t msplit, nbest; 336 515 Float_t decsplit=0; 337 const Int_t jstat=FindBestSplit(datasort,datarang,hadtrue, 338 ndstart,ndend,tclasspop,msplit, 339 decsplit,nbest,winbag); 340 341 if (jstat==1) 516 517 if ((*this.*FindBestSplit)(datasort,datarang,hadtrue,idclass,ndstart, 518 ndend, tclasspop,tmean, tsquare,msplit,decsplit, 519 nbest,winbag,nclass)) 342 520 { 343 521 nodestatus[kbuild]=-1; … … 356 534 357 535 // leftnode no.= ncur+1, rightnode no. = ncur+2. 358 359 536 nodepop[ncur+1]=ndendl-ndstart+1; 360 537 nodepop[ncur+2]=ndend-ndendl; … … 363 540 364 541 // find class populations in both nodes 365 366 542 for (Int_t n=ndstart;n<=ndendl;n++) 367 543 { 368 544 const Int_t &nc=ncase[n]; 369 const Int_t &j=hadtrue[nc]; 545 const int j=idclass[nc]; 546 547 mean[ncur+1]+=hadtrue[nc]*winbag[nc]; 548 square[ncur+1]+=hadtrue[nc]*hadtrue[nc]*winbag[nc]; 549 370 550 classpop[j*nrnodes+ncur+1]+=winbag[nc]; 371 551 } … … 374 554 { 375 555 const Int_t &nc=ncase[n]; 376 const Int_t &j=hadtrue[nc]; 556 const int j=idclass[nc]; 557 558 mean[ncur+2] +=hadtrue[nc]*winbag[nc]; 559 square[ncur+2]+=hadtrue[nc]*hadtrue[nc]*winbag[nc]; 560 377 561 classpop[j*nrnodes+ncur+2]+=winbag[nc]; 378 562 } … … 385 569 if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1; 386 570 571 387 572 Double_t popt1=0; 388 573 Double_t popt2=0; 389 for (Int_t j=0;j< 2;j++)574 for (Int_t j=0;j<nclass;j++) 390 575 { 391 576 popt1+=classpop[j*nrnodes+ncur+1]; … … 393 578 } 394 579 395 for (Int_t j=0;j<2;j++)580 if(fClassify) 396 581 { 397 if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1; 398 if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1; 582 // check if only members of one class in node 583 for (Int_t j=0;j<nclass;j++) 584 { 585 if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1; 586 if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1; 587 } 399 588 } 400 589 … … 421 610 { 422 611 fNumEndNodes++; 612 423 613 Double_t pp=0; 424 for (Int_t j=0;j< 2;j++)614 for (Int_t j=0;j<nclass;j++) 425 615 { 426 616 if(classpop[j*nrnodes+kn]>pp) 427 617 { 428 618 // class + status of node kn coded into fBestVar[kn] 429 fBestVar[kn]=j- 2;619 fBestVar[kn]=j-nclass; 430 620 pp=classpop[j*nrnodes+kn]; 431 621 } 432 622 } 433 fBestSplit[kn] =classpop[1*nrnodes+kn]; 434 fBestSplit[kn]/=(classpop[0*nrnodes+kn]+classpop[1*nrnodes+kn]); 623 624 float sum=0; 625 for(int i=0;i<nclass;i++) sum+=classpop[i*nrnodes+kn]; 626 627 fBestSplit[kn]=mean[kn]/sum; 435 628 } 436 629 } 437 630 438 void MRanTree::SetRules(MDataArray *rules)439 {440 fData=rules;441 }442 443 631 Double_t MRanTree::TreeHad(const TVector &event) 444 {445 Int_t kt=0;446 // to optimize on storage space node status and node class447 // are coded into fBestVar:448 // status of node kt = TMath::Sign(1,fBestVar[kt])449 // class of node kt = fBestVar[kt]+2450 // (class defined by larger node population, actually not used)451 // hadronness assigned to node kt = fBestSplit[kt]452 453 for (Int_t k=0;k<fNumNodes;k++)454 {455 if (fBestVar[kt]<0)456 break;457 458 const Int_t m=fBestVar[kt];459 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];460 }461 462 return fBestSplit[kt];463 }464 465 Double_t MRanTree::TreeHad(const TMatrixFRow_const &event)466 632 { 467 633 Int_t kt=0; … … 485 651 } 486 652 653 Double_t MRanTree::TreeHad(const TMatrixRow &event) 654 { 655 Int_t kt=0; 656 // to optimize on storage space node status and node class 657 // are coded into fBestVar: 658 // status of node kt = TMath::Sign(1,fBestVar[kt]) 659 // class of node kt = fBestVar[kt]+2 (class defined by larger 660 // node population, actually not used) 661 // hadronness assigned to node kt = fBestSplit[kt] 662 663 for (Int_t k=0;k<fNumNodes;k++) 664 { 665 if (fBestVar[kt]<0) 666 break; 667 668 const Int_t m=fBestVar[kt]; 669 kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt]; 670 } 671 672 return fBestSplit[kt]; 673 } 674 487 675 Double_t MRanTree::TreeHad(const TMatrix &m, Int_t ievt) 488 676 { … … 494 682 } 495 683 496 Double_t MRanTree::TreeHad()497 {498 TVector event;499 *fData >> event;500 501 return TreeHad(event);502 }503 504 684 Bool_t MRanTree::AsciiWrite(ostream &out) const 505 685 { 506 out.width(5); 507 out << fNumNodes << endl; 508 686 TString str; 509 687 Int_t k; 510 for (k=0; k<fNumNodes; k++) 511 { 512 TString str=Form("%f", GetBestSplit(k)); 688 689 out.width(5);out<<fNumNodes<<endl; 690 691 for (k=0;k<fNumNodes;k++) 692 { 693 str=Form("%f",GetBestSplit(k)); 513 694 514 695 out.width(5); out << k; … … 520 701 out.width(5); out << GetNodeClass(k); 521 702 } 522 out <<endl;523 524 return k TRUE;525 } 703 out<<endl; 704 705 return k==fNumNodes; 706 } -
trunk/MagicSoft/Mars/mranforest/MRanTree.h
r7142 r7396 16 16 class TMatrix; 17 17 class TMatrixRow; 18 class TMatrixFRow_const;19 18 class TVector; 20 19 class TRandom; 21 class MDataArray;22 20 23 21 class MRanTree : public MParContainer 24 22 { 25 23 private: 24 Int_t fClassify; 26 25 Int_t fNdSize; 27 26 Int_t fNumTry; … … 29 28 Int_t fNumNodes; 30 29 Int_t fNumEndNodes; 31 MDataArray *fData;32 30 33 31 TArrayI fBestVar; … … 35 33 TArrayI fTreeMap2; 36 34 TArrayF fBestSplit; 37 38 35 TArrayF fGiniDec; 39 36 40 Int_t FindBestSplit(const TArrayI &datasort, const TArrayI &datarang, 41 const TArrayI &hadtrue, 42 Int_t ndstart, Int_t ndend, TArrayF &tclasspop, 43 Int_t &msplit, Float_t &decsplit, Int_t &nbest, 44 const TArrayF &winbag); 37 int (MRanTree::*FindBestSplit) 38 (const TArrayI &, const TArrayI &, const TArrayF &, const TArrayI &, 39 Int_t, Int_t , TArrayF &, float &, float &, Int_t &, Float_t &, 40 Int_t &, const TArrayF &, const int); //! 41 42 43 int FindBestSplitGini(const TArrayI &datasort, const TArrayI &datarang, 44 const TArrayF &hadtrue, const TArrayI &idclass, 45 Int_t ndstart, Int_t ndend, TArrayF &tclasspop, 46 float &mean, float &square, Int_t &msplit, 47 Float_t &decsplit, Int_t &nbest, const TArrayF &winbag, 48 const int nclass); 49 50 int FindBestSplitSigma(const TArrayI &datasort, const TArrayI &datarang, 51 const TArrayF &hadtrue, const TArrayI &idclass, 52 Int_t ndstart, Int_t ndend, TArrayF &tclasspop, 53 float &mean, float &square, Int_t &msplit, 54 Float_t &decsplit, Int_t &nbest, const TArrayF &winbag, 55 const int nclass); 45 56 46 57 void MoveData(TArrayI &datasort, Int_t ndstart, Int_t ndend, … … 48 59 Int_t nbest, Int_t &ndendl); 49 60 50 void BuildTree(TArrayI &datasort, const TArrayI &datarang, 51 const TArrayI &hadtrue, 52 TArrayI &bestsplit,TArrayI &bestsplitnext, 53 TArrayF &tclasspop, 54 const TArrayF &winbag, 55 Int_t ninbag); 61 void BuildTree(TArrayI &datasort, const TArrayI &datarang, const TArrayF &hadtrue, 62 const TArrayI &idclass,TArrayI &bestsplit,TArrayI &bestsplitnext, 63 TArrayF &tclasspop, float &tmean, float &tsquare, const TArrayF &winbag, 64 Int_t ninbag, const int nclass); 56 65 57 66 public: 58 67 MRanTree(const char *name=NULL, const char *title=NULL); 68 MRanTree(const MRanTree &tree); 59 69 60 70 void SetNdSize(Int_t n); 61 71 void SetNumTry(Int_t n); 62 void SetRules(MDataArray *rules);63 64 MDataArray *GetRules() { return fData;}65 72 66 73 Int_t GetNdSize() const { return fNdSize; } … … 78 85 Float_t GetGiniDec(Int_t i) const { return fGiniDec.At(i); } 79 86 87 void SetClassify(Int_t n){ fClassify=n; } 88 80 89 // functions used in tree growing process 81 void GrowTree( const TMatrix &mhad, const TMatrix &mgam,82 const TArrayI &hadtrue, TArrayI &datasort,83 const TArrayI &datarang,84 TArrayF &tclasspop, TArrayI &jinbag, const TArrayF &winbag);90 void GrowTree(TMatrix *mat, const TArrayF &hadtrue, const TArrayI &idclass, 91 TArrayI &datasort, const TArrayI &datarang,TArrayF &tclasspop, 92 float &mean, float &square, TArrayI &jinbag, const TArrayF &winbag, 93 const int nclass); 85 94 86 95 Double_t TreeHad(const TVector &event); 87 Double_t TreeHad(const TMatrix FRow_const&event);96 Double_t TreeHad(const TMatrixRow &event); 88 97 Double_t TreeHad(const TMatrix &m, Int_t ievt); 89 Double_t TreeHad();90 98 91 99 Bool_t AsciiWrite(ostream &out) const; -
trunk/MagicSoft/Mars/mranforest/Makefile
r6479 r7396 25 25 MRanForest.cc \ 26 26 MRanForestGrow.cc \ 27 MRanForestCalc.cc \28 MRanForestFill.cc \29 27 MHRanForest.cc \ 30 28 MHRanForestGini.cc \ -
trunk/MagicSoft/Mars/mranforest/RanForestLinkDef.h
r6479 r7396 8 8 #pragma link C++ class MRanForest+; 9 9 #pragma link C++ class MRanForestGrow+; 10 #pragma link C++ class MRanForestCalc+;11 #pragma link C++ class MRanForestFill+;12 10 13 11 #pragma link C++ class MHRanForest+;
Note:
See TracChangeset
for help on using the changeset viewer.