/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@physik.hu-berlin.de>
!
!   Copyright: MAGIC Software Development, 2000-2005
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//
// MRanTree
//
// ParameterContainer for Tree structure
//
/////////////////////////////////////////////////////////////////////////////
#include "MRanTree.h"

#include <iostream>

#include <TVector.h>
#include <TMatrix.h>
#include <TRandom.h>

#include "MLog.h"
#include "MLogManip.h"

ClassImp(MRanTree);

using namespace std;


// --------------------------------------------------------------------------
// Default constructor.
//
MRanTree::MRanTree(const char *name, const char *title):fClassify(kTRUE),fNdSize(0), fNumTry(3)
{

    fName  = name  ? name  : "MRanTree";
    fTitle = title ? title : "Storage container for structure of a single tree";
}

// --------------------------------------------------------------------------
// Copy constructor
//
MRanTree::MRanTree(const MRanTree &tree)
{
    fName  = tree.fName;
    fTitle = tree.fTitle;

    fClassify = tree.fClassify;
    fNdSize   = tree.fNdSize;
    fNumTry   = tree.fNumTry;

    fNumNodes    = tree.fNumNodes;
    fNumEndNodes = tree.fNumEndNodes;

    fBestVar   = tree.fBestVar;
    fTreeMap1  = tree.fTreeMap1;
    fTreeMap2  = tree.fTreeMap2;
    fBestSplit = tree.fBestSplit;
    fGiniDec   = tree.fGiniDec;
}

void MRanTree::SetNdSize(Int_t n)
{
    // threshold nodesize of terminal nodes, i.e. the training data is splitted
    // until there is only pure date in the subsets(=terminal nodes) or the
    // subset size is LE n

    fNdSize=TMath::Max(1,n);//at least 1 event per node
}

void MRanTree::SetNumTry(Int_t n)
{
    // number of trials in random split selection:
    // choose at least 1 variable to split in

    fNumTry=TMath::Max(1,n);
}

void MRanTree::GrowTree(TMatrix *mat, const TArrayF &hadtrue, const TArrayI &idclass,
                        TArrayI &datasort, const TArrayI &datarang, TArrayF &tclasspop,
                        float &mean, float &square, TArrayI &jinbag, const TArrayF &winbag,
                        const int nclass)
{
    // arrays have to be initialized with generous size, so number of total nodes (nrnodes)
    // is estimated for worst case
    const Int_t numdim =mat->GetNcols();
    const Int_t numdata=winbag.GetSize();
    const Int_t nrnodes=2*numdata+1;

    // number of events in bootstrap sample
    Int_t ninbag=0;
    for (Int_t n=0;n<numdata;n++) if(jinbag[n]==1) ninbag++;

    TArrayI bestsplit(nrnodes);      bestsplit.Reset(0);
    TArrayI bestsplitnext(nrnodes);  bestsplitnext.Reset(0);

    fBestVar.Set(nrnodes);    fBestVar.Reset(0);
    fTreeMap1.Set(nrnodes);   fTreeMap1.Reset(0);
    fTreeMap2.Set(nrnodes);   fTreeMap2.Reset(0);
    fBestSplit.Set(nrnodes);  fBestSplit.Reset(0);
    fGiniDec.Set(numdim);     fGiniDec.Reset(0);


    if(fClassify)
        FindBestSplit=&MRanTree::FindBestSplitGini;
    else
        FindBestSplit=&MRanTree::FindBestSplitSigma;

    // tree growing
    BuildTree(datasort,datarang,hadtrue,idclass,bestsplit, bestsplitnext,
              tclasspop,mean,square,winbag,ninbag,nclass);

    // post processing, determine cut (or split) values fBestSplit
    for(Int_t k=0; k<nrnodes; k++)
    {
        if (GetNodeStatus(k)==-1)
            continue;

        const Int_t &bsp =bestsplit[k];
        const Int_t &bspn=bestsplitnext[k];
        const Int_t &msp =fBestVar[k];

        fBestSplit[k]  = (*mat)(bsp, msp);
        fBestSplit[k] += (*mat)(bspn,msp);
        fBestSplit[k] /= 2.;
    }

    // resizing arrays to save memory
    fBestVar.Set(fNumNodes);
    fTreeMap1.Set(fNumNodes);
    fTreeMap2.Set(fNumNodes);
    fBestSplit.Set(fNumNodes);
}

