/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 3/2003 <mailto:hengsteb@alwa02.physik.uni-siegen.de>
!
!   Copyright: MAGIC Software Development, 2000-2003
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//                                                                         //
// MRanTree                                                                //
//                                                                         //
// ParameterContainer for Tree structure                                   //
//                                                                         //
//                                                                         //
/////////////////////////////////////////////////////////////////////////////
#include "MRanTree.h"

#include <iostream.h>

#include <TVector.h>
#include <TMatrix.h>
#include <TRandom.h>

#include "MDataArray.h"

#include "MLog.h"
#include "MLogManip.h"

ClassImp(MRanTree);

// --------------------------------------------------------------------------
//
// Default constructor.
//
MRanTree::MRanTree(const char *name, const char *title):fNdSize(0), fNumTry(3), fData(NULL)
{

    fName  = name  ? name  : "MRanTree";
    fTitle = title ? title : "Storage container for structure of a single tree";
}

void MRanTree::SetNdSize(Int_t n)
{
    // threshold nodesize of terminal nodes, i.e. the training data is splitted
    // until there is only pure date in the subsets(=terminal nodes) or the
    // subset size is LE n

    fNdSize=TMath::Max(1,n);//at least 1 event per node
}

void MRanTree::SetNumTry(Int_t n)
{
    // number of trials in random split selection:
    // choose at least 1 variable to split in

    fNumTry=TMath::Max(1,n);
}

void MRanTree::GrowTree(TMatrix &mhad,TMatrix &mgam,Int_t numdata, Int_t numdim,TArrayI &hadtrue,
                        TArrayI &datasort,TArrayI &datarang,TArrayF &tclasspop,TArrayI &jinbag,
                        TArrayF &winbag,TArrayF &weight)
{
    // arrays have to be initialized with generous size, so number of total nodes (nrnodes)
    // is estimated for worst case
    Int_t nrnodes=2*numdata+1;

    // number of events in bootstrap sample
    Int_t ninbag=0;
    for (Int_t n=0;n<numdata;n++)
        if(jinbag[n]==1) ninbag++;

    // weighted class populations after split
    TArrayF wl(2); // left node
    TArrayF wc(2); 
    TArrayF wr(2); // right node
    TArrayI nc(2);

    TArrayI bestsplit(nrnodes);
    TArrayI bestsplitnext(nrnodes);
    TArrayI nodepop(nrnodes);
    TArrayI parent(nrnodes);
    TArrayI nodex(numdata);
    TArrayI nodestart(nrnodes);

    TArrayI ncase(numdata);
    TArrayI iv(numdim);
    TArrayI idmove(numdata);

    idmove.Reset();

    fBestVar.Set(nrnodes);
    fTreeMap1.Set(nrnodes);
    fTreeMap2.Set(nrnodes);
    fBestSplit.Set(nrnodes);

    fTreeMap1.Reset();
    fTreeMap2.Reset();
    fBestSplit.Reset();

    fGiniDec.Set(numdim);
    fGiniDec.Reset();

    // tree growing
    BuildTree(datasort,datarang,hadtrue,numdim,numdata,bestsplit,
              bestsplitnext,nodepop,nodestart,tclasspop,nrnodes,
              idmove,ncase,parent,jinbag,iv,winbag,wr,wc,wl,ninbag);

    // post processing, determine cut (or split) values fBestSplit
    Int_t nhad=mhad.GetNrows();

    for(Int_t k=0;k<nrnodes;k++)
    {
        Int_t bsp=bestsplit[k];
        Int_t bspn=bestsplitnext[k];
        Int_t msp=fBestVar[k];

        if (GetNodeStatus(k)!=-1)
        {
            fBestSplit[k] = bsp<nhad ? mhad(bsp,msp):mgam(bsp-nhad,msp);
            fBestSplit[k] += bspn<nhad ? mhad(bspn,msp):mgam(bspn-nhad,msp);
            fBestSplit[k] /=2.;
        }
    }

    // resizing arrays to save memory
    fBestVar.Set(fNumNodes);
    fTreeMap1.Set(fNumNodes);
    fTreeMap2.Set(fNumNodes);
    fBestSplit.Set(fNumNodes);
}

