/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
!   Author(s): Thomas Bretz 8/2005 <mailto:tbretz@astro.uni-wuerzburg.de>
!
!   Copyright: MAGIC Software Development, 2000-2005
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//
//  MRanForestCalc
//
//
////////////////////////////////////////////////////////////////////////////
#include "MRanForestCalc.h"

#include <TVector.h>

#include "MHMatrix.h"

#include "MLog.h"
#include "MLogManip.h"

#include "MData.h"
#include "MDataArray.h"

#include "MRanForest.h"
#include "MParameters.h"

#include "MParList.h"
#include "MTaskList.h"
#include "MEvtLoop.h"
#include "MRanForestGrow.h"
#include "MFillH.h"

ClassImp(MRanForestCalc);

using namespace std;

const TString MRanForestCalc::gsDefName    = "MRanForestCalc";
const TString MRanForestCalc::gsDefTitle   = "RF for energy estimation";

const TString MRanForestCalc::gsNameOutput = "RanForestOut";

MRanForestCalc::MRanForestCalc(const char *name, const char *title)
    : fDebug(kFALSE), fData(0), fRFOut(0),
    fNumTrees(-1), fNumTry(-1), fNdSize(-1), fNumObsoleteVariables(1),
    fTestMatrix(0), fEstimationMode(kMean)
{
    fName  = name  ? name  : gsDefName.Data();
    fTitle = title ? title : gsDefTitle.Data();

    // FIXME:
    fNumTrees = 100; //100
    fNumTry   = 0;   //3   0 means: in MRanForest estimated best value will be calculated
    fNdSize   = 1;   //1   
}

MRanForestCalc::~MRanForestCalc()
{
    fEForests.Delete();
}

Int_t MRanForestCalc::Train(const MHMatrix &matrixtrain, const TArrayD &grid, Int_t ver)
{
    gLog.Separator("MRanForestCalc - Train");

    if (!matrixtrain.GetColumns())
    {
        *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
        return kFALSE;
    }

    const Int_t ncols = matrixtrain.GetM().GetNcols();
    const Int_t nrows = matrixtrain.GetM().GetNrows();
    if (ncols<=0 || nrows <=0)
    {
        *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
        return kFALSE;
    }

    // rules (= combination of image par) to be used for energy estimation
    TFile fileRF(fFileName, "recreate");
    if (!fileRF.IsOpen())
    {
        *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
        return kFALSE;
    }

    const Int_t nobs = fNumObsoleteVariables; // Number of obsolete columns

    const MDataArray &dcol = *matrixtrain.GetColumns();

    MDataArray usedrules;
    for (Int_t i=0; i<ncols; i++)
        if (i<ncols-nobs)  // -3 is important!!!
            usedrules.AddEntry(dcol[i].GetRule());
        else
            *fLog << inf << "Skipping " << dcol[i].GetRule() << " for training" << endl;

    MDataArray rules(usedrules);
    rules.AddEntry(ver<2?"Classification":dcol[ncols-1].GetRule());

    // prepare matrix for current energy bin
    TMatrix mat(matrixtrain.GetM());

    // last column must be removed (true energy col.)
    mat.ResizeTo(nrows, ncols-nobs+1);

    if (fDebug)
        gLog.SetNullOutput(kTRUE);

    const Int_t nbins = ver>0 ? 1 : grid.GetSize()-1;
    for (Int_t ie=0; ie<nbins; ie++)
    {
        switch (ver)
        {
        case 0: // Replace Energy Grid by classification
            {
                Int_t irows=0;
                for (Int_t j=0; j<nrows; j++)
                {
                    const Double_t energy = matrixtrain.GetM()(j,ncols-1);
                    const Bool_t   inside = energy>grid[ie] && energy<=grid[ie+1];

                    mat(j, ncols-nobs) = inside ? 1 : 0;

                    if (inside)
                        irows++;
                }
                if (irows==0)
                    *fLog << warn << "WARNING - Skipping";
                else
                    *fLog << inf << "Training RF for";

                *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ") " << irows << "/" << nrows << endl;

                if (irows==0)
                    continue;
            }
            break;

        case 1: // Use Energy as classifier
        case 2:
            for (Int_t j=0; j<nrows; j++)
                mat(j, ncols-nobs) = matrixtrain.GetM()(j,ncols-1);
            break;
        }

        MHMatrix matrix(mat, &rules, "MatrixTrain");

        MParList plist;
        MTaskList tlist;
        plist.AddToList(&tlist);
        plist.AddToList(&matrix);

        MRanForest rf;
        rf.SetNumTrees(fNumTrees);
        rf.SetNumTry(fNumTry);
        rf.SetNdSize(fNdSize);
        rf.SetClassify(ver<2 ? 1 : 0);
        if (ver==1)
            rf.SetGrid(grid);

        plist.AddToList(&rf);

        MRanForestGrow rfgrow;
        tlist.AddToList(&rfgrow);

        MFillH fillh("MHRanForestGini");
        tlist.AddToList(&fillh);

        MEvtLoop evtloop;
        evtloop.SetParList(&plist);
        evtloop.SetDisplay(fDisplay);
        evtloop.SetLogStream(fLog);

        if (!evtloop.Eventloop())
            return kFALSE;

        if (fDebug)
            gLog.SetNullOutput(kFALSE);

        if (ver==0)
        {
            // Calculate bin center
            const Double_t E = (TMath::Log10(grid[ie])+TMath::Log10(grid[ie+1]))/2;

            // save whole forest
            rf.SetUserVal(E);
            rf.SetName(Form("%.10f", E));
        }

        rf.Write();
    }

    // save rules
    usedrules.Write("rules");

    return kTRUE;
}