int MRanTree::FindBestSplitGini(const TArrayI &datasort,const TArrayI &datarang,
                                const TArrayF &hadtrue,const TArrayI &idclass,
                                Int_t ndstart,Int_t ndend, TArrayF &tclasspop,
                                float &mean, float &square, Int_t &msplit,
                                Float_t &decsplit,Int_t &nbest, const TArrayF &winbag,
                                const int nclass)
{
    const Int_t nrnodes = fBestSplit.GetSize();
    const Int_t numdata = (nrnodes-1)/2;
    const Int_t mdim = fGiniDec.GetSize();

    TArrayF wr(nclass); wr.Reset(0);// right node

    // For the best split, msplit is the index of the variable (e.g Hillas par.,
    // zenith angle ,...)
    // split on. decsplit is the decreae in impurity measured by Gini-index.
    // nsplit is the case number of value of msplit split on,
    // and nsplitnext is the case number of the next larger value of msplit.

    Int_t nbestvar=0;

    // compute initial values of numerator and denominator of Gini-index,
    // Gini index= pno/dno
    Double_t pno=0;
    Double_t pdo=0;

    for (Int_t j=0; j<nclass; j++)
    {
        pno+=tclasspop[j]*tclasspop[j];
        pdo+=tclasspop[j];
    }

    const Double_t crit0=pno/pdo;

    // start main loop through variables to find best split,
    // (Gini-index as criterium crit)

    Double_t critmax=-FLT_MAX;

    // random split selection, number of trials = fNumTry
    for (Int_t mt=0; mt<fNumTry; mt++)
    {
        const Int_t mvar=Int_t(gRandom->Rndm()*mdim);
        const Int_t mn  = mvar*numdata;

        // Gini index = rrn/rrd+rln/rld
        Double_t rrn=pno;
        Double_t rrd=pdo;
        Double_t rln=0;
        Double_t rld=0;

        TArrayF wl(nclass); wl.Reset(0.);// left node //nclass
        wr = tclasspop;

        Double_t critvar=-1.0e20;
        for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
        {
            const Int_t  &nc = datasort[mn+nsp];
            const Int_t   &k = idclass[nc];
            const Float_t &u = winbag[nc];

            // do classification, Gini index as split rule
            rln+=u*(2*wl[k]+u);
            rrn+=u*(-2*wr[k]+u);

            rld+=u;
            rrd-=u;

            wl[k]+=u;
            wr[k]-=u;

            if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]])
                continue;

            if (TMath::Min(rrd,rld)<=1.0e-5)
                continue;

            const Double_t crit=(rln/rld)+(rrn/rrd);


            if (crit<=critvar) continue;

            nbestvar=nsp;
            critvar=crit;
        }

        if (critvar<=critmax) continue;

        msplit=mvar;
        nbest=nbestvar;
        critmax=critvar;
    }

    decsplit=critmax-crit0;

    return critmax<-1.0e10 ? 1 : 0;
}