Int_t MRanTree::FindBestSplit(TArrayI &datasort,TArrayI &datarang,TArrayI &hadtrue,Int_t mdim,
                             Int_t numdata,Int_t ndstart,Int_t ndend,TArrayF &tclasspop,
                             Int_t &msplit,Float_t &decsplit,Int_t &nbest,TArrayI &ncase,
                             TArrayI &jinbag,TArrayI &iv,TArrayF &winbag,TArrayF &wr,
                             TArrayF &wc,TArrayF &wl,Int_t kbuild)
{
    // For the best split, msplit is the index of the variable (e.g Hillas par., zenith angle ,...)
    // split on. decsplit is the decreae in impurity measured by Gini-index.
    // nsplit is the case number of value of msplit split on,
    // and nsplitnext is the case number of the next larger value of msplit.

    Int_t mvar,nc,nbestvar=0,jstat,k;
    Float_t crit,crit0,critmax,critvar=0;
    Float_t rrn, rrd, rln, rld, u;

    // compute initial values of numerator and denominator of Gini-index,
    // Gini index= pno/dno
    Float_t pno=0;
    Float_t pdo=0;

    for (Int_t j=0;j<2;j++)
    {
          pno+=tclasspop[j]*tclasspop[j];
          pdo+=tclasspop[j];
    }
    crit0=pno/pdo;
    jstat=0;

    // start main loop through variables to find best split,
    // (Gini-index as criterium crit)

    critmax=-1.0e20;

    // random split selection, number of trials = fNumTry
    if(!gRandom)
    {
        *fLog << err << dbginf << "gRandom not initialized... aborting." << endl;
        return kFALSE;
    }
    for(Int_t mt=0;mt<fNumTry;mt++)
    {
        mvar=Int_t(mdim*gRandom->Rndm());

        // Gini index = rrn/rrd+rln/rld
        rrn=pno;
        rrd=pdo;
        rln=0;
        rld=0;
        wl.Reset();

        for (Int_t j=0;j<2;j++)
        {
            wr[j]=tclasspop[j];
        }

        critvar=-1.0e20;

        for(Int_t nsp=ndstart;nsp<=ndend-1;nsp++)
        {
            nc=datasort[mvar*numdata+nsp];

            u=winbag[nc];
            k=hadtrue[nc];

            rln=rln+u*(2*wl[k]+u);
            rrn=rrn+u*(-2*wr[k]+u);
            rld=rld+u;
            rrd=rrd-u;

            wl[k]=wl[k]+u;
            wr[k]=wr[k]-u;

            if (datarang[mvar*numdata+nc]<datarang[mvar*numdata+datasort[mvar*numdata+nsp+1]])
            {
                if(TMath::Min(rrd,rld)>1.0e-5)
                {
                    crit=(rln/rld)+(rrn/rrd);
                    if (crit>critvar)
                    {
                        nbestvar=nsp;
                        critvar=crit;
                    }
                }
            }
        }

        if (critvar>critmax) {
            msplit=mvar;
            nbest=nbestvar;
            critmax=critvar;
        }
    }

    decsplit=critmax-crit0;
    if (critmax<-1.0e10) jstat=1;

    return jstat;
}

void MRanTree::MoveData(TArrayI &datasort,Int_t mdim,Int_t numdata,Int_t ndstart,
                        Int_t ndend,TArrayI &idmove,TArrayI &ncase,Int_t msplit,
                        Int_t nbest,Int_t &ndendl)
{
    // This is the heart of the BuildTree construction. Based on the best split
    // the data in the part of datasort corresponding to the current node is moved to the
    // left if it belongs to the left child and right if it belongs to the right child-node.

    Int_t nc,k,ih;
    TArrayI tdatasort(numdata);

    // compute idmove = indicator of case nos. going left

    for (Int_t nsp=ndstart;nsp<=nbest;nsp++)
    {
        nc=datasort[msplit*numdata+nsp];
        idmove[nc]=1;
    }
    for (Int_t nsp=nbest+1;nsp<=ndend;nsp++)
    {
        nc=datasort[msplit*numdata+nsp];
        idmove[nc]=0;
    }
    ndendl=nbest;

    // shift case. nos. right and left for numerical variables.

    for(Int_t msh=0;msh<mdim;msh++)
    {
        k=ndstart-1;
        for (Int_t n=ndstart;n<=ndend;n++)
        {
            ih=datasort[msh*numdata+n];
            if (idmove[ih]==1) {
                k++;
                tdatasort[k]=datasort[msh*numdata+n];
            }
        }

        for (Int_t n=ndstart;n<=ndend;n++)
        {
            ih=datasort[msh*numdata+n];
            if (idmove[ih]==0){
                k++;
                tdatasort[k]=datasort[msh*numdata+n];
            }
        }
        for(Int_t k=ndstart;k<=ndend;k++)
            datasort[msh*numdata+k]=tdatasort[k];
    }

    // compute case nos. for right and left nodes.

    for(Int_t n=ndstart;n<=ndend;n++)
        ncase[n]=datasort[msplit*numdata+n];

    return;
}

