Changeset 7396 for trunk/MagicSoft/Mars/mranforest/MRanForest.cc
- Timestamp:
- 11/14/05 09:45:33 (19 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/MagicSoft/Mars/mranforest/MRanForest.cc
r7170 r7396 16 16 ! 17 17 ! 18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@ alwa02.physik.uni-siegen.de>18 ! Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de> 19 19 ! 20 ! Copyright: MAGIC Software Development, 2000-200 320 ! Copyright: MAGIC Software Development, 2000-2005 21 21 ! 22 22 ! … … 39 39 // split selection (which is subject to MRanTree::GrowTree()) 40 40 // 41 // Version 2:42 // + fUserVal43 //44 41 ///////////////////////////////////////////////////////////////////////////// 45 42 #include "MRanForest.h" 46 43 47 #include <T Matrix.h>48 #include <TRandom 3.h>44 #include <TVector.h> 45 #include <TRandom.h> 49 46 50 47 #include "MHMatrix.h" 51 48 #include "MRanTree.h" 49 #include "MData.h" 50 #include "MDataArray.h" 51 #include "MParList.h" 52 52 53 53 #include "MLog.h" … … 62 62 // Default constructor. 63 63 // 64 MRanForest::MRanForest(const char *name, const char *title) : f NumTrees(100), fRanTree(NULL),fUsePriors(kFALSE), fUserVal(-1)64 MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1) 65 65 { 66 66 fName = name ? name : "MRanForest"; … … 71 71 } 72 72 73 MRanForest::MRanForest(const MRanForest &rf) 74 { 75 // Copy constructor 76 fName = rf.fName; 77 fTitle = rf.fTitle; 78 79 fClassify = rf.fClassify; 80 fNumTrees = rf.fNumTrees; 81 fNumTry = rf.fNumTry; 82 fNdSize = rf.fNdSize; 83 fTreeNo = rf.fTreeNo; 84 fRanTree = NULL; 85 86 fRules=new MDataArray(); 87 fRules->Reset(); 88 89 MDataArray *newrules=rf.fRules; 90 91 for(Int_t i=0;i<newrules->GetNumEntries();i++) 92 { 93 MData &data=(*newrules)[i]; 94 fRules->AddEntry(data.GetRule()); 95 } 96 97 // trees 98 fForest=new TObjArray(); 99 fForest->SetOwner(kTRUE); 100 101 TObjArray *newforest=rf.fForest; 102 for(Int_t i=0;i<newforest->GetEntries();i++) 103 { 104 MRanTree *rantree=(MRanTree*)newforest->At(i); 105 106 MRanTree *newtree=new MRanTree(*rantree); 107 fForest->Add(newtree); 108 } 109 110 fHadTrue = rf.fHadTrue; 111 fHadEst = rf.fHadEst; 112 fDataSort = rf.fDataSort; 113 fDataRang = rf.fDataRang; 114 fClassPop = rf.fClassPop; 115 fWeight = rf.fWeight; 116 fTreeHad = rf.fTreeHad; 117 118 fNTimesOutBag = rf.fNTimesOutBag; 119 120 } 121 73 122 // -------------------------------------------------------------------------- 74 //75 123 // Destructor. 76 //77 124 MRanForest::~MRanForest() 78 125 { 79 126 delete fForest; 127 } 128 129 MRanTree *MRanForest::GetTree(Int_t i) 130 { 131 return (MRanTree*)(fForest->At(i)); 80 132 } 81 133 … … 88 140 } 89 141 90 void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam) 91 { 92 const Float_t sum=prior_gam+prior_had; 93 94 prior_gam/=sum; 95 prior_had/=sum; 96 97 fPrior[0]=prior_had; 98 fPrior[1]=prior_gam; 99 100 fUsePriors=kTRUE; 142 void MRanForest::SetNumTry(Int_t n) 143 { 144 fNumTry=TMath::Max(n,0); 145 } 146 147 void MRanForest::SetNdSize(Int_t n) 148 { 149 fNdSize=TMath::Max(n,1); 150 } 151 152 void MRanForest::SetWeights(const TArrayF &weights) 153 { 154 const int n=weights.GetSize(); 155 fWeight.Set(n); 156 fWeight=weights; 157 158 return; 159 } 160 161 void MRanForest::SetGrid(const TArrayD &grid) 162 { 163 const int n=grid.GetSize(); 164 165 for(int i=0;i<n-1;i++) 166 if(grid[i]>=grid[i+1]) 167 { 168 *fLog<<inf<<"Grid points must be in increasing order! Ignoring grid."<<endl; 169 return; 170 } 171 172 fGrid=grid; 173 174 //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl; 175 //for(int i=0;i<n;i++) 176 // *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl; 177 178 return; 101 179 } 102 180 103 181 Int_t MRanForest::GetNumDim() const 104 182 { 105 return fGammas ? fGammas->GetM().GetNcols() : 0; 106 } 107 183 return fMatrix ? fMatrix->GetNcols() : 0; 184 } 185 186 Int_t MRanForest::GetNumData() const 187 { 188 return fMatrix ? fMatrix->GetNrows() : 0; 189 } 190 191 Int_t MRanForest::GetNclass() const 192 { 193 int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray()); 194 195 return int(fClass[maxidx])+1; 196 } 197 198 void MRanForest::PrepareClasses() 199 { 200 const int numdata=fHadTrue.GetSize(); 201 202 if(fGrid.GetSize()>0) 203 { 204 // classes given by grid 205 const int ngrid=fGrid.GetSize(); 206 207 for(int j=0;j<numdata;j++) 208 { 209 // Array is supposed to be sorted prior to this call. 210 // If match is found, function returns position of element. 211 // If no match found, function gives nearest element smaller 212 // than value. 213 int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]); 214 215 fClass[j] = k; 216 } 217 218 int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray()); 219 int min = fClass[minidx]; 220 if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min; 221 222 }else{ 223 // classes directly given 224 for (Int_t j=0;j<numdata;j++) 225 fClass[j] = int(fHadTrue[j]+0.5); 226 } 227 228 return; 229 } 230 231 /* 232 Bool_t MRanForest::PreProcess(MParList *plist) 233 { 234 if (!fRules) 235 { 236 *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl; 237 return kFALSE; 238 } 239 240 if (!fRules->PreProcess(plist)) 241 { 242 *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl; 243 return kFALSE; 244 } 245 246 return kTRUE; 247 } 248 */ 249 250 Double_t MRanForest::CalcHadroness() 251 { 252 TVector event; 253 *fRules >> event; 254 255 return CalcHadroness(event); 256 } 108 257 109 258 Double_t MRanForest::CalcHadroness(const TVector &event) … … 117 266 while ((tree=(MRanTree*)Next())) 118 267 { 119 fTreeHad[ntree]=tree->TreeHad(event); 120 hadroness+=fTreeHad[ntree]; 268 hadroness+=(fTreeHad[ntree]=tree->TreeHad(event)); 121 269 ntree++; 122 270 } … … 126 274 Bool_t MRanForest::AddTree(MRanTree *rantree=NULL) 127 275 { 128 if (rantree) 129 fRanTree=rantree; 130 if (!fRanTree) 276 fRanTree = rantree ? rantree:fRanTree; 277 278 if (!fRanTree) return kFALSE; 279 280 MRanTree *newtree=new MRanTree(*fRanTree); 281 fForest->Add(newtree); 282 283 return kTRUE; 284 } 285 286 Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist) 287 { 288 //------------------------------------------------------------------- 289 // access matrix, copy last column (target) preliminarily 290 // into fHadTrue 291 TMatrix mat_tmp = mat->GetM(); 292 int dim = mat_tmp.GetNcols(); 293 int numdata = mat_tmp.GetNrows(); 294 295 fMatrix=new TMatrix(mat_tmp); 296 297 fHadTrue.Set(numdata); 298 fHadTrue.Reset(0); 299 300 for (Int_t j=0;j<numdata;j++) 301 fHadTrue[j] = (*fMatrix)(j,dim-1); 302 303 // remove last col 304 fMatrix->ResizeTo(numdata,dim-1); 305 dim=fMatrix->GetNcols(); 306 307 //------------------------------------------------------------------- 308 // setup labels for classification/regression 309 fClass.Set(numdata); 310 fClass.Reset(0); 311 312 if(fClassify) PrepareClasses(); 313 314 //------------------------------------------------------------------- 315 // allocating and initializing arrays 316 fHadEst.Set(numdata); fHadEst.Reset(0); 317 fNTimesOutBag.Set(numdata); fNTimesOutBag.Reset(0); 318 fDataSort.Set(dim*numdata); fDataSort.Reset(0); 319 fDataRang.Set(dim*numdata); fDataRang.Reset(0); 320 321 if(fWeight.GetSize()!=numdata) 322 { 323 fWeight.Set(numdata); 324 fWeight.Reset(1.); 325 *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl; 326 } 327 328 //------------------------------------------------------------------- 329 // setup rules to be used for classification/regression 330 MDataArray *allrules=(MDataArray*)mat->GetColumns(); 331 if(allrules==NULL) 332 { 333 *fLog << err <<"Rules of matrix == null, exiting"<< endl; 131 334 return kFALSE; 132 133 fForest->Add((MRanTree*)fRanTree->Clone()); 134 135 return kTRUE; 136 } 137 138 Int_t MRanForest::GetNumData() const 139 { 140 return fHadrons && fGammas ? fHadrons->GetM().GetNrows()+fGammas->GetM().GetNrows() : 0; 141 } 142 143 Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam) 144 { 145 // pointer to training data 146 fHadrons=mhad; 147 fGammas=mgam; 148 149 // determine data entries and dimension of Hillas-parameter space 150 //fNumHad=fHadrons->GetM().GetNrows(); 151 //fNumGam=fGammas->GetM().GetNrows(); 152 153 const Int_t dim = GetNumDim(); 154 155 if (dim!=fGammas->GetM().GetNcols()) 156 return kFALSE; 157 158 const Int_t numdata = GetNumData(); 159 160 // allocating and initializing arrays 161 fHadTrue.Set(numdata); 162 fHadTrue.Reset(); 163 fHadEst.Set(numdata); 164 165 fPrior.Set(2); 166 fClassPop.Set(2); 167 fWeight.Set(numdata); 168 fNTimesOutBag.Set(numdata); 169 fNTimesOutBag.Reset(); 170 171 fDataSort.Set(dim*numdata); 172 fDataRang.Set(dim*numdata); 173 174 // calculating class populations (= no. of gammas and hadrons) 175 fClassPop.Reset(); 176 for(Int_t n=0;n<numdata;n++) 177 fClassPop[fHadTrue[n]]++; 178 179 // setting weights taking into account priors 180 if (!fUsePriors) 181 fWeight.Reset(1.); 182 else 183 { 184 for(Int_t j=0;j<2;j++) 185 fPrior[j] *= (fClassPop[j]>=1) ? (Float_t)numdata/fClassPop[j]:0; 186 187 for(Int_t n=0;n<numdata;n++) 188 fWeight[n]=fPrior[fHadTrue[n]]; 189 } 190 191 // create fDataSort 192 CreateDataSort(); 335 } 336 337 fRules=new MDataArray(); fRules->Reset(); 338 TString target_rule; 339 340 for(Int_t i=0;i<dim+1;i++) 341 { 342 MData &data=(*allrules)[i]; 343 if(i<dim) 344 fRules->AddEntry(data.GetRule()); 345 else 346 target_rule=data.GetRule(); 347 } 348 349 *fLog << inf <<endl; 350 *fLog << inf <<"Setting up RF for training on target:"<<endl<<" "<<target_rule.Data()<<endl; 351 *fLog << inf <<"Following rules are used as input to RF:"<<endl; 352 353 for(Int_t i=0;i<dim;i++) 354 { 355 MData &data=(*fRules)[i]; 356 *fLog<<inf<<" "<<i<<") "<<data.GetRule()<<endl<<flush; 357 } 358 359 *fLog << inf <<endl; 360 361 //------------------------------------------------------------------- 362 // prepare (sort) data for fast optimization algorithm 363 if(!CreateDataSort()) return kFALSE; 364 365 //------------------------------------------------------------------- 366 // access and init tree container 367 fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree"); 193 368 194 369 if(!fRanTree) … … 197 372 return kFALSE; 198 373 } 199 fRanTree->SetRules(fGammas->GetColumns()); 374 375 fRanTree->SetClassify(fClassify); 376 fRanTree->SetNdSize(fNdSize); 377 378 if(fNumTry==0) 379 { 380 double ddim = double(dim); 381 382 fNumTry=int(sqrt(ddim)+0.5); 383 *fLog<<inf<<endl; 384 *fLog<<inf<<"Set no. of trials to the recommended value of round("<<sqrt(ddim)<<") = "; 385 *fLog<<inf<<fNumTry<<endl; 386 387 } 388 fRanTree->SetNumTry(fNumTry); 389 390 *fLog<<inf<<endl; 391 *fLog<<inf<<"Following settings for the tree growing are used:"<<endl; 392 *fLog<<inf<<" Number of Trees : "<<fNumTrees<<endl; 393 *fLog<<inf<<" Number of Trials: "<<fNumTry<<endl; 394 *fLog<<inf<<" Final Node size : "<<fNdSize<<endl; 395 200 396 fTreeNo=0; 201 397 202 398 return kTRUE; 203 }204 205 void MRanForest::InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag)206 {207 for (Int_t ievt=from;ievt<to;ievt++)208 {209 if (jinbag[ievt]>0)210 continue;211 fHadEst[ievt] += fRanTree->TreeHad(m, ievt-from);212 fNTimesOutBag[ievt]++;213 }214 399 } 215 400 … … 224 409 fTreeNo++; 225 410 411 //------------------------------------------------------------------- 226 412 // initialize running output 413 414 float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())]; 415 Bool_t calcResolution=(minfloat>0.001); 416 227 417 if (fTreeNo==1) 228 418 { 229 419 *fLog << inf << endl; 230 *fLog << underline; // << "1st col 2nd col" << endl; 231 *fLog << "no. of tree error in % (calulated using oob-data -> overestim. of error)" << endl; 420 *fLog << underline; 421 422 if(calcResolution) 423 *fLog << "no. of tree no. of nodes resolution in % (from oob-data -> overest. of error)" << endl; 424 else 425 *fLog << "no. of tree no. of nodes rms in % (from oob-data -> overest. of error)" << endl; 426 // 12345678901234567890123456789012345678901234567890 232 427 } 233 428 234 429 const Int_t numdata = GetNumData(); 235 430 const Int_t nclass = GetNclass(); 431 432 //------------------------------------------------------------------- 236 433 // bootstrap aggregating (bagging) -> sampling with replacement: 237 // 238 // The integer k is randomly (uniformly) chosen from the set 239 // {0,1,...,fNumData-1}, which is the set of the index numbers of 240 // all events in the training sample 241 TArrayF classpopw(2); 434 435 TArrayF classpopw(nclass); 242 436 TArrayI jinbag(numdata); // Initialization includes filling with 0 243 437 TArrayF winbag(numdata); // Initialization includes filling with 0 244 438 439 float square=0; float mean=0; 440 245 441 for (Int_t n=0; n<numdata; n++) 246 442 { 443 // The integer k is randomly (uniformly) chosen from the set 444 // {0,1,...,numdata-1}, which is the set of the index numbers of 445 // all events in the training sample 446 247 447 const Int_t k = Int_t(gRandom->Rndm()*numdata); 248 448 249 classpopw[fHadTrue[k]]+=fWeight[k]; 449 if(fClassify) 450 classpopw[fClass[k]]+=fWeight[k]; 451 else 452 classpopw[0]+=fWeight[k]; 453 454 mean +=fHadTrue[k]*fWeight[k]; 455 square+=fHadTrue[k]*fHadTrue[k]*fWeight[k]; 456 250 457 winbag[k]+=fWeight[k]; 251 458 jinbag[k]=1; 252 } 253 459 460 } 461 462 //------------------------------------------------------------------- 254 463 // modifying sorted-data array for in-bag data: 255 // 464 256 465 // In bagging procedure ca. 2/3 of all elements in the original 257 466 // training sample are used to build the in-bag data … … 261 470 ModifyDataSort(datsortinbag, ninbag, jinbag); 262 471 263 const TMatrix &hadrons=fHadrons->GetM(); 264 const TMatrix &gammas =fGammas->GetM(); 265 266 // growing single tree 267 fRanTree->GrowTree(hadrons,gammas,fHadTrue,datsortinbag, 268 fDataRang,classpopw,jinbag,winbag); 269 472 fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square, 473 jinbag,winbag,nclass); 474 475 //------------------------------------------------------------------- 270 476 // error-estimates from out-of-bag data (oob data): 271 477 // … … 277 483 // determined from oob-data is underestimated, but can still be taken as upper limit. 278 484 279 const Int_t numhad = hadrons.GetNrows(); 280 InitHadEst(0, numhad, hadrons, jinbag); 281 InitHadEst(numhad, numdata, gammas, jinbag); 282 /* 283 for (Int_t ievt=0;ievt<numhad;ievt++) 284 { 285 if (jinbag[ievt]>0) 286 continue; 287 fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt); 485 for (Int_t ievt=0;ievt<numdata;ievt++) 486 { 487 if (jinbag[ievt]>0) continue; 488 489 fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt); 288 490 fNTimesOutBag[ievt]++; 289 } 290 291 for (Int_t ievt=numhad;ievt<numdata;ievt++) 292 { 293 if (jinbag[ievt]>0) 294 continue; 295 fHadEst[ievt] += fRanTree->TreeHad(gammas, ievt-numhad); 296 fNTimesOutBag[ievt]++; 297 } 298 */ 491 492 } 493 299 494 Int_t n=0; 300 Double_t ferr=0; 495 double ferr=0; 496 301 497 for (Int_t ievt=0;ievt<numdata;ievt++) 302 if (fNTimesOutBag[ievt]!=0) 498 { 499 if(fNTimesOutBag[ievt]!=0) 303 500 { 304 const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt]; 501 float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt]; 502 if(calcResolution) val/=fHadTrue[ievt]; 503 305 504 ferr += val*val; 306 505 n++; 307 506 } 308 507 } 309 508 ferr = TMath::Sqrt(ferr/n); 310 509 510 //------------------------------------------------------------------- 311 511 // give running output 312 *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl; 512 *fLog << inf << setw(5) << fTreeNo; 513 *fLog << inf << setw(20) << fRanTree->GetNumEndNodes(); 514 *fLog << inf << Form("%20.2f", ferr*100.); 515 *fLog << inf << endl; 313 516 314 517 // adding tree to forest … … 318 521 } 319 522 320 void MRanForest::CreateDataSort() 321 { 322 // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData, 323 // which does actually not exist) are sorted from lowest to highest. 523 Bool_t MRanForest::CreateDataSort() 524 { 525 // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs. 526 // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r: 527 // fMatrix(m,n) is the r-th highest value of all fMatrix(m,.). 324 528 // 325 // 326 // fHadrons(0,0) ... fHadrons(0,nhad-1) fGammas(0,0) ... fGammas(0,ngam-1) 327 // . . . . 328 // fData(m,n) = . . . . 329 // . . . . 330 // fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1) 331 // 332 // 333 // Then fDataSort(m,n) is the event number in which fData(m,n) occurs. 334 // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest 335 // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging). 529 // There may be more then 1 event with rang r (due to bagging). 530 336 531 const Int_t numdata = GetNumData(); 532 const Int_t dim = GetNumDim(); 337 533 338 534 TArrayF v(numdata); 339 535 TArrayI isort(numdata); 340 536 341 const TMatrix &hadrons=fHadrons->GetM(); 342 const TMatrix &gammas=fGammas->GetM(); 343 344 const Int_t numgam = gammas.GetNrows(); 345 const Int_t numhad = hadrons.GetNrows(); 346 347 for (Int_t j=0;j<numhad;j++) 348 fHadTrue[j]=1; 349 350 for (Int_t j=0;j<numgam;j++) 351 fHadTrue[j+numhad]=0; 352 353 const Int_t dim = GetNumDim(); 537 354 538 for (Int_t mvar=0;mvar<dim;mvar++) 355 539 { 356 for(Int_t n=0;n<numhad;n++) 540 541 for(Int_t n=0;n<numdata;n++) 357 542 { 358 v[n]= hadrons(n,mvar);543 v[n]=(*fMatrix)(n,mvar); 359 544 isort[n]=n; 360 } 361 362 for(Int_t n=0;n<numgam;n++) 363 { 364 v[n+numhad]=gammas(n,mvar); 365 isort[n+numhad]=n; 545 546 if(TMath::IsNaN(v[n])) 547 { 548 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar; 549 *fLog << err <<" has the value NaN."<<endl; 550 return kFALSE; 551 } 366 552 } 367 553 … … 371 557 // of that v[n], which is the n-th from the lowest (assume the original 372 558 // event numbers are 0,1,...). 559 560 // control sorting 561 for(int n=1;n<numdata;n++) 562 if(v[isort[n-1]]>v[isort[n]]) 563 { 564 *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar; 565 *fLog << err <<" not at correct sorting position."<<endl; 566 return kFALSE; 567 } 373 568 374 569 for(Int_t n=0;n<numdata-1;n++) … … 388 583 fDataSort[(mvar+1)*numdata-1]=isort[numdata-1]; 389 584 } 585 return kTRUE; 390 586 } 391 587
Note:
See TracChangeset
for help on using the changeset viewer.