int MRanTree::FindBestSplitSigma(const TArrayI &datasort,const TArrayI &datarang,
                                 const TArrayF &hadtrue, const TArrayI &idclass,
                                 Int_t ndstart,Int_t ndend, TArrayF &tclasspop,
                                 float &mean, float &square, Int_t &msplit,
                                 Float_t &decsplit,Int_t &nbest, const TArrayF &winbag,
                                 const int nclass)
{
    const Int_t nrnodes = fBestSplit.GetSize();
    const Int_t numdata = (nrnodes-1)/2;
    const Int_t mdim = fGiniDec.GetSize();

    float wr=0;// right node

    // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
    // split on. decsplit is the decreae in impurity measured by Gini-index.
    // nsplit is the case number of value of msplit split on,
    // and nsplitnext is the case number of the next larger value of msplit.

    Int_t nbestvar=0;

    // compute initial values of numerator and denominator of split-index,

    // resolution
    //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0];
    //Double_t pdo= (tclasspop[0]-1.)*mean*mean;

    // n*resolution
    //Double_t pno=-(tclasspop[0]*square-mean*mean)*tclasspop[0];
    //Double_t pdo= mean*mean;

    // variance
    //Double_t pno=-(square-mean*mean/tclasspop[0]);
    //Double_t pdo= (tclasspop[0]-1.);

    // n*variance
    Double_t pno= (square-mean*mean/tclasspop[0]);
    Double_t pdo= 1.;

    // 1./(n*variance)
    //Double_t pno= 1.;
    //Double_t pdo= (square-mean*mean/tclasspop[0]);

    const Double_t crit0=pno/pdo;

    // start main loop through variables to find best split,

    Double_t critmin=1.0e40;

    // random split selection, number of trials = fNumTry
    for (Int_t mt=0; mt<fNumTry; mt++)
    {
        const Int_t mvar=Int_t(gRandom->Rndm()*mdim);
        const Int_t mn  = mvar*numdata;

        Double_t rrn=0, rrd=0, rln=0, rld=0;
        Double_t esumr=0, esuml=0, e2sumr=0,e2suml=0;

        esumr =mean;
        e2sumr=square;
        esuml =0;
        e2suml=0;

        float wl=0.;// left node
        wr = tclasspop[0];

        Double_t critvar=critmin;
        for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
        {
            const Int_t &nc=datasort[mn+nsp];
            const Float_t &f=hadtrue[nc];;
            const Float_t &u=winbag[nc];

            e2sumr-=u*f*f;
            esumr -=u*f;
            wr    -=u;

            //-------------------------------------------
            // resolution
            //rrn=(wr*e2sumr-esumr*esumr)*wr;
            //rrd=(wr-1.)*esumr*esumr;

            // resolution times n
            //rrn=(wr*e2sumr-esumr*esumr)*wr;
            //rrd=esumr*esumr;

            // sigma
            //rrn=(e2sumr-esumr*esumr/wr);
            //rrd=(wr-1.);

            // sigma times n
            rrn=(e2sumr-esumr*esumr/wr);
            rrd=1.;

            // 1./(n*variance)
            //rrn=1.;
            //rrd=(e2sumr-esumr*esumr/wr);
            //-------------------------------------------

            e2suml+=u*f*f;
            esuml +=u*f;
            wl    +=u;

            //-------------------------------------------
            // resolution
            //rln=(wl*e2suml-esuml*esuml)*wl;
            //rld=(wl-1.)*esuml*esuml;

            // resolution times n
            //rln=(wl*e2suml-esuml*esuml)*wl;
            //rld=esuml*esuml;

            // sigma
            //rln=(e2suml-esuml*esuml/wl);
            //rld=(wl-1.);

            // sigma times n
            rln=(e2suml-esuml*esuml/wl);
            rld=1.;

            // 1./(n*variance)
            //rln=1.;
            //rld=(e2suml-esuml*esuml/wl);
            //-------------------------------------------

            if (datarang[mn+nc]>=datarang[mn+datasort[mn+nsp+1]])
                continue;

            if (TMath::Min(rrd,rld)<=1.0e-5)
                continue;

            const Double_t crit=(rln/rld)+(rrn/rrd);

            if (crit>=critvar) continue;

            nbestvar=nsp;
            critvar=crit;
        }

        if (critvar>=critmin) continue;

        msplit=mvar;
        nbest=nbestvar;
        critmin=critvar;
    }

    decsplit=crit0-critmin;

    //return critmin>1.0e20 ? 1 : 0;
    return decsplit<0 ? 1 : 0;
}

void MRanTree::MoveData(TArrayI &datasort,Int_t ndstart, Int_t ndend,
                        TArrayI &idmove,TArrayI &ncase,Int_t msplit,
                        Int_t nbest,Int_t &ndendl)
{
    // This is the heart of the BuildTree construction. Based on the best split
    // the data in the part of datasort corresponding to the current node is moved to the
    // left if it belongs to the left child and right if it belongs to the right child-node.
    const Int_t numdata = ncase.GetSize();
    const Int_t mdim    = fGiniDec.GetSize();

    TArrayI tdatasort(numdata); tdatasort.Reset(0);

    // compute idmove = indicator of case nos. going left
    for (Int_t nsp=ndstart;nsp<=ndend;nsp++)
    {
        const Int_t &nc=datasort[msplit*numdata+nsp];
        idmove[nc]= nsp<=nbest?1:0;
    }
    ndendl=nbest;

    // shift case. nos. right and left for numerical variables.
    for(Int_t msh=0;msh<mdim;msh++)
    {
        Int_t k=ndstart-1;
        for (Int_t n=ndstart;n<=ndend;n++)
        {
            const Int_t &ih=datasort[msh*numdata+n];
            if (idmove[ih]==1)
                tdatasort[++k]=datasort[msh*numdata+n];
        }

        for (Int_t n=ndstart;n<=ndend;n++)
        {
            const Int_t &ih=datasort[msh*numdata+n];
            if (idmove[ih]==0)
                tdatasort[++k]=datasort[msh*numdata+n];
        }

        for(Int_t m=ndstart;m<=ndend;m++)
            datasort[msh*numdata+m]=tdatasort[m];
    }

    // compute case nos. for right and left nodes.

    for(Int_t n=ndstart;n<=ndend;n++)
        ncase[n]=datasort[msplit*numdata+n];
}

