/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
!
!   Copyright: MAGIC Software Development, 2000-2003
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//                                                                         //
// MRanForest                                                              //
//                                                                         //
// ParameterContainer for Forest structure                                 //
//                                                                         //
// A random forest can be grown by calling GrowForest.                     //
// In advance SetupGrow must be called in order to initialize arrays and   //
// do some preprocessing.                                                  //
// GrowForest() provides the training data for a single tree (bootstrap    //
// aggregate procedure)                                                    //
//                                                                         //
// Essentially two random elements serve to provide a "random" forest,     //
// namely bootstrap aggregating (which is done in GrowForest()) and random //
// split selection (which is subject to MRanTree::GrowTree())              //
//                                                                         //
/////////////////////////////////////////////////////////////////////////////
#include "MRanForest.h"

#include <TMatrix.h>
#include <TRandom3.h>

#include "MHMatrix.h"
#include "MRanTree.h"

#include "MLog.h"
#include "MLogManip.h"

ClassImp(MRanForest);

using namespace std;

// --------------------------------------------------------------------------
//
// Default constructor.
//
MRanForest::MRanForest(const char *name, const char *title) : fNumTrees(100), fRanTree(NULL),fUsePriors(kFALSE)
{
    fName  = name  ? name  : "MRanForest";
    fTitle = title ? title : "Storage container for Random Forest";

    fForest=new TObjArray();
    fForest->SetOwner(kTRUE);
}

// --------------------------------------------------------------------------
//
// Destructor. 
//
MRanForest::~MRanForest()
{
    delete fForest;
}

void MRanForest::SetNumTrees(Int_t n)
{
    //at least 1 tree
    fNumTrees=TMath::Max(n,1);
    fTreeHad.Set(fNumTrees);
    fTreeHad.Reset();
}

void MRanForest::SetPriors(Float_t prior_had, Float_t prior_gam)
{
    const Float_t sum=prior_gam+prior_had;

    prior_gam/=sum;
    prior_had/=sum;

    fPrior[0]=prior_had;
    fPrior[1]=prior_gam;

    fUsePriors=kTRUE;
}

Double_t MRanForest::CalcHadroness(const TVector &event)
{
    Double_t hadroness=0;
    Int_t ntree=0;

    TIter Next(fForest);

    MRanTree *tree;
    while ((tree=(MRanTree*)Next()))
    {
        fTreeHad[ntree]=tree->TreeHad(event);
        hadroness+=fTreeHad[ntree];
        ntree++;
    }
    return hadroness/ntree;
}

Bool_t MRanForest::AddTree(MRanTree *rantree=NULL)
{
    if (rantree)
        fRanTree=rantree;
    if (!fRanTree)
        return kFALSE;

    fForest->Add((MRanTree*)fRanTree->Clone());

    return kTRUE;
}

Bool_t MRanForest::SetupGrow(MHMatrix *mhad,MHMatrix *mgam)
{
    // pointer to training data
    fHadrons=mhad;
    fGammas=mgam;

    // determine data entries and dimension of Hillas-parameter space
    fNumHad=fHadrons->GetM().GetNrows();
    fNumGam=fGammas->GetM().GetNrows();
    fNumDim=fHadrons->GetM().GetNcols();

    if (fNumDim!=fHadrons->GetM().GetNcols()) return kFALSE;

    fNumData=fNumHad+fNumGam;

    // allocating and initializing arrays
    fHadTrue.Set(fNumData);
    fHadTrue.Reset();
    fHadEst.Set(fNumData);

    fPrior.Set(2);
    fClassPop.Set(2);
    fWeight.Set(fNumData);
    fNTimesOutBag.Set(fNumData);
    fNTimesOutBag.Reset();

    fDataSort.Set(fNumDim*fNumData);
    fDataRang.Set(fNumDim*fNumData);

    // calculating class populations (= no. of gammas and hadrons)
    fClassPop.Reset();
    for(Int_t n=0;n<fNumData;n++)
        fClassPop[fHadTrue[n]]++;

    // setting weights taking into account priors
    if (!fUsePriors)
    {
        fWeight.Reset(1.);
    }else{
        for(Int_t j=0;j<2;j++)
            fPrior[j] *= (fClassPop[j]>=1) ?
                Float_t(fNumData)/Float_t(fClassPop[j]):0;

        for(Int_t n=0;n<fNumData;n++)
            fWeight[n]=fPrior[fHadTrue[n]];
    }

    // create fDataSort
    CreateDataSort();

    if(!fRanTree)
    {
        *fLog << err << dbginf << "MRanForest, fRanTree not initialized... aborting." << endl;
        return kFALSE;
    }
    fRanTree->SetRules(fGammas->GetColumns());
    fTreeNo=0;

    return kTRUE;
}