void MRanTree::BuildTree(TArrayI &datasort,TArrayI &datarang,TArrayI &hadtrue,Int_t mdim,
                         Int_t numdata,TArrayI &bestsplit,TArrayI &bestsplitnext,
                         TArrayI &nodepop,TArrayI &nodestart,TArrayF &tclasspop,
                         Int_t nrnodes,TArrayI &idmove,TArrayI &ncase,TArrayI &parent,
                         TArrayI &jinbag,TArrayI &iv,TArrayF &winbag,TArrayF &wr,TArrayF &wc,
                         TArrayF &wl,Int_t ninbag)
{
    // Buildtree consists of repeated calls to two void functions, FindBestSplit and MoveData.
    // Findbestsplit does just that--it finds the best split of the current node.
    // MoveData moves the data in the split node right and left so that the data
    // corresponding to each child node is contiguous.
    //
    // buildtree bookkeeping:
    // ncur is the total number of nodes to date.  nodestatus(k)=1 if the kth node has been split.
    // nodestatus(k)=2 if the node exists but has not yet been split, and =-1 if the node is
    // terminal.  A node is terminal if its size is below a threshold value, or if it is all
    // one class, or if all the data-values are equal.  If the current node k is split, then its
    // children are numbered ncur+1 (left), and ncur+2(right), ncur increases to ncur+2 and
    // the next node to be split is numbered k+1.  When no more nodes can be split, buildtree
    // returns.

    Int_t msplit,nbest,ndendl,nc,jstat,ndend,ndstart;
    Float_t decsplit=0;
    Float_t popt1,popt2,pp;
    TArrayF classpop;
    TArrayI nodestatus;

    nodestatus.Set(nrnodes);
    classpop.Set(2*nrnodes);

    nodestatus.Reset();
    nodestart.Reset();
    nodepop.Reset();
    classpop.Reset();


    for (Int_t j=0;j<2;j++)
        classpop[j*nrnodes+0]=tclasspop[j];

    Int_t ncur=0;
    nodestart[0]=0;
    nodepop[0]=ninbag;
    nodestatus[0]=2;

    // start main loop
    for (Int_t kbuild=0;kbuild<nrnodes;kbuild++)
    {
          if (kbuild>ncur) break;
          if (nodestatus[kbuild]!=2) continue;

          // initialize for next call to FindBestSplit

          ndstart=nodestart[kbuild];
          ndend=ndstart+nodepop[kbuild]-1;
          for (Int_t j=0;j<2;j++)
            tclasspop[j]=classpop[j*nrnodes+kbuild];

          jstat=FindBestSplit(datasort,datarang,hadtrue,mdim,numdata,
                              ndstart,ndend,tclasspop,msplit,decsplit,
                              nbest,ncase,jinbag,iv,winbag,wr,wc,wl,
                              kbuild);

          if(jstat==1) {
              nodestatus[kbuild]=-1;
              continue;
          }else{
              fBestVar[kbuild]=msplit;
              fGiniDec[msplit]+=decsplit;

              bestsplit[kbuild]=datasort[msplit*numdata+nbest];
              bestsplitnext[kbuild]=datasort[msplit*numdata+nbest+1];
          }

          MoveData(datasort,mdim,numdata,ndstart,ndend,idmove,ncase,
                   msplit,nbest,ndendl);

          // leftnode no.= ncur+1, rightnode no. = ncur+2.

          nodepop[ncur+1]=ndendl-ndstart+1;
          nodepop[ncur+2]=ndend-ndendl;
          nodestart[ncur+1]=ndstart;
          nodestart[ncur+2]=ndendl+1;

          // find class populations in both nodes

          for (Int_t n=ndstart;n<=ndendl;n++)
          {
              nc=ncase[n];
              Int_t j=hadtrue[nc];
              classpop[j*nrnodes+ncur+1]+=winbag[nc];
          }

          for (Int_t n=ndendl+1;n<=ndend;n++)
          {
              nc=ncase[n];
              Int_t j=hadtrue[nc];
              classpop[j*nrnodes+ncur+2]+=winbag[nc];
          }

          // check on nodestatus

          nodestatus[ncur+1]=2;
          nodestatus[ncur+2]=2;
          if (nodepop[ncur+1]<=fNdSize) nodestatus[ncur+1]=-1;
          if (nodepop[ncur+2]<=fNdSize) nodestatus[ncur+2]=-1;
          popt1=0;
          popt2=0;
          for (Int_t j=0;j<2;j++)
          {
            popt1+=classpop[j*nrnodes+ncur+1];
            popt2+=classpop[j*nrnodes+ncur+2];
          }

          for (Int_t j=0;j<2;j++)
          {
            if (classpop[j*nrnodes+ncur+1]==popt1) nodestatus[ncur+1]=-1;
            if (classpop[j*nrnodes+ncur+2]==popt2) nodestatus[ncur+2]=-1;
          }

          fTreeMap1[kbuild]=ncur+1;
          fTreeMap2[kbuild]=ncur+2;
          parent[ncur+1]=kbuild;
          parent[ncur+2]=kbuild;
          nodestatus[kbuild]=1;
          ncur+=2;
          if (ncur>=nrnodes) break;
    }

    // determine number of nodes
    fNumNodes=nrnodes;
    for (Int_t k=nrnodes-1;k>=0;k--)
    {
        if (nodestatus[k]==0) fNumNodes-=1;
        if (nodestatus[k]==2) nodestatus[k]=-1;
    }

    fNumEndNodes=0;
    for (Int_t kn=0;kn<fNumNodes;kn++)
        if(nodestatus[kn]==-1)
        {
            fNumEndNodes++;
            pp=0;
            for (Int_t j=0;j<2;j++)
            {
                if(classpop[j*nrnodes+kn]>pp)
                {
                    // class + status of node kn coded into fBestVar[kn]
                    fBestVar[kn]=j-2;
                    pp=classpop[j*nrnodes+kn];
                }
            }
            fBestSplit[kn] =classpop[1*nrnodes+kn];
            fBestSplit[kn]/=(classpop[0*nrnodes+kn]+classpop[1*nrnodes+kn]);
        }

    return;
}