Int_t MRanForestCalc::ReadForests(MParList &plist)
{
    TFile fileRF(fFileName, "read");
    if (!fileRF.IsOpen())
    {
        *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
        return kFALSE;
    }

    fEForests.Delete();

    TIter Next(fileRF.GetListOfKeys());
    TObject *o=0;
    while ((o=Next()))
    {
        MRanForest *forest=0;
        fileRF.GetObject(o->GetName(), forest);
        if (!forest)
            continue;

        forest->SetUserVal(atof(o->GetName()));

        fEForests.Add(forest);
    }

    // Maybe fEForests[0].fRules yould be used instead?

    if (fData->Read("rules")<=0)
    {
        *fLog << err << "ERROR - Reading 'rules' from file " << fFileName << endl;
        return kFALSE;
    }

    return kTRUE;
}

Int_t MRanForestCalc::PreProcess(MParList *plist)
{
    fRFOut = (MParameterD*)plist->FindCreateObj("MParameterD", fNameOutput);
    if (!fRFOut)
        return kFALSE;

    fData = (MDataArray*)plist->FindCreateObj("MDataArray");
    if (!fData)
        return kFALSE;

    if (!ReadForests(*plist))
    {
        *fLog << err << "Reading RFs failed... aborting." << endl;
        return kFALSE;
    }

    *fLog << inf << "RF read from " << fFileName << endl;

    if (fTestMatrix)
        return kTRUE;

    fData->Print();

    if (!fData->PreProcess(plist))
    {
        *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
        return kFALSE;
    }

    return kTRUE;
}

#include <TGraph.h>
#include <TF1.h>
Int_t MRanForestCalc::Process()
{
    TVector event;
    if (fTestMatrix)
        *fTestMatrix >> event;
    else
        *fData >> event;

    // --------------- Single Tree RF -------------------
    if (fEForests.GetEntries()==1)
    {
        MRanForest *rf = (MRanForest*)fEForests[0];
        fRFOut->SetVal(rf->CalcHadroness(event));
        fRFOut->SetReadyToSave();

        return kTRUE;
    }

    // --------------- Multi Tree RF -------------------
    static TF1 f1("f1", "gaus");

    Double_t sume = 0;
    Double_t sumh = 0;
    Double_t maxh = 0;
    Double_t maxe = 0;

    Double_t max  = -1e10;
    Double_t min  =  1e10;

    TIter Next(&fEForests);
    MRanForest *rf = 0;

    TGraph g;
    while ((rf=(MRanForest*)Next()))
    {
        const Double_t h = rf->CalcHadroness(event);
        const Double_t e = rf->GetUserVal();

        g.SetPoint(g.GetN(), e, h);

        sume += e*h;
        sumh += h;

        if (h>maxh)
        {
            maxh = h;
            maxe = e;
        }
        if (e>max)
            max = e;
        if (e<min)
            min = e;
    }

    switch (fEstimationMode)
    {
    case kMean:
        fRFOut->SetVal(pow(10, sume/sumh));
        break;
    case kMaximum:
        fRFOut->SetVal(pow(10, maxe));
        break;
    case kFit:
        f1.SetParameter(0, maxh);
        f1.SetParameter(1, maxe);
        f1.SetParameter(2, 0.125);
        g.Fit(&f1, "Q0N");
        fRFOut->SetVal(pow(10, f1.GetParameter(1)));
        break;
    }

    fRFOut->SetReadyToSave();

    return kTRUE;
}

// --------------------------------------------------------------------------
//
//
Int_t MRanForestCalc::ReadEnv(const TEnv &env, TString prefix, Bool_t print)
{
    Bool_t rc = kFALSE;
    if (IsEnvDefined(env, prefix, "FileName", print))
    {
        rc = kTRUE;
        SetFileName(GetEnvValue(env, prefix, "FileName", fFileName));
    }
    if (IsEnvDefined(env, prefix, "Debug", print))
    {
        rc = kTRUE;
        SetDebug(GetEnvValue(env, prefix, "Debug", fDebug));
    }
    if (IsEnvDefined(env, prefix, "NameOutput", print))
    {
        rc = kTRUE;
        SetNameOutput(GetEnvValue(env, prefix, "NameOutput", fNameOutput));
    }
    if (IsEnvDefined(env, prefix, "EstimationMode", print))
    {
        TString txt = GetEnvValue(env, prefix, "EstimationMode", "");
        txt = txt.Strip(TString::kBoth);
        txt.ToLower();
        if (txt==(TString)"mean")
            fEstimationMode = kMean;
        if (txt==(TString)"maximum")
            fEstimationMode = kMaximum;
        if (txt==(TString)"fit")
            fEstimationMode = kFit;
        rc = kTRUE;
    }
    return rc;
}