Bool_t MRanForest::GrowForest()
{
    if(!gRandom)
    {
        *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
        return kFALSE;
    }

    fTreeNo++;

    // initialize running output
    if (fTreeNo==1)
    {
        *fLog << inf << endl;
        *fLog << underline; // << "1st col        2nd col" << endl;
        *fLog << "no. of tree    error in % (calulated using oob-data -> overestim. of error)" << endl;
    }

    TArrayF classpopw(2);
    TArrayI jinbag(fNumData); // Initialization includes filling with 0
    TArrayF winbag(fNumData); // Initialization includes filling with 0

    // bootstrap aggregating (bagging) -> sampling with replacement:
    //
    // The integer k is randomly (uniformly) chosen from the set
    // {0,1,...,fNumData-1}, which is the set of the index numbers of
    // all events in the training sample
    for (Int_t n=0; n<fNumData; n++)
    {
        const Int_t k = Int_t(gRandom->Rndm()*fNumData);

        classpopw[fHadTrue[k]]+=fWeight[k];
        winbag[k]+=fWeight[k];
        jinbag[k]=1;
    }

    // modifying sorted-data array for in-bag data:
    //
    // In bagging procedure ca. 2/3 of all elements in the original
    // training sample are used to build the in-bag data
    TArrayI datsortinbag=fDataSort;
    Int_t ninbag=0;

    ModifyDataSort(datsortinbag, ninbag, jinbag);

    const TMatrix &hadrons=fHadrons->GetM();
    const TMatrix &gammas =fGammas->GetM();

    // growing single tree
    fRanTree->GrowTree(hadrons,gammas,fNumData,fNumDim,fHadTrue,datsortinbag,
                       fDataRang,classpopw,jinbag,winbag,fWeight);

    // error-estimates from out-of-bag data (oob data):
    //
    // For a single tree the events not(!) contained in the bootstrap sample of
    // this tree can be used to obtain estimates for the classification error of
    // this tree.
    // If you take a certain event, it is contained in the oob-data of 1/3 of
    // the trees (see comment to ModifyData). This means that the classification error
    // determined from oob-data is underestimated, but can still be taken as upper limit.

    for (Int_t ievt=0;ievt<fNumHad;ievt++)
    {
        if (jinbag[ievt]>0)
            continue;
        fHadEst[ievt] += fRanTree->TreeHad(hadrons, ievt);
        fNTimesOutBag[ievt]++;
    }
    for (Int_t ievt=0;ievt<fNumGam;ievt++)
    {
        if (jinbag[fNumHad+ievt]>0)
            continue;
        fHadEst[fNumHad+ievt] += fRanTree->TreeHad(gammas, ievt);
        fNTimesOutBag[fNumHad+ievt]++;
    }

    Int_t n=0;
    Double_t ferr=0;
    for (Int_t ievt=0;ievt<fNumData;ievt++)
        if (fNTimesOutBag[ievt]!=0)
        {
            const Double_t val = fHadEst[ievt]/fNTimesOutBag[ievt]-fHadTrue[ievt];
            ferr += val*val;
            n++;
        }

    ferr = TMath::Sqrt(ferr/n);

    // give running output
    *fLog << inf << setw(5) << fTreeNo << Form("%15.2f", ferr*100) << endl;

    // adding tree to forest
    AddTree();

    return fTreeNo<fNumTrees;
}