void MRanTree::BuildTree(TArrayI &datasort,const TArrayI &datarang, const TArrayF &hadtrue,
                         const TArrayI &idclass, TArrayI &bestsplit, TArrayI &bestsplitnext,
                         TArrayF &tclasspop, float &tmean, float &tsquare, const TArrayF &winbag,
                         Int_t ninbag, const int nclass)
{
    // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData.
    // Findbestsplit does just that--it finds the best split of the current node.
    // MoveData moves the data in the split node right and left so that the data
    // corresponding to each child node is contiguous.
    //
    // buildtree bookkeeping:
    // ncur is the total number of nodes to date.  nodestatus(k)=1 if the kth node has been split.
    // nodestatus(k)=2 if the node exists but has not yet been split, and =-1 if the node is
    // terminal.  A node is terminal if its size is below a threshold value, or if it is all
    // one class, or if all the data-values are equal.  If the current node k is split, then its
    // children are numbered ncur+1 (left), and ncur+2(right), ncur increases to ncur+2 and
    // the next node to be split is numbered k+1.  When no more nodes can be split, buildtree
    // returns.
    const Int_t mdim    = fGiniDec.GetSize();
    const Int_t nrnodes = fBestSplit.GetSize();
    const Int_t numdata = (nrnodes-1)/2;

    TArrayI nodepop(nrnodes);     nodepop.Reset(0);
    TArrayI nodestart(nrnodes);   nodestart.Reset(0);
    TArrayI parent(nrnodes);      parent.Reset(0);

    TArrayI ncase(numdata);       ncase.Reset(0);
    TArrayI idmove(numdata);      idmove.Reset(0);
    TArrayI iv(mdim);             iv.Reset(0);

    TArrayF classpop(nrnodes*nclass);  classpop.Reset(0.);//nclass
    TArrayI nodestatus(nrnodes);       nodestatus.Reset(0);

    for (Int_t j=0;j<nclass;j++)
        classpop[j*nrnodes+0]=tclasspop[j];

    TArrayF mean(nrnodes);   mean.Reset(0.);
    TArrayF square(nrnodes); square.Reset(0.);

    mean[0]=tmean;
    square[0]=tsquare;


    Int_t ncur=0;
    nodepop[0]=ninbag;
    nodestatus[0]=2;

    // start main loop
    for (Int_t kbuild=0; kbuild<nrnodes; kbuild++)
    {
          if (kbuild>ncur) break;
          if (nodestatus[kbuild]!=2) continue;

          // initialize for next call to FindBestSplit

          const Int_t ndstart=nodestart[kbuild];
          const Int_t ndend=ndstart+nodepop[kbuild]-1;

          for (Int_t j=0;j<nclass;j++)
              tclasspop[j]=classpop[j*nrnodes+kbuild];

          tmean=mean[kbuild];
          tsquare=square[kbuild];

          Int_t msplit, nbest;
          Float_t decsplit=0;

          if ((*this.*FindBestSplit)(datasort,datarang,hadtrue,idclass,ndstart,
                                     ndend, tclasspop,tmean, tsquare,msplit,decsplit,
                                     nbest,winbag,nclass))
          {
              nodestatus[kbuild]=-1;
              continue;
          }

          fBestVar[kbuild]=msplit;
          fGiniDec[msplit]+=decsplit;

          bestsplit[kbuild]=datasort[msplit*numdata+nbest];
          bestsplitnext[kbuild]=datasort[msplit*numdata+nbest+1];

          Int_t ndendl;
          MoveData(datasort,ndstart,ndend,idmove,ncase,
                   msplit,nbest,ndendl);

          // leftnode no.= ncur+1, rightnode no. = ncur+2.
          nodepop[ncur+1]=ndendl-ndstart+1;
          nodepop[ncur+2]=ndend-ndendl;
          nodestart[ncur+1]=ndstart;
          nodestart[ncur+2]=ndendl+1;

          // find class populations in both nodes
          for (Int_t n=ndstart;n<=ndendl;n++)
          {
              const Int_t &nc=ncase[n];
              const int j=idclass[nc];
                   
              mean[ncur+1]+=hadtrue[nc]*winbag[nc];
              square[ncur+1]+=hadtrue[nc]*hadtrue[nc]*winbag[nc];

              classpop[j*nrnodes+ncur+1]+=winbag[nc];
          }

          for (Int_t n=ndendl+1;n<=ndend;n++)
          {
              const Int_t &nc=ncase[n];
              const int j=idclass[nc];

              mean[ncur+2]  +=hadtrue[nc]*winbag[nc];
              square[ncur+2]+=hadtrue[nc]*hadtrue[nc]*winbag[nc];

              classpop[j*nrnodes+ncur+2]+=winbag[nc];
          }

          // check on nodestatus

          nodestatus[ncur+1]=2;
          nodestatus[ncur+2]=2;
          if (nodepop[ncur+1]<=fNdSize) nodestatus[ncur+1]=-1;
          if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1;


          Double_t popt1=0;
          Double_t popt2=0;
          for (Int_t j=0;j<nclass;j++)
          {
              popt1+=classpop[j*nrnodes+ncur+1];
              popt2+=classpop[j*nrnodes+ncur+2];
          }

          if(fClassify)
          {
              // check if only members of one class in node
              for (Int_t j=0;j<nclass;j++)
              {
                  if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
                  if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
              }
          }

          fTreeMap1[kbuild]=ncur+1;
          fTreeMap2[kbuild]=ncur+2;
          parent[ncur+1]=kbuild;
          parent[ncur+2]=kbuild;
          nodestatus[kbuild]=1;
          ncur+=2;
          if (ncur>=nrnodes) break;
    }

    // determine number of nodes
    fNumNodes=nrnodes;
    for (Int_t k=nrnodes-1;k>=0;k--)
    {
        if (nodestatus[k]==0) fNumNodes-=1;
        if (nodestatus[k]==2) nodestatus[k]=-1;
    }

    fNumEndNodes=0;
    for (Int_t kn=0;kn<fNumNodes;kn++)
        if(nodestatus[kn]==-1)
        {
            fNumEndNodes++;

            Double_t pp=0;
            for (Int_t j=0;j<nclass;j++)
            {
                if(classpop[j*nrnodes+kn]>pp)
                {
                    // class + status of node kn coded into fBestVar[kn]
                    fBestVar[kn]=j-nclass;
                    pp=classpop[j*nrnodes+kn];
                }
            }

                float sum=0;
                for(int i=0;i<nclass;i++) sum+=classpop[i*nrnodes+kn];

                fBestSplit[kn]=mean[kn]/sum;
        }
}

