Changeset 7396 for trunk/MagicSoft/Mars


Ignore:
Timestamp:
11/14/05 09:45:33 (19 years ago)
Author:
tbretz
Message:
*** empty log message ***
Location:
trunk/MagicSoft/Mars
Files:
1 added
12 edited

Legend:

Unmodified
Added
Removed
  • trunk/MagicSoft/Mars/Changelog

    r7395 r7396  
    1818
    1919                                                 -*-*- END OF LINE -*-*-
     20
     21 2005/11/14 Thomas Bretz
     22
     23   * MRFEnergyEst.[h,cc]:
     24     - changed to allow for new regression method
     25
     26   * MRanForest.[h,cc]:
     27     - taken TH's implementation of the RF regression
     28       + updated includes
     29       + added fUserVal
     30       + removed ReadRF, WriteRF -- obsolete
     31
     32   * MRanForestGrow.[h,cc]:
     33     - took out setting up the growing of the forest from this task
     34       (currently it is done by MRanForest directly)
     35     - adapted to changes in other classes as TH did.
     36
     37   * MRanTree.[h,cc]:
     38     - changes taken from TH
     39       + added copy-constructor
     40       + upadted to allow regression
     41
     42   * Makefile, RanForestLinkDef.h:
     43     - took out MRanForestCalc and MRanForestFill temporarily
     44
     45
    2046
    2147 2005/11/12 Daniela Dorner
  • trunk/MagicSoft/Mars/NEWS

    r7392 r7396  
    22
    33 *** Version  <cvs>
     4
     5   - general: Updated the random forest classes to support also the
     6     regression method implemented by Thomas H.
    47
    58   - ganymed: Implemented two new options which allow
  • trunk/MagicSoft/Mars/mranforest/MRFEnergyEst.cc

    r7178 r7396  
    1717!
    1818!   Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
     19!   Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
    1920!
    2021!   Copyright: MAGIC Software Development, 2000-2005
     
    3132#include "MRFEnergyEst.h"
    3233
    33 #include <TFile.h>
    34 #include <TList.h>
    35 
    36 #include <TH1.h>
    37 #include <TH2.h>
    38 #include <TStyle.h>
    39 #include <TCanvas.h>
    40 #include <TMath.h>
    4134#include <TVector.h>
    4235
     
    4538#include "MLog.h"
    4639#include "MLogManip.h"
     40
     41#include "MData.h"
     42#include "MDataArray.h"
     43
     44#include "MRanForest.h"
     45#include "MParameters.h"
    4746
    4847#include "MParList.h"
    4948#include "MTaskList.h"
    5049#include "MEvtLoop.h"
    51 
    52 #include "MRanTree.h"
    53 #include "MRanForest.h"
    5450#include "MRanForestGrow.h"
    5551
    56 #include "MData.h"
    57 #include "MParameters.h"
    58 
    5952ClassImp(MRFEnergyEst);
    6053
    6154using namespace std;
    6255
    63 static const TString gsDefName  = "MRFEnergyEst";
    64 static const TString gsDefTitle = "RF for energy estimation";
    65 
    66 // --------------------------------------------------------------------------
    67 //
    68 //  Default constructor. Set the name and title of this object
    69 //
     56const TString MRFEnergyEst::gsDefName  = "MRFEnergyEst";
     57const TString MRFEnergyEst::gsDefTitle = "RF for energy estimation";
     58
    7059MRFEnergyEst::MRFEnergyEst(const char *name, const char *title)
    7160    : fDebug(kFALSE), fData(0), fEnergyEst(0),
     
    7564    fName  = name  ? name  : gsDefName.Data();
    7665    fTitle = title ? title : gsDefTitle.Data();
     66
     67    // FIXME:
     68    fNumTrees = 100; //100
     69    fNumTry   = 5;   //3
     70    fNdSize   = 0;   //1   0 means: in MRanForest estimated best value will be calculated
    7771}
    7872
     
    8276}
    8377
    84 // --------------------------------------------------------------------------
    85 //
    86 // Train the RF with the goven MHMatrix. The last column must contain the
    87 // True energy.
    88 //
    89 Int_t MRFEnergyEst::Train(const MHMatrix &matrixtrain, const TArrayD &grid)
     78Int_t MRFEnergyEst::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
    9079{
    9180    gLog.Separator("MRFEnergyEst - Train");
     
    10594    }
    10695
    107     const Int_t nbins = grid.GetSize()-1;
    108     if (nbins<=0)
    109     {
    110         *fLog << err << "ERROR - Energy grid not vaild... abort." << endl;
    111         return kFALSE;
    112     }
    113 
    11496    // rules (= combination of image par) to be used for energy estimation
    11597    TFile fileRF(fFileName, "recreate");
     
    122104    const Int_t nobs = 3; // Number of obsolete columns
    123105
     106    MDataArray &dcol = *matrixtrain.GetColumns();
     107
    124108    MDataArray usedrules;
    125     for(Int_t i=0; i<ncols-nobs; i++) // -3 is important!!!
    126         usedrules.AddEntry((*matrixtrain.GetColumns())[i].GetRule());
    127 
    128     // training of RF for each energy bin
     109    for (Int_t i=0; i<ncols-nobs; i++) // -3 is important!!!
     110        usedrules.AddEntry(dcol[i].GetRule());
     111
     112    MDataArray rules(usedrules);
     113    rules.AddEntry(ver<2?"Classification":dcol[ncols-1].GetRule());
     114
     115    // prepare matrix for current energy bin
     116    TMatrix mat(matrixtrain.GetM());
     117
     118    // last column must be removed (true energy col.)
     119    mat.ResizeTo(nrows, ncols-nobs+1);
     120
     121    if (fDebug)
     122        gLog.SetNullOutput(kTRUE);
     123
     124    const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
    129125    for (Int_t ie=0; ie<nbins; ie++)
    130126    {
    131         TMatrix mat1(nrows, ncols);
    132         TMatrix mat0(nrows, ncols);
    133 
    134         // prepare matrix for current energy bin
    135         Int_t irow1=0;
    136         Int_t irow0=0;
    137 
    138         const TMatrix &m = matrixtrain.GetM();
    139         for (Int_t j=0; j<nrows; j++)
     127        switch (ver)
    140128        {
    141             const Double_t energy = m(j,ncols-1);
    142 
    143             if (energy>grid[ie] && energy<=grid[ie+1])
    144                 TMatrixFRow(mat1, irow1++) = TMatrixFRow_const(m,j);
    145             else
    146                 TMatrixFRow(mat0, irow0++) = TMatrixFRow_const(m,j);
     129        case 0: // Replace Energy Grid by classification
     130            {
     131                Int_t irows=0;
     132                for (Int_t j=0; j<nrows; j++)
     133                {
     134                    const Double_t energy = matrixtrain.GetM()(j,ncols-1);
     135                    const Bool_t   inside = energy>grid[ie] && energy<=grid[ie+1];
     136
     137                    mat(j, ncols-nobs) = inside ? 1 : 0;
     138
     139                    if (inside)
     140                        irows++;
     141                }
     142                if (irows==0)
     143                    *fLog << warn << "WARNING - Skipping";
     144                else
     145                    *fLog << inf << "Training RF for";
     146
     147                *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;
     148
     149                if (irows==0)
     150                    continue;
     151            }
     152            break;
     153
     154        case 1: // Use Energy as classifier
     155        case 2:
     156            for (Int_t j=0; j<nrows; j++)
     157                mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
     158            break;
    147159        }
    148160
    149         const Bool_t invalid = irow1==0 || irow0==0;
    150 
    151         if (invalid)
    152             *fLog << warn << "WARNING - Skipping";
    153         else
    154             *fLog << inf << "Training RF for";
    155 
    156         *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irow0 << " " << irow1 << endl;
    157 
    158         if (invalid)
    159             continue;
    160 
    161         if (fDebug)
    162             gLog.SetNullOutput(kTRUE);
    163 
    164         // last column must be removed (true energy col.)
    165         mat1.ResizeTo(irow1, ncols-nobs);
    166         mat0.ResizeTo(irow0, ncols-nobs);
    167 
    168         // create MHMatrix as input for RF
    169         MHMatrix matrix1(mat1, "MatrixHadrons");
    170         MHMatrix matrix0(mat0, "MatrixGammas");
    171 
    172         // training of RF
     161        MHMatrix matrix(mat, &rules, "MatrixTrain");
     162
     163        MParList plist;
    173164        MTaskList tlist;
    174         MParList plist;
    175165        plist.AddToList(&tlist);
    176         plist.AddToList(&matrix0);
    177         plist.AddToList(&matrix1);
     166        plist.AddToList(&matrix);
     167
     168        MRanForest rf;
     169        rf.SetNumTrees(fNumTrees);
     170        rf.SetNumTry(fNumTry);
     171        rf.SetNdSize(fNdSize);
     172        rf.SetClassify(ver<2 ? 1 : 0);
     173        if (ver==1)
     174            rf.SetGrid(grid);
     175
     176        plist.AddToList(&rf);
    178177
    179178        MRanForestGrow rfgrow;
    180         rfgrow.SetNumTrees(fNumTrees); // number of trees
    181         rfgrow.SetNumTry(fNumTry);     // number of trials in random split selection
    182         rfgrow.SetNdSize(fNdSize);     // limit for nodesize
    183 
    184179        tlist.AddToList(&rfgrow);
    185    
     180
    186181        MEvtLoop evtloop;
    187         evtloop.SetDisplay(fDisplay);
    188182        evtloop.SetParList(&plist);
    189183
     
    194188            gLog.SetNullOutput(kFALSE);
    195189
    196         // Calculate bin center
    197         const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
    198 
    199         // save whole forest
    200         MRanForest *forest=(MRanForest*)plist.FindObject("MRanForest");
    201         forest->SetUserVal(E);
    202         forest->Write(Form("%.10f", E));
     190        if (ver==0)
     191        {
     192            // Calculate bin center
     193            const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;
     194
     195            // save whole forest
     196            rf.SetUserVal(E);
     197            rf.SetName(Form("%.10f", E));
     198        }
     199
     200        rf.Write();
    203201    }
    204202
     
    224222    while ((o=Next()))
    225223    {
    226         MRanForest *forest;
     224        MRanForest *forest=0;
    227225        fileRF.GetObject(o->GetName(), forest);
    228226        if (!forest)
     
    234232    }
    235233
     234    // Maybe fEForests[0].fRules yould be used instead?
     235
    236236    if (fData->Read("rules")<=0)
    237237    {
     
    249249        return kFALSE;
    250250
     251    cout << "MDataArray" << endl;
     252
    251253    fData = (MDataArray*)plist->FindCreateObj("MDataArray");
    252254    if (!fData)
    253255        return kFALSE;
    254256
     257    cout << "ReadForests" << endl;
     258
    255259    if (!ReadForests(*plist))
    256260    {
     
    275279}
    276280
    277 // --------------------------------------------------------------------------
    278 //
    279 //
    280281#include <TGraph.h>
    281282#include <TF1.h>
    282283Int_t MRFEnergyEst::Process()
    283284{
    284     static TF1 f1("f1", "gaus");
    285 
    286285    TVector event;
    287286    if (fTestMatrix)
     
    290289        *fData >> event;
    291290
     291    // --------------- Single Tree RF -------------------
     292    if (fEForests.GetEntries()==1)
     293    {
     294        const MRanForest *rf = (MRanForest*)fEForests[0];
     295        fEnergyEst->SetVal(rf->CalcHadroness(event));
     296        fEnergyEst->SetReadyToSave();
     297
     298        return kTRUE;
     299    }
     300
     301    // --------------- Multi Tree RF -------------------
     302    static TF1 f1("f1", "gaus");
     303
    292304    Double_t sume = 0;
    293305    Double_t sumh = 0;
     
    308320        const Double_t h = rf->CalcHadroness(event);
    309321        const Double_t e = rf->GetUserVal();
     322
    310323        g.SetPoint(g.GetN(), e, h);
     324
    311325        sume += e*h;
    312326        sumh += h;
     327
    313328        if (h>maxh)
    314329        {
     
    337352        fEnergyEst->SetVal(pow(10, f1.GetParameter(1)));
    338353        break;
    339 
    340     }
     354    }
     355
    341356    fEnergyEst->SetReadyToSave();
    342357
  • trunk/MagicSoft/Mars/mranforest/MRFEnergyEst.h

    r7178 r7396  
    99#include <TObjArray.h>
    1010#endif
     11
    1112#ifndef ROOT_TArrayD
    1213#include <TArrayD.h>
    1314#endif
    1415
    15 class MHMatrix;
    1616class MDataArray;
    1717class MParameterD;
     18class MHMatrix;
    1819
    1920class MRFEnergyEst : public MTask
     
    2627        kFit
    2728    };
     29
    2830private:
     31    static const TString gsDefName;   //! Default Name
     32    static const TString gsDefTitle;  //! Default Title
     33
    2934    Bool_t       fDebug;      // Debugging of eventloop while training on/off
    3035
     
    3944    Int_t        fNdSize;     //! Training parameters
    4045
    41     MHMatrix    *fTestMatrix; //! Test Matrix
     46    MHMatrix    *fTestMatrix; //! Test Matrix used in Process (together with MMatrixLoop)
    4247
    4348    EstimationMode_t fEstimationMode;
     49
     50private:
     51    // MTask
     52    Int_t PreProcess(MParList *plist);
     53    Int_t Process();
    4454
    4555    // MRFEnergyEst
    4656    Int_t ReadForests(MParList &plist);
    4757
    48     // MTask
    49     Int_t PreProcess(MParList *plist);
    50     Int_t Process();
    51 
    5258    // MParContainer
    5359    Int_t ReadEnv(const TEnv &env, TString prefix, Bool_t print);
     60
     61    // Train Interface
     62    Int_t Train(const MHMatrix &n, const TArrayD &grid, Int_t ver=2);
    5463
    5564public:
     
    5867
    5968    // Setter for estimation
    60     void  SetFileName(TString str)     { fFileName = str; }
    61     void  SetEstimationMode(EstimationMode_t op) { fEstimationMode = op; }
     69    void SetFileName(TString filename)          { fFileName = filename; }
     70    void SetEstimationMode(EstimationMode_t op) { fEstimationMode = op; }
    6271
    6372    // Setter for training
    64     void  SetNumTrees(Int_t n=-1)      { fNumTrees = n; }
    65     void  SetNdSize(Int_t n=-1)        { fNdSize   = n; }
    66     void  SetNumTry(Int_t n=-1)        { fNumTry   = n; }
    67     void  SetDebug(Bool_t b=kTRUE)     { fDebug = b; }
     73    void SetNumTrees(UShort_t n=100) { fNumTrees = n; }
     74    void SetNdSize(UShort_t n=5)     { fNdSize   = n; }
     75    void SetNumTry(UShort_t n=0)     { fNumTry   = n; }
     76    void SetDebug(Bool_t b=kTRUE)    { fDebug = b; }
    6877
    6978    // Train Interface
    70     Int_t Train(const MHMatrix &n, const TArrayD &grid);
     79    Int_t TrainMultiRF(const MHMatrix &n, const TArrayD &grid)
     80    {
     81        return Train(n, grid, 0);
     82    }
     83    Int_t TrainSingleRF(const MHMatrix &n, const TArrayD &grid=TArrayD())
     84    {
     85        return Train(n, grid, grid.GetSize()==0 ? 2 : 1);
     86    }
    7187
    7288    // Test Interface
  • trunk/MagicSoft/Mars/mranforest/MRanForest.cc

    r7170 r7396  
    1616!
    1717!
    18 !   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
     18!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
    1919!
    20 !   Copyright: MAGIC Software Development, 2000-2003
     20!   Copyright: MAGIC Software Development, 2000-2005
    2121!
    2222!
     
    3939// split selection (which is subject to MRanTree::GrowTree())
    4040//
    41 // Version 2:
    42 //  + fUserVal
    43 //
    4441/////////////////////////////////////////////////////////////////////////////
    4542#include "MRanForest.h"
    4643
    47 #include <TMatrix.h>
    48 #include <TRandom3.h>
     44#include <TVector.h>
     45#include <TRandom.h>
    4946
    5047#include "MHMatrix.h"
    5148#include "MRanTree.h"
     49#include "MData.h"
     50#include "MDataArray.h"
     51#include "MParList.h"
    5252
    5353#include "MLog.h"
     
    6262// Default constructor.
    6363//
    64 MRanForest::MRanForest(const char *name, const char *title) : fNumTrees(100), fRanTree(NULL),fUsePriors(kFALSE), fUserVal(-1)
     64MRanForest::MRanForest(const char *name, const char *title) : fClassify(1), fNumTrees(100), fNumTry(0), fNdSize(1), fRanTree(NULL), fUserVal(-1)
    6565{
    6666    fName  = name  ? name  : "MRanForest";
     
    7171}
    7272
     73MRanForest::MRanForest(const MRanForest &rf)
     74{
     75    // Copy constructor
     76    fName  = rf.fName;
     77    fTitle = rf.fTitle;
     78
     79    fClassify = rf.fClassify;
     80    fNumTrees = rf.fNumTrees;
     81    fNumTry   = rf.fNumTry;
     82    fNdSize   = rf.fNdSize;
     83    fTreeNo   = rf.fTreeNo;
     84    fRanTree  = NULL;
     85
     86    fRules=new MDataArray();
     87    fRules->Reset();
     88
     89    MDataArray *newrules=rf.fRules;
     90
     91    for(Int_t i=0;i<newrules->GetNumEntries();i++)
     92    {
     93        MData &data=(*newrules)[i];
     94        fRules->AddEntry(data.GetRule());
     95    }
     96
     97    // trees
     98    fForest=new TObjArray();
     99    fForest->SetOwner(kTRUE);
     100
     101    TObjArray *newforest=rf.fForest;
     102    for(Int_t i=0;i<newforest->GetEntries();i++)
     103    {
     104        MRanTree *rantree=(MRanTree*)newforest->At(i);
     105
     106        MRanTree *newtree=new MRanTree(*rantree);
     107        fForest->Add(newtree);
     108    }
     109
     110    fHadTrue  = rf.fHadTrue;
     111    fHadEst   = rf.fHadEst;
     112    fDataSort = rf.fDataSort;
     113    fDataRang = rf.fDataRang;
     114    fClassPop = rf.fClassPop;
     115    fWeight   = rf.fWeight;
     116    fTreeHad  = rf.fTreeHad;
     117
     118    fNTimesOutBag = rf.fNTimesOutBag;
     119
     120}
     121
    73122// --------------------------------------------------------------------------
    74 //
    75123// Destructor.
    76 //
    77124MRanForest::~MRanForest()
    78125{
    79126    delete fForest;
     127}
     128
     129MRanTree *MRanForest::GetTree(Int_t i)
     130{
     131    return (MRanTree*)(fForest->At(i));
    80132}
    81133
     
    88140}
    89141
    90 void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam)
    91 {
    92     const Float_t sum=prior_gam+prior_had;
    93 
    94     prior_gam/=sum;
    95     prior_had/=sum;
    96 
    97     fPrior[0]=prior_had;
    98     fPrior[1]=prior_gam;
    99 
    100     fUsePriors=kTRUE;
     142void MRanForest::SetNumTry(Int_t n)
     143{
     144    fNumTry=TMath::Max(n,0);
     145}
     146
     147void MRanForest::SetNdSize(Int_t n)
     148{
     149    fNdSize=TMath::Max(n,1);
     150}
     151
     152void MRanForest::SetWeights(const TArrayF &weights)
     153{
     154    const int n=weights.GetSize();
     155    fWeight.Set(n);
     156    fWeight=weights;
     157
     158    return;
     159}
     160
     161void MRanForest::SetGrid(const TArrayD &grid)
     162{
     163    const int n=grid.GetSize();
     164
     165    for(int i=0;i<n-1;i++)
     166        if(grid[i]>=grid[i+1])
     167        {
     168            *fLog<<inf<<"Grid points must be in increasing order! Ignoring grid."<<endl;
     169            return;
     170        }
     171
     172    fGrid=grid;
     173
     174    //*fLog<<inf<<"Following "<<n<<" grid points are used:"<<endl;
     175    //for(int i=0;i<n;i++)
     176    //    *fLog<<inf<<" "<<i<<") "<<fGrid[i]<<endl;
     177
     178    return;
    101179}
    102180
    103181Int_t MRanForest::GetNumDim() const
    104182{
    105     return fGammas ? fGammas->GetM().GetNcols() : 0;
    106 }
    107 
     183    return fMatrix ? fMatrix->GetNcols() : 0;
     184}
     185
     186Int_t MRanForest::GetNumData() const
     187{
     188    return fMatrix ? fMatrix->GetNrows() : 0;
     189}
     190
     191Int_t MRanForest::GetNclass() const
     192{
     193    int maxidx = TMath::LocMax(fClass.GetSize(),fClass.GetArray());
     194
     195    return int(fClass[maxidx])+1;
     196}
     197
     198void MRanForest::PrepareClasses()
     199{
     200    const int numdata=fHadTrue.GetSize();
     201
     202    if(fGrid.GetSize()>0)
     203    {
     204        // classes given by grid
     205        const int ngrid=fGrid.GetSize();
     206
     207        for(int j=0;j<numdata;j++)
     208        {
     209            // Array is supposed  to be sorted prior to this call.
     210            // If match is found, function returns position of element.
     211            // If no match found, function gives nearest element smaller
     212            // than value.
     213            int k=TMath::BinarySearch(ngrid, fGrid.GetArray(), fHadTrue[j]);
     214
     215            fClass[j]   = k;
     216        }
     217
     218        int minidx = TMath::LocMin(fClass.GetSize(),fClass.GetArray());
     219        int min = fClass[minidx];
     220        if(min!=0) for(int j=0;j<numdata;j++)fClass[j]-=min;
     221
     222    }else{
     223        // classes directly given
     224        for (Int_t j=0;j<numdata;j++)
     225            fClass[j] = int(fHadTrue[j]+0.5);
     226    }
     227
     228    return;
     229}
     230
     231/*
     232Bool_t MRanForest::PreProcess(MParList *plist)
     233{
     234    if (!fRules)
     235    {
     236        *fLog << err << dbginf << "MDataArray with rules not initialized... aborting." << endl;
     237        return kFALSE;
     238    }
     239
     240    if (!fRules->PreProcess(plist))
     241    {
     242        *fLog << err << dbginf << "PreProcessing of MDataArray failed... aborting." << endl;
     243        return kFALSE;
     244    }
     245
     246    return kTRUE;
     247}
     248*/
     249
     250Double_t MRanForest::CalcHadroness()
     251{
     252    TVector event;
     253    *fRules >> event;
     254
     255    return CalcHadroness(event);
     256}
    108257
    109258Double_t MRanForest::CalcHadroness(const TVector &event)
     
    117266    while ((tree=(MRanTree*)Next()))
    118267    {
    119         fTreeHad[ntree]=tree->TreeHad(event);
    120         hadroness+=fTreeHad[ntree];
     268        hadroness+=(fTreeHad[ntree]=tree->TreeHad(event));
    121269        ntree++;
    122270    }
     
    126274Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
    127275{
    128     if (rantree)
    129         fRanTree=rantree;
    130     if (!fRanTree)
     276    fRanTree = rantree ? rantree:fRanTree;
     277
     278    if (!fRanTree) return kFALSE;
     279
     280    MRanTree *newtree=new MRanTree(*fRanTree);
     281    fForest->Add(newtree);
     282
     283    return kTRUE;
     284}
     285
     286Bool_t MRanForest::SetupGrow(MHMatrix *mat,MParList *plist)
     287{
     288    //-------------------------------------------------------------------
     289    // access matrix, copy last column (target) preliminarily
     290    // into fHadTrue
     291    TMatrix mat_tmp = mat->GetM();
     292    int dim         = mat_tmp.GetNcols();
     293    int numdata     = mat_tmp.GetNrows();
     294
     295    fMatrix=new TMatrix(mat_tmp);
     296
     297    fHadTrue.Set(numdata);
     298    fHadTrue.Reset(0);
     299
     300    for (Int_t j=0;j<numdata;j++)
     301        fHadTrue[j] = (*fMatrix)(j,dim-1);
     302
     303    // remove last col
     304    fMatrix->ResizeTo(numdata,dim-1);
     305    dim=fMatrix->GetNcols();
     306
     307    //-------------------------------------------------------------------
     308    // setup labels for classification/regression
     309    fClass.Set(numdata);
     310    fClass.Reset(0);
     311
     312    if(fClassify) PrepareClasses();
     313
     314    //-------------------------------------------------------------------
     315    // allocating and initializing arrays
     316    fHadEst.Set(numdata);       fHadEst.Reset(0);
     317    fNTimesOutBag.Set(numdata); fNTimesOutBag.Reset(0);
     318    fDataSort.Set(dim*numdata); fDataSort.Reset(0);
     319    fDataRang.Set(dim*numdata); fDataRang.Reset(0);
     320
     321    if(fWeight.GetSize()!=numdata)
     322    {
     323        fWeight.Set(numdata);
     324        fWeight.Reset(1.);
     325        *fLog << inf <<"Setting weights to 1 (no weighting)"<< endl;
     326    }
     327
     328    //-------------------------------------------------------------------
     329    // setup rules to be used for classification/regression
     330    MDataArray *allrules=(MDataArray*)mat->GetColumns();
     331    if(allrules==NULL)
     332    {
     333        *fLog << err <<"Rules of matrix == null, exiting"<< endl;
    131334        return kFALSE;
    132 
    133     fForest->Add((MRanTree*)fRanTree->Clone());
    134 
    135     return kTRUE;
    136 }
    137 
    138 Int_t MRanForest::GetNumData() const
    139 {
    140     return fHadrons && fGammas ? fHadrons->GetM().GetNrows()+fGammas->GetM().GetNrows() : 0;
    141 }
    142 
    143 Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam)
    144 {
    145     // pointer to training data
    146     fHadrons=mhad;
    147     fGammas=mgam;
    148 
    149     // determine data entries and dimension of Hillas-parameter space
    150     //fNumHad=fHadrons->GetM().GetNrows();
    151     //fNumGam=fGammas->GetM().GetNrows();
    152 
    153     const Int_t dim = GetNumDim();
    154 
    155     if (dim!=fGammas->GetM().GetNcols())
    156         return kFALSE;
    157 
    158     const Int_t numdata = GetNumData();
    159 
    160     // allocating and initializing arrays
    161     fHadTrue.Set(numdata);
    162     fHadTrue.Reset();
    163     fHadEst.Set(numdata);
    164 
    165     fPrior.Set(2);
    166     fClassPop.Set(2);
    167     fWeight.Set(numdata);
    168     fNTimesOutBag.Set(numdata);
    169     fNTimesOutBag.Reset();
    170 
    171     fDataSort.Set(dim*numdata);
    172     fDataRang.Set(dim*numdata);
    173 
    174     // calculating class populations (= no. of gammas and hadrons)
    175     fClassPop.Reset();
    176     for(Int_t n=0;n<numdata;n++)
    177         fClassPop[fHadTrue[n]]++;
    178 
    179     // setting weights taking into account priors
    180     if (!fUsePriors)
    181         fWeight.Reset(1.);
    182     else
    183     {
    184         for(Int_t j=0;j<2;j++)
    185             fPrior[j] *= (fClassPop[j]>=1) ? (Float_t)numdata/fClassPop[j]:0;
    186 
    187         for(Int_t n=0;n<numdata;n++)
    188             fWeight[n]=fPrior[fHadTrue[n]];
    189     }
    190 
    191     // create fDataSort
    192     CreateDataSort();
     335    }
     336
     337    fRules=new MDataArray(); fRules->Reset();
     338    TString target_rule;
     339
     340    for(Int_t i=0;i<dim+1;i++)
     341    {
     342        MData &data=(*allrules)[i];
     343        if(i<dim)
     344            fRules->AddEntry(data.GetRule());
     345        else
     346            target_rule=data.GetRule();
     347    }
     348
     349    *fLog << inf <<endl;
     350    *fLog << inf <<"Setting up RF for training on target:"<<endl<<" "<<target_rule.Data()<<endl;
     351    *fLog << inf <<"Following rules are used as input to RF:"<<endl;
     352
     353    for(Int_t i=0;i<dim;i++)
     354    {
     355        MData &data=(*fRules)[i];
     356        *fLog<<inf<<" "<<i<<") "<<data.GetRule()<<endl<<flush;
     357    }
     358
     359    *fLog << inf <<endl;
     360
     361    //-------------------------------------------------------------------
     362    // prepare (sort) data for fast optimization algorithm
     363    if(!CreateDataSort()) return kFALSE;
     364
     365    //-------------------------------------------------------------------
     366    // access and init tree container
     367    fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
    193368
    194369    if(!fRanTree)
     
    197372        return kFALSE;
    198373    }
    199     fRanTree->SetRules(fGammas->GetColumns());
     374
     375    fRanTree->SetClassify(fClassify);
     376    fRanTree->SetNdSize(fNdSize);
     377
     378    if(fNumTry==0)
     379    {
     380        double ddim = double(dim);
     381
     382        fNumTry=int(sqrt(ddim)+0.5);
     383        *fLog<<inf<<endl;
     384        *fLog<<inf<<"Set no. of trials to the recommended value of round("<<sqrt(ddim)<<") = ";
     385        *fLog<<inf<<fNumTry<<endl;
     386
     387    }
     388    fRanTree->SetNumTry(fNumTry);
     389
     390    *fLog<<inf<<endl;
     391    *fLog<<inf<<"Following settings for the tree growing are used:"<<endl;
     392    *fLog<<inf<<" Number of Trees : "<<fNumTrees<<endl;
     393    *fLog<<inf<<" Number of Trials: "<<fNumTry<<endl;
     394    *fLog<<inf<<" Final Node size : "<<fNdSize<<endl;
     395
    200396    fTreeNo=0;
    201397
    202398    return kTRUE;
    203 }
    204 
    205 void MRanForest::InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag)
    206 {
    207     for (Int_t ievt=from;ievt<to;ievt++)
    208     {
    209         if (jinbag[ievt]>0)
    210             continue;
    211         fHadEst[ievt] += fRanTree->TreeHad(m, ievt-from);
    212         fNTimesOutBag[ievt]++;
    213     }
    214399}
    215400
     
    224409    fTreeNo++;
    225410
     411    //-------------------------------------------------------------------
    226412    // initialize running output
     413
     414    float minfloat=fHadTrue[TMath::LocMin(fHadTrue.GetSize(),fHadTrue.GetArray())];
     415    Bool_t calcResolution=(minfloat>0.001);
     416
    227417    if (fTreeNo==1)
    228418    {
    229419        *fLog << inf << endl;
    230         *fLog << underline; // << "1st col        2nd col" << endl;
    231         *fLog << "no. of tree    error in % (calulated using oob-data -> overestim. of error)" << endl;
     420        *fLog << underline;
     421
     422        if(calcResolution)
     423            *fLog << "no. of tree    no. of nodes    resolution in % (from oob-data -> overest. of error)" << endl;
     424        else
     425            *fLog << "no. of tree    no. of nodes    rms in % (from oob-data -> overest. of error)" << endl;
     426                     //        12345678901234567890123456789012345678901234567890
    232427    }
    233428
    234429    const Int_t numdata = GetNumData();
    235 
     430    const Int_t nclass  = GetNclass();
     431
     432    //-------------------------------------------------------------------
    236433    // bootstrap aggregating (bagging) -> sampling with replacement:
    237     //
    238     // The integer k is randomly (uniformly) chosen from the set
    239     // {0,1,...,fNumData-1}, which is the set of the index numbers of
    240     // all events in the training sample
    241     TArrayF classpopw(2);
     434
     435    TArrayF classpopw(nclass);
    242436    TArrayI jinbag(numdata); // Initialization includes filling with 0
    243437    TArrayF winbag(numdata); // Initialization includes filling with 0
    244438
     439    float square=0; float mean=0;
     440
    245441    for (Int_t n=0; n<numdata; n++)
    246442    {
     443        // The integer k is randomly (uniformly) chosen from the set
     444        // {0,1,...,numdata-1}, which is the set of the index numbers of
     445        // all events in the training sample
     446 
    247447        const Int_t k = Int_t(gRandom->Rndm()*numdata);
    248448
    249         classpopw[fHadTrue[k]]+=fWeight[k];
     449        if(fClassify)
     450            classpopw[fClass[k]]+=fWeight[k];
     451        else
     452            classpopw[0]+=fWeight[k];
     453
     454        mean  +=fHadTrue[k]*fWeight[k];
     455        square+=fHadTrue[k]*fHadTrue[k]*fWeight[k];
     456 
    250457        winbag[k]+=fWeight[k];
    251458        jinbag[k]=1;
    252     }
    253 
     459
     460    }
     461
     462    //-------------------------------------------------------------------
    254463    // modifying sorted-data array for in-bag data:
    255     //
     464
    256465    // In bagging procedure ca. 2/3 of all elements in the original
    257466    // training sample are used to build the in-bag data
     
    261470    ModifyDataSort(datsortinbag, ninbag, jinbag);
    262471
    263     const TMatrix &hadrons=fHadrons->GetM();
    264     const TMatrix &gammas =fGammas->GetM();
    265 
    266     // growing single tree
    267     fRanTree->GrowTree(hadrons,gammas,fHadTrue,datsortinbag,
    268                        fDataRang,classpopw,jinbag,winbag);
    269 
     472    fRanTree->GrowTree(fMatrix,fHadTrue,fClass,datsortinbag,fDataRang,classpopw,mean,square,
     473                       jinbag,winbag,nclass);
     474
     475    //-------------------------------------------------------------------
    270476    // error-estimates from out-of-bag data (oob data):
    271477    //
     
    277483    // determined from oob-data is underestimated, but can still be taken as upper limit.
    278484
    279     const Int_t numhad = hadrons.GetNrows();
    280     InitHadEst(0, numhad, hadrons, jinbag);
    281     InitHadEst(numhad, numdata, gammas, jinbag);
    282     /*
    283     for (Int_t ievt=0;ievt<numhad;ievt++)
    284     {
    285         if (jinbag[ievt]>0)
    286             continue;
    287         fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt);
     485    for (Int_t ievt=0;ievt<numdata;ievt++)
     486    {
     487        if (jinbag[ievt]>0) continue;
     488
     489        fHadEst[ievt] +=fRanTree->TreeHad((*fMatrix), ievt);
    288490        fNTimesOutBag[ievt]++;
    289     }
    290 
    291     for (Int_t ievt=numhad;ievt<numdata;ievt++)
    292     {
    293         if (jinbag[ievt]>0)
    294             continue;
    295         fHadEst[ievt] += fRanTree->TreeHad(gammas, ievt-numhad);
    296         fNTimesOutBag[ievt]++;
    297     }
    298     */
     491
     492    }
     493
    299494    Int_t n=0;
    300     Double_t ferr=0;
     495    double ferr=0;
     496
    301497    for (Int_t ievt=0;ievt<numdata;ievt++)
    302         if (fNTimesOutBag[ievt]!=0)
     498    {
     499        if(fNTimesOutBag[ievt]!=0)
    303500        {
    304             const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt];
     501            float val = fHadEst[ievt]/float(fNTimesOutBag[ievt])-fHadTrue[ievt];
     502            if(calcResolution) val/=fHadTrue[ievt];
     503
    305504            ferr += val*val;
    306505            n++;
    307506        }
    308 
     507    }
    309508    ferr = TMath::Sqrt(ferr/n);
    310509
     510    //-------------------------------------------------------------------
    311511    // give running output
    312     *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl;
     512    *fLog << inf << setw(5)  << fTreeNo;
     513    *fLog << inf << setw(20) << fRanTree->GetNumEndNodes();
     514    *fLog << inf << Form("%20.2f", ferr*100.);
     515    *fLog << inf << endl;
    313516
    314517    // adding tree to forest
     
    318521}
    319522
    320 void MRanForest::CreateDataSort()
    321 {
    322     // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData,
    323     // which does actually not exist) are sorted from lowest to highest.
     523Bool_t MRanForest::CreateDataSort()
     524{
     525    // fDataSort(m,n) is the event number in which fMatrix(m,n) occurs.
     526    // fDataRang(m,n) is the rang of fMatrix(m,n), i.e. if rang = r:
     527    //   fMatrix(m,n) is the r-th highest value of all fMatrix(m,.).
    324528    //
    325     //
    326     //                   fHadrons(0,0) ... fHadrons(0,nhad-1)   fGammas(0,0) ... fGammas(0,ngam-1)
    327     //                        .                 .                   .                .
    328     //  fData(m,n)   =        .                 .                   .                .
    329     //                        .                 .                   .                .
    330     //                   fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1)
    331     //
    332     //
    333     // Then fDataSort(m,n) is the event number in which fData(m,n) occurs.
    334     // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest
    335     // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging).
     529    // There may be more then 1 event with rang r (due to bagging).
     530
    336531    const Int_t numdata = GetNumData();
     532    const Int_t dim = GetNumDim();
    337533
    338534    TArrayF v(numdata);
    339535    TArrayI isort(numdata);
    340536
    341     const TMatrix &hadrons=fHadrons->GetM();
    342     const TMatrix &gammas=fGammas->GetM();
    343 
    344     const Int_t numgam = gammas.GetNrows();
    345     const Int_t numhad = hadrons.GetNrows();
    346 
    347     for (Int_t j=0;j<numhad;j++)
    348         fHadTrue[j]=1;
    349 
    350     for (Int_t j=0;j<numgam;j++)
    351         fHadTrue[j+numhad]=0;
    352 
    353     const Int_t dim = GetNumDim();
     537
    354538    for (Int_t mvar=0;mvar<dim;mvar++)
    355539    {
    356         for(Int_t n=0;n<numhad;n++)
     540
     541        for(Int_t n=0;n<numdata;n++)
    357542        {
    358             v[n]=hadrons(n,mvar);
     543            v[n]=(*fMatrix)(n,mvar);
    359544            isort[n]=n;
    360         }
    361 
    362         for(Int_t n=0;n<numgam;n++)
    363         {
    364             v[n+numhad]=gammas(n,mvar);
    365             isort[n+numhad]=n;
     545
     546            if(TMath::IsNaN(v[n]))
     547            {
     548                *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
     549                *fLog << err <<" has the value NaN."<<endl;
     550                return kFALSE;
     551            }
    366552        }
    367553
     
    371557        // of that v[n], which is the n-th from the lowest (assume the original
    372558        // event numbers are 0,1,...).
     559
     560        // control sorting
     561        for(int n=1;n<numdata;n++)
     562            if(v[isort[n-1]]>v[isort[n]])
     563            {
     564                *fLog << err <<"Event no. "<<n<<", matrix column no. "<<mvar;
     565                *fLog << err <<" not at correct sorting position."<<endl;
     566                return kFALSE;
     567            }
    373568
    374569        for(Int_t n=0;n<numdata-1;n++)
     
    388583        fDataSort[(mvar+1)*numdata-1]=isort[numdata-1];
    389584    }
     585    return kTRUE;
    390586}
    391587
  • trunk/MagicSoft/Mars/mranforest/MRanForest.h

    r7170 r7396  
    44#ifndef MARS_MParContainer
    55#include "MParContainer.h"
    6 #endif
    7 
    8 #ifndef MARS_MRanTree
    9 #include "MRanTree.h"
    10 #endif
    11 
    12 #ifndef MARS_MDataArray
    13 #include "MDataArray.h"
    146#endif
    157
     
    2618#endif
    2719
    28 #ifndef ROOT_TObjArray
    29 #include <TObjArray.h>
    30 #endif
     20class TMatrix;
     21class TVector;
     22class TObjArray;
    3123
    32 #ifndef ROOT_TRandom
    33 #include <TRandom.h>
    34 #endif
    35 
     24class MRanTree;
     25class MDataArray;
    3626class MHMatrix;
    37 class MRanTree;
    38 class TVector;
    39 class TMatrix;
     27class MParList;
    4028
    4129class MRanForest : public MParContainer
    4230{
    4331private:
    44     Int_t fNumTrees;
    45     Int_t fTreeNo;      //!
     32    Int_t fClassify;
    4633
    47     MRanTree *fRanTree; //!
    48     TObjArray *fForest;
     34    Int_t fNumTrees;       // Number of trees
     35    Int_t fNumTry;         // Number of tries
     36    Int_t fNdSize;         // Size of node
     37
     38    Int_t fTreeNo;         //! Number of tree
     39
     40    MRanTree   *fRanTree;  //! Pointer to some tree
     41    MDataArray *fRules;    //! Pointer to corresponding rules
     42    TObjArray  *fForest;   //  Array containing forest
    4943
    5044    // training data
    51     MHMatrix *fGammas;  //!
    52     MHMatrix *fHadrons; //!
     45    TMatrix *fMatrix;      //!
    5346
    5447    // true  and estimated hadronness
    55     TArrayI fHadTrue;   //!
    56     TArrayF fHadEst;    //!
     48    TArrayI fClass;        //!
     49    TArrayD fGrid;         //!
     50    TArrayF fHadTrue;      //!
     51    TArrayF fHadEst;       //!
    5752
    5853    // data sorted according to parameters
    59     TArrayI fDataSort;  //!
    60     TArrayI fDataRang;  //!
    61     TArrayI fClassPop;  //!
     54    TArrayI fDataSort;     //!
     55    TArrayI fDataRang;     //!
     56    TArrayI fClassPop;     //!
    6257
    6358    // weights
    64     Bool_t  fUsePriors; //!
    65     TArrayF fPrior;     //!
    66     TArrayF fWeight;    //!
    67     TArrayI fNTimesOutBag;//!
     59    TArrayF fWeight;       //!
     60    TArrayI fNTimesOutBag; //!
    6861
    6962    // estimates for classification error of growing forest
    70     TArrayD fTreeHad;   //
     63    TArrayD fTreeHad;      // Hadronness values
    7164
    72     Double_t fUserVal;
    73 
    74     void InitHadEst(Int_t from, Int_t to, const TMatrix &m, TArrayI &jinbag);
     65    Double_t fUserVal;     // A user value describing this tree (eg E-mc)
    7566
    7667protected:
    7768    // create and modify (->due to bagging) fDataSort
    78     void CreateDataSort();
     69    Bool_t CreateDataSort();
    7970    void ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag);
    8071
    8172public:
    8273    MRanForest(const char *name=NULL, const char *title=NULL);
     74    MRanForest(const MRanForest &rf);
     75
    8376    ~MRanForest();
    8477
    85     // initialize forest
    86     void SetPriors(Float_t prior_had, Float_t prior_gam);
     78    void SetGrid(const TArrayD &grid);
     79    void SetWeights(const TArrayF &weights);
    8780    void SetNumTrees(Int_t n);
    8881
    89     // tree growing
    90     //void   SetupForest();
    91     Bool_t SetupGrow(MHMatrix *mhad,MHMatrix *mgam);
     82    void SetNumTry(Int_t n);
     83    void SetNdSize(Int_t n);
     84
     85    void SetClassify(Int_t n){ fClassify=n; }
     86    void PrepareClasses();
     87
     88        // tree growing
     89    Bool_t SetupGrow(MHMatrix *mat,MParList *plist);
    9290    Bool_t GrowForest();
    93     void SetCurTree(MRanTree *rantree) { fRanTree=rantree; }
     91    void   SetCurTree(MRanTree *rantree) { fRanTree=rantree; }
    9492    Bool_t AddTree(MRanTree *rantree);
    95     void SetUserVal(Double_t d) { fUserVal = d; }
     93    void   SetUserVal(Double_t d) { fUserVal = d; }
    9694
    9795    // getter methods
    9896    TObjArray  *GetForest()      { return fForest; }
    9997    MRanTree   *GetCurTree()     { return fRanTree; }
    100     MRanTree   *GetTree(Int_t i) { return (MRanTree*)(fForest->At(i)); }
    101     MDataArray *GetRules() { return ((MRanTree*)(fForest->At(0)))->GetRules(); }
     98    MRanTree   *GetTree(Int_t i);
     99    MDataArray *GetRules()       { return fRules; }
     100
    102101
    103102    Int_t      GetNumTrees() const { return fNumTrees; }
    104103    Int_t      GetNumData()  const;
    105104    Int_t      GetNumDim()   const;
     105    Int_t      GetNclass()   const;
    106106    Double_t   GetTreeHad(Int_t i) const { return fTreeHad.At(i); }
    107107    Double_t   GetUserVal() const { return fUserVal; }
     
    109109    // use forest to calculate hadronness of event
    110110    Double_t CalcHadroness(const TVector &event);
     111    Double_t CalcHadroness();
    111112
    112113    Bool_t AsciiWrite(ostream &out) const;
    113114
    114     ClassDef(MRanForest, 2) // Storage container for tree structure
     115    ClassDef(MRanForest, 1) // Storage container for tree structure
    115116};
    116117
  • trunk/MagicSoft/Mars/mranforest/MRanForestGrow.cc

    r7130 r7396  
    1616!
    1717!
    18 !   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
     18!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
    1919!
    20 !   Copyright: MAGIC Software Development, 2000-2003
     20!   Copyright: MAGIC Software Development, 2000-2005
    2121!
    2222!
     
    2424
    2525/////////////////////////////////////////////////////////////////////////////
    26 //                                                                         //
    27 //  MRanForestGrow                                                         //
    28 //                                                                         //
    29 //  Grows a random forest.                                                 //
    30 //                                                                         //
     26//
     27//  MRanForestGrow
     28//
     29//  Grows a random forest.
     30//
    3131/////////////////////////////////////////////////////////////////////////////
    3232#include "MRanForestGrow.h"
     
    3838
    3939#include "MParList.h"
    40 
    41 #include "MRanTree.h"
    4240#include "MRanForest.h"
    4341
     
    4644using namespace std;
    4745
    48 static const TString gsDefName  = "MRead";
    49 static const TString gsDefTitle = "Tree Classification Loop 1/2";
     46const TString MRanForestGrow::gsDefName  = "MRead";
     47const TString MRanForestGrow::gsDefTitle = "Task to train a random forst";
    5048
    51 // --------------------------------------------------------------------------
    52 //
    53 // Setup histograms and the number of distances which are used for
    54 // avaraging in CalcDist
    55 //
    5649MRanForestGrow::MRanForestGrow(const char *name, const char *title)
    5750{
    58     //
    5951    //   set the name and title of this object
    60     //
     52
    6153    fName  = name  ? name  : gsDefName.Data();
    6254    fTitle = title ? title : gsDefTitle.Data();
    6355
    64     SetNumTrees();
    65     SetNumTry();
    66     SetNdSize();
     56    //     SetNumTrees();
     57    //     SetNumTry();
     58    //     SetNdSize();
    6759}
    6860
    69 // --------------------------------------------------------------------------
    70 //
    71 // Needs:
    72 //  - MatrixGammas  [MHMatrix]
    73 //  - MatrixHadrons {MHMatrix]
    74 //  - MHadroness
    75 //  - all data containers used to build the matrixes
    76 //
    77 // The matrix object can be filles using MFillH. And must be of the same
    78 // number of columns (with the same meaning).
    79 //
    8061Int_t MRanForestGrow::PreProcess(MParList *plist)
    8162{
    82     fMGammas = (MHMatrix*)plist->FindObject("MatrixGammas", "MHMatrix");
    83     if (!fMGammas)
     63    fMatrix = (MHMatrix*)plist->FindObject("MatrixTrain", "MHMatrix");
     64    if (!fMatrix)
    8465    {
    85         *fLog << err << dbginf << "MatrixGammas [MHMatrix] not found... aborting." << endl;
    86         return kFALSE;
    87     }
    88 
    89     fMHadrons = (MHMatrix*)plist->FindObject("MatrixHadrons", "MHMatrix");
    90     if (!fMHadrons)
    91     {
    92         *fLog << err << dbginf << "MatrixHadrons [MHMatrix] not found... aborting." << endl;
    93         return kFALSE;
    94     }
    95 
    96     if (fMGammas->GetM().GetNcols() != fMHadrons->GetM().GetNcols())
    97     {
    98         *fLog << err << dbginf << "Error matrices have different numbers of columns... aborting." << endl;
    99         return kFALSE;
    100     }
    101 
    102     fRanTree = (MRanTree*)plist->FindCreateObj("MRanTree");
    103     if (!fRanTree)
    104     {
    105         *fLog << err << dbginf << "MRanTree not found... aborting." << endl;
     66        *fLog << err << dbginf << "MatrixTrain [MHMatrix] not found... aborting." << endl;
    10667        return kFALSE;
    10768    }
     
    11475    }
    11576
    116     fRanTree->SetNumTry(fNumTry);
    117     fRanTree->SetNdSize(fNdSize);
    118     fRanForest->SetCurTree(fRanTree);
    119     fRanForest->SetNumTrees(fNumTrees);
     77    //     fRanForest->SetNumTry(fNumTry);
     78    //     fRanForest->SetNdSize(fNdSize);
     79    //     fRanForest->SetNumTrees(fNumTrees);
    12080
    121     return fRanForest->SetupGrow(fMHadrons,fMGammas);
     81    return fRanForest->SetupGrow(fMatrix,plist);
    12282}
    12383
    124 // --------------------------------------------------------------------------
    125 //
    126 //
    12784Int_t MRanForestGrow::Process()
    12885{
    129     const Bool_t not_last=fRanForest->GrowForest();
    130 
    131     fRanTree->SetReadyToSave();
    132 
    133     return not_last;
     86    return fRanForest->GrowForest();
    13487}
    13588
    13689Int_t MRanForestGrow::PostProcess()
    13790{
    138     fRanTree->SetReadyToSave();
    13991    fRanForest->SetReadyToSave();
    14092
  • trunk/MagicSoft/Mars/mranforest/MRanForestGrow.h

    r7130 r7396  
    99class MParList;
    1010class MRanForest;
    11 class MRanTree;
    1211
    1312class MRanForestGrow : public MRead
    1413{
    1514private:
    16     MRanTree   *fRanTree;
     15    static const TString gsDefName;
     16    static const TString gsDefTitle;
     17
     18    //     Int_t fNumTrees;
     19    //     Int_t fNumTry;
     20    //     Int_t fNdSize;
     21
    1722    MRanForest *fRanForest;
    18     MHMatrix   *fMGammas;   //! Gammas describing matrix
    19     MHMatrix   *fMHadrons;  //! Hadrons (non gammas) describing matrix
    20 
    21     Int_t fNumTrees;
    22     Int_t fNumTry;
    23     Int_t fNdSize;
     23    MHMatrix   *fMatrix;   //! matrix with events
    2424
    2525    Int_t PreProcess(MParList *pList);
     
    3434    MRanForestGrow(const char *name=NULL, const char *title=NULL);
    3535
    36     void SetNumTrees(Int_t n=-1) { fNumTrees=n>0?n:100; }
    37     void SetNumTry(Int_t   n=-1) { fNumTry  =n>0?n:  3; }
    38     void SetNdSize(Int_t   n=-1) { fNdSize  =n>0?n:  1; }
     36    //     void SetNumTrees(Int_t n=-1) { fNumTrees=n>0?n:100; }
     37    //     void SetNumTry(Int_t   n=-1) { fNumTry  =n>0?n:  3; }
     38    //     void SetNdSize(Int_t   n=-1) { fNdSize  =n>0?n:  1; }
    3939
    4040    ClassDef(MRanForestGrow, 0) // Task to grow a random forest
  • trunk/MagicSoft/Mars/mranforest/MRanTree.cc

    r7142 r7396  
    1616!
    1717!
    18 !   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
     18!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
    1919!
    20 !   Copyright: MAGIC Software Development, 2000-2003
     20!   Copyright: MAGIC Software Development, 2000-2005
    2121!
    2222!
     
    3838#include <TRandom.h>
    3939
    40 #include "MDataArray.h"
    41 
    4240#include "MLog.h"
    4341#include "MLogManip.h"
     
    4745using namespace std;
    4846
     47
    4948// --------------------------------------------------------------------------
    50 //
    5149// Default constructor.
    5250//
    53 MRanTree::MRanTree(const char *name, const char *title):fNdSize(0), fNumTry(3), fData(NULL)
     51MRanTree::MRanTree(const char *name, const char *title):fClassify(1),fNdSize(0), fNumTry(3)
    5452{
    5553
    5654    fName  = name  ? name  : "MRanTree";
    5755    fTitle = title ? title : "Storage container for structure of a single tree";
     56}
     57
     58// --------------------------------------------------------------------------
     59// Copy constructor
     60//
     61MRanTree::MRanTree(const MRanTree &tree)
     62{
     63    fName  = tree.fName;
     64    fTitle = tree.fTitle;
     65
     66    fClassify = tree.fClassify;
     67    fNdSize   = tree.fNdSize;
     68    fNumTry   = tree.fNumTry;
     69
     70    fNumNodes    = tree.fNumNodes;
     71    fNumEndNodes = tree.fNumEndNodes;
     72
     73    fBestVar   = tree.fBestVar;
     74    fTreeMap1  = tree.fTreeMap1;
     75    fTreeMap2  = tree.fTreeMap2;
     76    fBestSplit = tree.fBestSplit;
     77    fGiniDec   = tree.fGiniDec;
    5878}
    5979
     
    7595}
    7696
    77 void MRanTree::GrowTree(const TMatrix &mhad, const TMatrix &mgam,
    78                         const TArrayI &hadtrue, TArrayI &datasort,
    79                         const TArrayI &datarang, TArrayF &tclasspop, TArrayI &jinbag,
    80                         const TArrayF &winbag)
     97void MRanTree::GrowTree(TMatrix *mat, const TArrayF &hadtrue, const TArrayI &idclass,
     98                        TArrayI &datasort, const TArrayI &datarang, TArrayF &tclasspop,
     99                        float &mean, float &square, TArrayI &jinbag, const TArrayF &winbag,
     100                        const int nclass)
    81101{
    82102    // arrays have to be initialized with generous size, so number of total nodes (nrnodes)
    83103    // is estimated for worst case
    84     const Int_t numdim =mhad.GetNcols();
     104    const Int_t numdim =mat->GetNcols();
    85105    const Int_t numdata=winbag.GetSize();
    86106    const Int_t nrnodes=2*numdata+1;
     
    88108    // number of events in bootstrap sample
    89109    Int_t ninbag=0;
    90     for (Int_t n=0;n<numdata;n++)
    91         if(jinbag[n]==1) ninbag++;
    92 
    93     TArrayI bestsplit(nrnodes);
    94     TArrayI bestsplitnext(nrnodes);
    95 
    96     fBestVar.Set(nrnodes);
    97     fTreeMap1.Set(nrnodes);
    98     fTreeMap2.Set(nrnodes);
    99     fBestSplit.Set(nrnodes);
    100 
    101     fTreeMap1.Reset();
    102     fTreeMap2.Reset();
    103     fBestSplit.Reset();
    104 
    105     fGiniDec.Set(numdim);
    106     fGiniDec.Reset();
     110    for (Int_t n=0;n<numdata;n++) if(jinbag[n]==1) ninbag++;
     111
     112    TArrayI bestsplit(nrnodes);      bestsplit.Reset(0);
     113    TArrayI bestsplitnext(nrnodes);  bestsplitnext.Reset(0);
     114
     115    fBestVar.Set(nrnodes);    fBestVar.Reset(0);
     116    fTreeMap1.Set(nrnodes);   fTreeMap1.Reset(0);
     117    fTreeMap2.Set(nrnodes);   fTreeMap2.Reset(0);
     118    fBestSplit.Set(nrnodes);  fBestSplit.Reset(0);
     119    fGiniDec.Set(numdim);     fGiniDec.Reset(0);
     120
     121
     122    if(fClassify)
     123        FindBestSplit=&MRanTree::FindBestSplitGini;
     124    else
     125        FindBestSplit=&MRanTree::FindBestSplitSigma;
    107126
    108127    // tree growing
    109     BuildTree(datasort,datarang,hadtrue,bestsplit,
    110               bestsplitnext,tclasspop,winbag,ninbag);
     128    BuildTree(datasort,datarang,hadtrue,idclass,bestsplit, bestsplitnext,
     129              tclasspop,mean,square,winbag,ninbag,nclass);
    111130
    112131    // post processing, determine cut (or split) values fBestSplit
    113     Int_t nhad=mhad.GetNrows();
    114 
    115132    for(Int_t k=0; k<nrnodes; k++)
    116133    {
     
    122139        const Int_t &msp =fBestVar[k];
    123140
    124         fBestSplit[k]  = bsp<nhad  ? mhad(bsp, msp):mgam(bsp-nhad, msp);
    125         fBestSplit[k] += bspn<nhad ? mhad(bspn,msp):mgam(bspn-nhad,msp);
    126         fBestSplit[k] /= 2;
     141        fBestSplit[k]  = (*mat)(bsp, msp);
     142        fBestSplit[k] += (*mat)(bspn,msp);
     143        fBestSplit[k] /= 2.;
    127144    }
    128145
     
    134151}
    135152
    136 Int_t MRanTree::FindBestSplit(const TArrayI &datasort,const TArrayI &datarang,
    137                               const TArrayI &hadtrue,Int_t ndstart,Int_t ndend,TArrayF &tclasspop,
    138                               Int_t &msplit,Float_t &decsplit,Int_t &nbest,
    139                               const TArrayF &winbag)
     153int MRanTree::FindBestSplitGini(const TArrayI &datasort,const TArrayI &datarang,
     154                                const TArrayF &hadtrue,const TArrayI &idclass,
     155                                Int_t ndstart,Int_t ndend, TArrayF &tclasspop,
     156                                float &mean, float &square, Int_t &msplit,
     157                                Float_t &decsplit,Int_t &nbest, const TArrayF &winbag,
     158                                const int nclass)
    140159{
    141160    const Int_t nrnodes = fBestSplit.GetSize();
     
    143162    const Int_t mdim = fGiniDec.GetSize();
    144163
    145     // weighted class populations after split
    146     TArrayF wc(2);
    147     TArrayF wr(2); // right node
    148 
    149     // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
     164    TArrayF wr(nclass); wr.Reset(0);// right node
     165
     166    // For the best split, msplit is the index of the variable (e.g Hillas par.,
     167    // zenith angle ,...)
    150168    // split on. decsplit is the decreae in impurity measured by Gini-index.
    151169    // nsplit is the case number of value of msplit split on,
     
    158176    Double_t pno=0;
    159177    Double_t pdo=0;
    160     for (Int_t j=0; j<2; j++)
     178
     179    for (Int_t j=0; j<nclass; j++)
    161180    {
    162181        pno+=tclasspop[j]*tclasspop[j];
     
    165184
    166185    const Double_t crit0=pno/pdo;
    167     Int_t jstat=0;
    168186
    169187    // start main loop through variables to find best split,
     
    184202        Double_t rld=0;
    185203
    186         TArrayF wl(2); // left node
     204        TArrayF wl(nclass); wl.Reset(0.);// left node //nclass
    187205        wr = tclasspop;
    188206
    189207        Double_t critvar=-1.0e20;
    190 
     208        for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
     209        {
     210            const Int_t  &nc = datasort[mn+nsp];
     211            const Int_t   &k = idclass[nc];
     212            const Float_t &u = winbag[nc];
     213
     214            // do classification, Gini index as split rule
     215            rln+=u*(2*wl[k]+u);
     216            rrn+=u*(-2*wr[k]+u);
     217
     218            rld+=u;
     219            rrd-=u;
     220
     221            wl[k]+=u;
     222            wr[k]-=u;
     223
     224            if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]])
     225                continue;
     226
     227            if (TMath::Min(rrd,rld)<=1.0e-5)
     228                continue;
     229
     230            const Double_t crit=(rln/rld)+(rrn/rrd);
     231
     232
     233            if (crit<=critvar) continue;
     234
     235            nbestvar=nsp;
     236            critvar=crit;
     237        }
     238
     239        if (critvar<=critmax) continue;
     240
     241        msplit=mvar;
     242        nbest=nbestvar;
     243        critmax=critvar;
     244    }
     245
     246    decsplit=critmax-crit0;
     247
     248    return critmax<-1.0e10 ? 1 : 0;
     249}
     250
     251int MRanTree::FindBestSplitSigma(const TArrayI &datasort,const TArrayI &datarang,
     252                                 const TArrayF &hadtrue, const TArrayI &idclass,
     253                                 Int_t ndstart,Int_t ndend, TArrayF &tclasspop,
     254                                 float &mean, float &square, Int_t &msplit,
     255                                 Float_t &decsplit,Int_t &nbest, const TArrayF &winbag,
     256                                 const int nclass)
     257{
     258    const Int_t nrnodes = fBestSplit.GetSize();
     259    const Int_t numdata = (nrnodes-1)/2;
     260    const Int_t mdim = fGiniDec.GetSize();
     261
     262    float wr=0;// right node
     263
     264    // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
     265    // split on. decsplit is the decreae in impurity measured by Gini-index.
     266    // nsplit is the case number of value of msplit split on,
     267    // and nsplitnext is the case number of the next larger value of msplit.
     268
     269    Int_t nbestvar=0;
     270
     271    // compute initial values of numerator and denominator of split-index,
     272
     273    // resolution
     274    //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0];
     275    //Double_t pdo= (tclasspop[0]-1.)*mean*mean;
     276
     277    // n*resolution
     278    //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0];
     279    //Double_t pdo= mean*mean;
     280
     281    // variance
     282    //Double_t pno=-(square-mean*mean/tclasspop[0]);
     283    //Double_t pdo= (tclasspop[0]-1.);
     284
     285    // n*variance
     286    Double_t pno= (square-mean*mean/tclasspop[0]);
     287    Double_t pdo= 1.;
     288
     289    // 1./(n*variance)
     290    //Double_t pno= 1.;
     291    //Double_t pdo= (square-mean*mean/tclasspop[0]);
     292
     293    const Double_t crit0=pno/pdo;
     294
     295    // start main loop through variables to find best split,
     296
     297    Double_t critmin=1.0e40;
     298
     299    // random split selection, number of trials = fNumTry
     300    for (Int_t mt=0; mt<fNumTry; mt++)
     301    {
     302        const Int_t mvar=Int_t(gRandom->Rndm()*mdim);
     303        const Int_t mn  = mvar*numdata;
     304
     305        Double_t rrn=0, rrd=0, rln=0, rld=0;
     306        Double_t esumr=0, esuml=0, e2sumr=0,e2suml=0;
     307
     308        esumr =mean;
     309        e2sumr=square;
     310        esuml =0;
     311        e2suml=0;
     312
     313        float wl=0.;// left node
     314        wr = tclasspop[0];
     315
     316        Double_t critvar=critmin;
    191317        for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
    192318        {
    193319            const Int_t &nc=datasort[mn+nsp];
    194             const Int_t &k=hadtrue[nc];
    195 
     320            const Float_t &f=hadtrue[nc];;
    196321            const Float_t &u=winbag[nc];
    197322
    198             rln+=u*(2*wl[k]+u);
    199             rrn+=u*(-2*wr[k]+u);
    200             rld+=u;
    201             rrd-=u;
    202 
    203             wl[k]+=u;
    204             wr[k]-=u;
     323            e2sumr-=u*f*f;
     324            esumr -=u*f;
     325            wr    -=u;
     326
     327            //-------------------------------------------
     328            // resolution
     329            //rrn=(wr*e2sumr-esumr*esumr)*wr;
     330            //rrd=(wr-1.)*esumr*esumr;
     331
     332            // resolution times n
     333            //rrn=(wr*e2sumr-esumr*esumr)*wr;
     334            //rrd=esumr*esumr;
     335
     336            // sigma
     337            //rrn=(e2sumr-esumr*esumr/wr);
     338            //rrd=(wr-1.);
     339
     340            // sigma times n
     341            rrn=(e2sumr-esumr*esumr/wr);
     342            rrd=1.;
     343
     344            // 1./(n*variance)
     345            //rrn=1.;
     346            //rrd=(e2sumr-esumr*esumr/wr);
     347            //-------------------------------------------
     348
     349            e2suml+=u*f*f;
     350            esuml +=u*f;
     351            wl    +=u;
     352
     353            //-------------------------------------------
     354            // resolution
     355            //rln=(wl*e2suml-esuml*esuml)*wl;
     356            //rld=(wl-1.)*esuml*esuml;
     357
     358            // resolution times n
     359            //rln=(wl*e2suml-esuml*esuml)*wl;
     360            //rld=esuml*esuml;
     361
     362            // sigma
     363            //rln=(e2suml-esuml*esuml/wl);
     364            //rld=(wl-1.);
     365
     366            // sigma times n
     367            rln=(e2suml-esuml*esuml/wl);
     368            rld=1.;
     369
     370            // 1./(n*variance)
     371            //rln=1.;
     372            //rld=(e2suml-esuml*esuml/wl);
     373            //-------------------------------------------
    205374
    206375            if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]])
    207376                continue;
     377
    208378            if (TMath::Min(rrd,rld)<=1.0e-5)
    209379                continue;
    210380
    211381            const Double_t crit=(rln/rld)+(rrn/rrd);
    212             if (crit<=critvar)
    213                 continue;
     382
     383            if (crit>=critvar) continue;
    214384
    215385            nbestvar=nsp;
     
    217387        }
    218388
    219         if (critvar<=critmax)
    220             continue;
     389        if (critvar>=critmin) continue;
    221390
    222391        msplit=mvar;
    223392        nbest=nbestvar;
    224         critmax=critvar;
    225     }
    226 
    227     decsplit=critmax-crit0;
    228 
    229     return critmax<-1.0e10 ? 1 : jstat;
    230 }
    231 
    232 void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart,
    233                         Int_t ndend,TArrayI &idmove,TArrayI &ncase,Int_t msplit,
     393        critmin=critvar;
     394    }
     395
     396    decsplit=crit0-critmin;
     397
     398    //return critmin>1.0e20 ? 1 : 0;
     399    return decsplit<0 ? 1 : 0;
     400}
     401
     402void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart, Int_t ndend,
     403                        TArrayI &idmove,TArrayI &ncase,Int_t msplit,
    234404                        Int_t nbest,Int_t &ndendl)
    235405{
     
    240410    const Int_t mdim    = fGiniDec.GetSize();
    241411
    242     TArrayI tdatasort(numdata);
     412    TArrayI tdatasort(numdata); tdatasort.Reset(0);
    243413
    244414    // compute idmove = indicator of case nos. going left
    245 
    246415    for (Int_t nsp=ndstart;nsp<=ndend;nsp++)
    247416    {
     
    252421
    253422    // shift case. nos. right and left for numerical variables.
    254 
    255423    for(Int_t msh=0;msh<mdim;msh++)
    256424    {
     
    280448}
    281449
    282 void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang,
    283                          const TArrayI &hadtrue, TArrayI &bestsplit,
    284                          TArrayI &bestsplitnext, TArrayF &tclasspop,
    285                          const TArrayF &winbag, Int_t ninbag)
     450void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang, const TArrayF &hadtrue,
     451                         const TArrayI &idclass, TArrayI &bestsplit, TArrayI &bestsplitnext,
     452                         TArrayF &tclasspop, float &tmean, float &tsquare, const TArrayF &winbag,
     453                         Int_t ninbag, const int nclass)
    286454{
    287455    // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData.
     
    302470    const Int_t numdata = (nrnodes-1)/2;
    303471
    304     TArrayI nodepop(nrnodes);
    305     TArrayI nodestart(nrnodes);
    306     TArrayI parent(nrnodes);
    307 
    308     TArrayI ncase(numdata);
    309     TArrayI idmove(numdata);
    310     TArrayI iv(mdim);
    311 
    312     TArrayF classpop(nrnodes*2);
    313     TArrayI nodestatus(nrnodes);
    314 
    315     for (Int_t j=0;j<2;j++)
     472    TArrayI nodepop(nrnodes);     nodepop.Reset(0);
     473    TArrayI nodestart(nrnodes);   nodestart.Reset(0);
     474    TArrayI parent(nrnodes);      parent.Reset(0);
     475
     476    TArrayI ncase(numdata);       ncase.Reset(0);
     477    TArrayI idmove(numdata);      idmove.Reset(0);
     478    TArrayI iv(mdim);             iv.Reset(0);
     479
     480    TArrayF classpop(nrnodes*nclass);  classpop.Reset(0.);//nclass
     481    TArrayI nodestatus(nrnodes);       nodestatus.Reset(0);
     482
     483    for (Int_t j=0;j<nclass;j++)
    316484        classpop[j*nrnodes+0]=tclasspop[j];
     485
     486    TArrayF mean(nrnodes);   mean.Reset(0.);
     487    TArrayF square(nrnodes); square.Reset(0.);
     488
     489    mean[0]=tmean;
     490    square[0]=tsquare;
     491
    317492
    318493    Int_t ncur=0;
     
    330505          const Int_t ndstart=nodestart[kbuild];
    331506          const Int_t ndend=ndstart+nodepop[kbuild]-1;
    332           for (Int_t j=0;j<2;j++)
     507
     508          for (Int_t j=0;j<nclass;j++)
    333509              tclasspop[j]=classpop[j*nrnodes+kbuild];
     510
     511          tmean=mean[kbuild];
     512          tsquare=square[kbuild];
    334513
    335514          Int_t msplit, nbest;
    336515          Float_t decsplit=0;
    337           const Int_t jstat=FindBestSplit(datasort,datarang,hadtrue,
    338                                           ndstart,ndend,tclasspop,msplit,
    339                                           decsplit,nbest,winbag);
    340 
    341           if (jstat==1)
     516
     517          if ((*this.*FindBestSplit)(datasort,datarang,hadtrue,idclass,ndstart,
     518                                     ndend, tclasspop,tmean, tsquare,msplit,decsplit,
     519                                     nbest,winbag,nclass))
    342520          {
    343521              nodestatus[kbuild]=-1;
     
    356534
    357535          // leftnode no.= ncur+1, rightnode no. = ncur+2.
    358 
    359536          nodepop[ncur+1]=ndendl-ndstart+1;
    360537          nodepop[ncur+2]=ndend-ndendl;
     
    363540
    364541          // find class populations in both nodes
    365 
    366542          for (Int_t n=ndstart;n<=ndendl;n++)
    367543          {
    368544              const Int_t &nc=ncase[n];
    369               const Int_t &j=hadtrue[nc];
     545              const int j=idclass[nc];
     546                   
     547              mean[ncur+1]+=hadtrue[nc]*winbag[nc];
     548              square[ncur+1]+=hadtrue[nc]*hadtrue[nc]*winbag[nc];
     549
    370550              classpop[j*nrnodes+ncur+1]+=winbag[nc];
    371551          }
     
    374554          {
    375555              const Int_t &nc=ncase[n];
    376               const Int_t &j=hadtrue[nc];
     556              const int j=idclass[nc];
     557
     558              mean[ncur+2]  +=hadtrue[nc]*winbag[nc];
     559              square[ncur+2]+=hadtrue[nc]*hadtrue[nc]*winbag[nc];
     560
    377561              classpop[j*nrnodes+ncur+2]+=winbag[nc];
    378562          }
     
    385569          if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1;
    386570
     571
    387572          Double_t popt1=0;
    388573          Double_t popt2=0;
    389           for (Int_t j=0;j<2;j++)
     574          for (Int_t j=0;j<nclass;j++)
    390575          {
    391576              popt1+=classpop[j*nrnodes+ncur+1];
     
    393578          }
    394579
    395           for (Int_t j=0;j<2;j++)
     580          if(fClassify)
    396581          {
    397               if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
    398               if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
     582              // check if only members of one class in node
     583              for (Int_t j=0;j<nclass;j++)
     584              {
     585                  if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
     586                  if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
     587              }
    399588          }
    400589
     
    421610        {
    422611            fNumEndNodes++;
     612
    423613            Double_t pp=0;
    424             for (Int_t j=0;j<2;j++)
     614            for (Int_t j=0;j<nclass;j++)
    425615            {
    426616                if(classpop[j*nrnodes+kn]>pp)
    427617                {
    428618                    // class + status of node kn coded into fBestVar[kn]
    429                     fBestVar[kn]=j-2;
     619                    fBestVar[kn]=j-nclass;
    430620                    pp=classpop[j*nrnodes+kn];
    431621                }
    432622            }
    433             fBestSplit[kn] =classpop[1*nrnodes+kn];
    434             fBestSplit[kn]/=(classpop[0*nrnodes+kn]+classpop[1*nrnodes+kn]);
     623
     624                float sum=0;
     625                for(int i=0;i<nclass;i++) sum+=classpop[i*nrnodes+kn];
     626
     627                fBestSplit[kn]=mean[kn]/sum;
    435628        }
    436629}
    437630
    438 void MRanTree::SetRules(MDataArray *rules)
    439 {
    440     fData=rules;
    441 }
    442 
    443631Double_t MRanTree::TreeHad(const TVector &event)
    444 {
    445     Int_t kt=0;
    446     // to optimize on storage space node status and node class
    447     // are coded into fBestVar:
    448     // status of node kt = TMath::Sign(1,fBestVar[kt])
    449     // class  of node kt = fBestVar[kt]+2
    450     //  (class defined by larger node population, actually not used)
    451     // hadronness assigned to node kt = fBestSplit[kt]
    452 
    453     for (Int_t k=0;k<fNumNodes;k++)
    454     {
    455         if (fBestVar[kt]<0)
    456             break;
    457 
    458         const Int_t m=fBestVar[kt];
    459         kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
    460     }
    461 
    462     return fBestSplit[kt];
    463 }
    464 
    465 Double_t MRanTree::TreeHad(const TMatrixFRow_const &event)
    466632{
    467633    Int_t kt=0;
     
    485651}
    486652
     653Double_t MRanTree::TreeHad(const TMatrixRow &event)
     654{
     655    Int_t kt=0;
     656    // to optimize on storage space node status and node class
     657    // are coded into fBestVar:
     658    // status of node kt = TMath::Sign(1,fBestVar[kt])
     659    // class  of node kt = fBestVar[kt]+2 (class defined by larger
     660    //  node population, actually not used)
     661    // hadronness assigned to node kt = fBestSplit[kt]
     662
     663    for (Int_t k=0;k<fNumNodes;k++)
     664    {
     665        if (fBestVar[kt]<0)
     666            break;
     667
     668        const Int_t m=fBestVar[kt];
     669        kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
     670    }
     671
     672    return fBestSplit[kt];
     673}
     674
    487675Double_t MRanTree::TreeHad(const TMatrix &m, Int_t ievt)
    488676{
     
    494682}
    495683
    496 Double_t MRanTree::TreeHad()
    497 {
    498     TVector event;
    499     *fData >> event;
    500 
    501     return TreeHad(event);
    502 }
    503 
    504684Bool_t MRanTree::AsciiWrite(ostream &out) const
    505685{
    506     out.width(5);
    507     out << fNumNodes << endl;
    508 
     686    TString str;
    509687    Int_t k;
    510     for (k=0; k<fNumNodes; k++)
    511     {
    512         TString str=Form("%f", GetBestSplit(k));
     688
     689    out.width(5);out<<fNumNodes<<endl;
     690
     691    for (k=0;k<fNumNodes;k++)
     692    {
     693        str=Form("%f",GetBestSplit(k));
    513694
    514695        out.width(5);  out << k;
     
    520701        out.width(5);  out << GetNodeClass(k);
    521702    }
    522     out << endl;
    523 
    524     return kTRUE;
    525 }
     703    out<<endl;
     704
     705    return k==fNumNodes;
     706}
  • trunk/MagicSoft/Mars/mranforest/MRanTree.h

    r7142 r7396  
    1616class TMatrix;
    1717class TMatrixRow;
    18 class TMatrixFRow_const;
    1918class TVector;
    2019class TRandom;
    21 class MDataArray;
    2220
    2321class MRanTree : public MParContainer
    2422{
    2523private:
     24    Int_t fClassify;
    2625    Int_t fNdSize;
    2726    Int_t fNumTry;
     
    2928    Int_t fNumNodes;
    3029    Int_t fNumEndNodes;
    31     MDataArray *fData;
    3230
    3331    TArrayI fBestVar;
     
    3533    TArrayI fTreeMap2;
    3634    TArrayF fBestSplit;
    37 
    3835    TArrayF fGiniDec;
    3936
    40     Int_t FindBestSplit(const TArrayI &datasort, const TArrayI &datarang,
    41                         const TArrayI &hadtrue,
    42                         Int_t ndstart, Int_t ndend, TArrayF &tclasspop,
    43                         Int_t &msplit, Float_t &decsplit, Int_t &nbest,
    44                         const TArrayF &winbag);
     37    int (MRanTree::*FindBestSplit)
     38        (const TArrayI &, const TArrayI &, const TArrayF &, const TArrayI &,
     39         Int_t, Int_t , TArrayF &, float &, float &, Int_t &, Float_t &,
     40         Int_t &, const TArrayF &, const int); //!
     41
     42
     43    int FindBestSplitGini(const TArrayI &datasort, const TArrayI &datarang,
     44                          const TArrayF &hadtrue, const TArrayI &idclass,
     45                          Int_t ndstart, Int_t ndend, TArrayF &tclasspop,
     46                          float &mean, float &square, Int_t &msplit,
     47                          Float_t &decsplit, Int_t &nbest, const TArrayF &winbag,
     48                          const int nclass);
     49
     50    int FindBestSplitSigma(const TArrayI &datasort, const TArrayI &datarang,
     51                           const TArrayF &hadtrue, const TArrayI &idclass,
     52                           Int_t ndstart, Int_t ndend, TArrayF &tclasspop,
     53                           float &mean, float &square, Int_t &msplit,
     54                           Float_t &decsplit, Int_t &nbest, const TArrayF &winbag,
     55                           const int nclass);
    4556
    4657    void MoveData(TArrayI &datasort, Int_t ndstart, Int_t ndend,
     
    4859                  Int_t nbest, Int_t &ndendl);
    4960
    50     void BuildTree(TArrayI &datasort, const TArrayI &datarang,
    51                    const TArrayI &hadtrue,
    52                    TArrayI &bestsplit,TArrayI &bestsplitnext,
    53                    TArrayF &tclasspop,
    54                    const TArrayF &winbag,
    55                    Int_t ninbag);
     61    void BuildTree(TArrayI &datasort, const TArrayI &datarang, const TArrayF &hadtrue,
     62                   const TArrayI &idclass,TArrayI &bestsplit,TArrayI &bestsplitnext,
     63                   TArrayF &tclasspop, float &tmean, float &tsquare, const TArrayF &winbag,
     64                   Int_t ninbag, const int nclass);
    5665
    5766public:
    5867    MRanTree(const char *name=NULL, const char *title=NULL);
     68    MRanTree(const MRanTree &tree);
    5969
    6070    void SetNdSize(Int_t n);
    6171    void SetNumTry(Int_t n);
    62     void SetRules(MDataArray *rules);
    63 
    64     MDataArray *GetRules() { return fData;}
    6572
    6673    Int_t GetNdSize() const { return fNdSize; }
     
    7885    Float_t GetGiniDec(Int_t i)  const { return fGiniDec.At(i); }
    7986
     87    void SetClassify(Int_t n){ fClassify=n; }
     88
    8089    // functions used in tree growing process
    81     void GrowTree(const TMatrix &mhad, const TMatrix &mgam,
    82                   const TArrayI &hadtrue, TArrayI &datasort,
    83                   const TArrayI &datarang,
    84                   TArrayF &tclasspop, TArrayI &jinbag, const TArrayF &winbag);
     90    void GrowTree(TMatrix *mat, const TArrayF &hadtrue, const TArrayI &idclass,
     91                  TArrayI &datasort, const TArrayI &datarang,TArrayF &tclasspop,
     92                  float &mean, float &square, TArrayI &jinbag, const TArrayF &winbag,
     93                  const int nclass);
    8594
    8695    Double_t TreeHad(const TVector &event);
    87     Double_t TreeHad(const TMatrixFRow_const &event);
     96    Double_t TreeHad(const TMatrixRow &event);
    8897    Double_t TreeHad(const TMatrix &m, Int_t ievt);
    89     Double_t TreeHad();
    9098
    9199    Bool_t AsciiWrite(ostream &out) const;
  • trunk/MagicSoft/Mars/mranforest/Makefile

    r6479 r7396  
    2525           MRanForest.cc \
    2626           MRanForestGrow.cc \
    27            MRanForestCalc.cc \
    28            MRanForestFill.cc \
    2927           MHRanForest.cc \
    3028           MHRanForestGini.cc \
  • trunk/MagicSoft/Mars/mranforest/RanForestLinkDef.h

    r6479 r7396  
    88#pragma link C++ class MRanForest+;
    99#pragma link C++ class MRanForestGrow+;
    10 #pragma link C++ class MRanForestCalc+;
    11 #pragma link C++ class MRanForestFill+;   
    1210
    1311#pragma link C++ class MHRanForest+;
Note: See TracChangeset for help on using the changeset viewer.