void MRanForest::CreateDataSort()
{
    // The values of concatenated data arrays fHadrons and fGammas (denoted in the following by fData,
    // which does actually not exist) are sorted from lowest to highest.
    //
    //
    //                   fHadrons(0,0) ... fHadrons(0,nhad-1)   fGammas(0,0) ... fGammas(0,ngam-1)
    //                        .                 .                   .                .
    //  fData(m,n)   =        .                 .                   .                .
    //                        .                 .                   .                .
    //                   fHadrons(m-1,0)...fHadrons(m-1,nhad-1) fGammas(m-1,0)...fGammas(m-1,ngam-1)
    //
    //
    // Then fDataSort(m,n) is the event number in which fData(m,n) occurs.
    // fDataRang(m,n) is the rang of fData(m,n), i.e. if rang = r, fData(m,n) is the r-th highest
    // value of all fData(m,.). There may be more then 1 event with rang r (due to bagging).

    TArrayF v(fNumData);
    TArrayI isort(fNumData);

    const TMatrix &hadrons=fHadrons->GetM();
    const TMatrix &gammas=fGammas->GetM();

    for (Int_t j=0;j<fNumHad;j++)
        fHadTrue[j]=1;

    for (Int_t j=0;j<fNumGam;j++)
        fHadTrue[j+fNumHad]=0;

    for (Int_t mvar=0;mvar<fNumDim;mvar++)
    {
        for(Int_t n=0;n<fNumHad;n++)
        {
            v[n]=hadrons(n,mvar);
            isort[n]=n;
        }

        for(Int_t n=0;n<fNumGam;n++)
        {
            v[n+fNumHad]=gammas(n,mvar);
            isort[n+fNumHad]=n;
        }

        TMath::Sort(fNumData,v.GetArray(),isort.GetArray(),kFALSE);

        // this sorts the v[n] in ascending order. isort[n] is the event number
        // of that v[n], which is the n-th from the lowest (assume the original
        // event numbers are 0,1,...).

        for(Int_t n=0;n<fNumData-1;n++)
        {
            const Int_t n1=isort[n];
            const Int_t n2=isort[n+1];

            fDataSort[mvar*fNumData+n]=n1;
            if(n==0) fDataRang[mvar*fNumData+n1]=0;
            if(v[n]<v[n+1])
            {
                fDataRang[mvar*fNumData+n2]=fDataRang[mvar*fNumData+n1]+1;
            }else{
                fDataRang[mvar*fNumData+n2]=fDataRang[mvar*fNumData+n1];
            }
        }
        fDataSort[(mvar+1)*fNumData-1]=isort[fNumData-1];
    }
}

void MRanForest::ModifyDataSort(TArrayI &datsortinbag, Int_t ninbag, const TArrayI &jinbag)
{
    ninbag=0;
    for (Int_t n=0;n<fNumData;n++)
        if(jinbag[n]==1) ninbag++;

    for(Int_t m=0;m<fNumDim;m++)
    {
        Int_t k=0;
        Int_t nt=0;
        for(Int_t n=0;n<fNumData;n++)
        {
            if(jinbag[datsortinbag[m*fNumData+k]]==1)
            {
                datsortinbag[m*fNumData+nt]=datsortinbag[m*fNumData+k];
                k++;
            }else{
                for(Int_t j=1;j<fNumData-k;j++)
                {
                    if(jinbag[datsortinbag[m*fNumData+k+j]]==1)
                    {
                        datsortinbag[m*fNumData+nt]=datsortinbag[m*fNumData+k+j];
                        k+=j+1;
                        break;
                    }
                }
            }
            nt++;
            if(nt>=ninbag) break;
        }
    }
}

Bool_t MRanForest::AsciiWrite(ostream &out) const
{
    Int_t n=0;
    MRanTree *tree;
    TIter forest(fForest);

    while ((tree=(MRanTree*)forest.Next()))
    {
        tree->AsciiWrite(out);
        n++;
    }

    return n==fNumTrees;
}