Double_t MRanTree::TreeHad(const TVector &event)
{
    Int_t kt=0;
    // to optimize on storage space node status and node class
    // are coded into fBestVar:
    // status of node kt = TMath::Sign(1,fBestVar[kt])
    // class  of node kt = fBestVar[kt]+2 (class defined by larger
    //  node population, actually not used)
    // hadronness assigned to node kt = fBestSplit[kt]

    for (Int_t k=0;k<fNumNodes;k++)
    {
        if (fBestVar[kt]<0)
            break;

        const Int_t m=fBestVar[kt];
        kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
    }

    return fBestSplit[kt];
}

Double_t MRanTree::TreeHad(const TMatrixRow &event)
{
    Int_t kt=0;
    // to optimize on storage space node status and node class
    // are coded into fBestVar:
    // status of node kt = TMath::Sign(1,fBestVar[kt])
    // class  of node kt = fBestVar[kt]+2 (class defined by larger
    //  node population, actually not used)
    // hadronness assigned to node kt = fBestSplit[kt]

    for (Int_t k=0;k<fNumNodes;k++)
    {
        if (fBestVar[kt]<0)
            break;

        const Int_t m=fBestVar[kt];
        kt = event(m)<=fBestSplit[kt] ? fTreeMap1[kt] : fTreeMap2[kt];
    }

    return fBestSplit[kt];
}

Double_t MRanTree::TreeHad(const TMatrix &m, Int_t ievt)
{
#if ROOT_VERSION_CODE < ROOT_VERSION(4,00,8)
    return TreeHad(TMatrixRow(m, ievt));
#else
    return TreeHad(TMatrixFRow_const(m, ievt));
#endif
}

Bool_t MRanTree::AsciiWrite(ostream &out) const
{
    TString str;
    Int_t k;

    out.width(5);out<<fNumNodes<<endl;

    for (k=0;k<fNumNodes;k++)
    {
        str=Form("%f",GetBestSplit(k));

        out.width(5);  out << k;
        out.width(5);  out << GetNodeStatus(k);
        out.width(5);  out << GetTreeMap1(k);
        out.width(5);  out << GetTreeMap2(k);
        out.width(5);  out << GetBestVar(k);
        out.width(15); out << str<<endl;
        out.width(5);  out << GetNodeClass(k);
    }
    out<<endl;

    return k==fNumNodes;
}