void MRanTree::SetRules(MDataArray *rules)
{
    fData=rules;
}

Double_t MRanTree::TreeHad(TVector &event)
{
    Int_t kt=0;
    // to optimize on storage space node status and node class
    // are coded into fBestVar:
    // status of node kt = TMath::Sign(1,fBestVar[kt])
    // hadronness assigned to node kt = fBestSplit[kt]

    for (Int_t k=0;k<fNumNodes;k++)
    {
        if (fBestVar[kt]<0)
            break;

        Int_t m=fBestVar[kt];

        if (event(m)<=fBestSplit[kt])
            kt=fTreeMap1[kt];
        else
            kt=fTreeMap2[kt];
    }

    return fBestSplit[kt];
}

Double_t MRanTree::TreeHad()
{
    const Double_t ncols = fData->GetNumEntries();
    TVector event(ncols);

    for (int i=0; i<fData->GetNumEntries(); i++)
        event(i) = (*fData)(i);

    Int_t kt=0;
    // to optimize on storage space node status and node class
    // are coded into fBestVar:
    // status of node kt = TMath::Sign(1,fBestVar[kt])
    // class  of node kt = fBestVar[kt]+2 (class defined by larger
    //  node population, actually not used)
    // hadronness assigned to node kt = fBestSplit[kt]

    for (Int_t k=0;k<fNumNodes;k++)
    {
        if (fBestVar[kt]<0)
            break;

        Int_t m=fBestVar[kt];

        if (event(m)<=fBestSplit[kt])
            kt=fTreeMap1[kt];
        else
            kt=fTreeMap2[kt];

    }

    return fBestSplit[kt];
}

Bool_t MRanTree::AsciiWrite(ostream &out) const
{
    TString str;
    Int_t k;

    out.width(5);out<<fNumNodes<<endl;

    for (k=0;k<fNumNodes;k++)
    {
        str=Form("%f",GetBestSplit(k));

        out.width(5);  out << k;
        out.width(5);  out << GetNodeStatus(k);
        out.width(5);  out << GetTreeMap1(k);
        out.width(5);  out << GetTreeMap2(k);
        out.width(5);  out << GetBestVar(k);
        out.width(15); out << str<<endl;
        out.width(5);  out << GetNodeClass(k);
    }
    out<<endl;

    return k==fNumNodes;
}
