/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
!
!   Copyright: MAGIC Software Development, 2000-2005
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//
//  MRFEnergyEst
//
//
////////////////////////////////////////////////////////////////////////////
#include "MRFEnergyEst.h"

#include <TFile.h>
#include <TList.h>

#include <TH1.h>
#include <TH2.h>
#include <TStyle.h>
#include <TCanvas.h>
#include <TMath.h>
#include <TVector.h>

#include "MHMatrix.h"

#include "MLog.h"
#include "MLogManip.h"

#include "MParList.h"
#include "MTaskList.h"
#include "MEvtLoop.h"

#include "MRanTree.h"
#include "MRanForest.h"
#include "MRanForestGrow.h"

#include "MData.h"
#include "MParameters.h"


ClassImp(MRFEnergyEst);

using namespace std;

static const TString gsDefName  = "MRFEnergyEst";
static const TString gsDefTitle = "RF for energy estimation";

// --------------------------------------------------------------------------
//
//
MRFEnergyEst::MRFEnergyEst(const char *name, const char *title):fNumTrees(-1)
{
    //
    //   set the name and title of this object
    //
    fName  = name  ? name  : gsDefName.Data();
    fTitle = title ? title : gsDefTitle.Data();
}

// --------------------------------------------------------------------------
//
// Delete the data chains
//
MRFEnergyEst::~MRFEnergyEst()
{
    //    delete fData;
}

// --------------------------------------------------------------------------
Int_t MRFEnergyEst::Train()
{
    if(!fMatrixTrain)
    {
        *fLog << err << dbginf << "fMatrixTrain not found... aborting." << endl;
        return kFALSE;
    }

    if(!fMatrixTrain->GetColumns())
    {
        *fLog << err << dbginf << "fMatrixTrain does not contain rules... aborting." << endl;
        return kFALSE;
    }

    const Int_t ncols = (fMatrixTrain->GetM()).GetNcols();
    const Int_t nrows = (fMatrixTrain->GetM()).GetNrows();

    cout<<"ncols="<<ncols<<" nrows="<<nrows<<endl<<flush;

    if(ncols<=0 || nrows <=0)
    {
        *fLog << err << dbginf << "No. of columns or no. of rows of fMatrixTrain equal 0 ... aborting." << endl;
        return kFALSE;
    }

    // rules (= combination of image par) to be used for energy estimation
    MDataArray used_rules;
    TString energy_rule;
    for(Int_t i=0;i<ncols;i++)
    {
        MDataArray *rules=fMatrixTrain->GetColumns();
        MData &data=(*rules)[i];

        if(i<ncols-1)
            used_rules.AddEntry(data.GetRule());
        else
            energy_rule=data.GetRule();
    }

    if(!energy_rule.Contains("fEnergy"))
    {
        *fLog << err << dbginf << "Can not accept "<<energy_rule<<" as true energy ... aborting." << endl;
        return kFALSE;
    }

    TFile fileRF(fRFfileName,"recreate");
    if(!fileRF.IsOpen())
    {
        *fLog << err << dbginf << "File to store RFs could not be opened... aborting." << endl;
        return kFALSE;
    }

    TMatrix *mptr=(TMatrix*)&(fMatrixTrain->GetM());
    const Int_t nbins = fEnergyGrid.GetSize()-1;
    if(nbins<=0)
    {
        *fLog << err << dbginf << "Energy grid not vaild... aborting." << endl;
        return kFALSE;
    }

    // training of RF for each energy bin
    Int_t numbins=0;
    for(Int_t ie=0;ie<nbins;ie++)
    {
        TMatrix mat1(nrows,ncols);
        TMatrix mat0(nrows,ncols);

        // prepare matrix for current energy bin
        Int_t irow1=0;
        Int_t irow0=0;

        for(Int_t j=0;j<nrows;j++)
        {
            Double_t energy=(*mptr)(j,ncols-1);
            if(energy>pow(10.,fEnergyGrid[ie]) && energy<=pow(10.,fEnergyGrid[ie+1]))
            {
                TMatrixRow(mat1, irow1) = TMatrixRow(*mptr,j);
                irow1++;
            }else{
                TMatrixRow(mat0, irow0) = TMatrixRow(*mptr,j);
                irow0++;
            }
        }
        if(irow1*irow0==0)continue;

        *fLog << inf << dbginf << "Training RF for energy bin "<<ie<< endl;

        // last column must be removed (true energy col.)
        mat1.ResizeTo(irow1,ncols-1);
        mat0.ResizeTo(irow0,ncols-1);

        // create MHMatrix as input for RF
        MHMatrix matrix1(mat1,"MatrixHadrons");
        MHMatrix matrix0(mat0,"MatrixGammas");

        MDataArray *rules1=matrix1.GetColumns();
        MDataArray *rules0=matrix0.GetColumns();
        // rules of new matrices be re-set
        if(rules1)delete rules1; rules1=new MDataArray();
        if(rules0)delete rules0; rules0=new MDataArray();

        for(Int_t i=0;i<ncols-1;i++)
        {
            //MDataArray *rules=fMatrixTrain->GetColumns();
            //MData &data=(*rules)[i];
            MData &data=used_rules[i];
            rules1->AddEntry(data.GetRule());
            rules0->AddEntry(data.GetRule());
        }

        // training of RF
        MParList plist;
        MTaskList tlist;
        plist.AddToList(&tlist);
        plist.AddToList(&matrix0);
        plist.AddToList(&matrix1);

        MRanForestGrow rfgrow;
        rfgrow.SetNumTrees(fNumTrees); // number of trees
        rfgrow.SetNumTry(fNumTry);     // number of trials in random split selection
        rfgrow.SetNdSize(fNdSize);     // limit for nodesize

        tlist.AddToList(&rfgrow);
    
        MEvtLoop evtloop;
        evtloop.SetParList(&plist);

        //gLog.SetNullOutput(kTRUE);

        if (!evtloop.Eventloop()) return kFALSE;

        //gLog.SetNullOutput(kFALSE);

        // save whole forest
        MRanForest *forest=(MRanForest*)plist.FindObject("MRanForest");
        forest->SetTitle(Form("%f",0.5*(fEnergyGrid[ie]+fEnergyGrid[ie+1])));
        forest->Write(Form("%d",numbins));
        numbins++;
    }

    // save rules
    used_rules.Write("rules");

    fileRF.Close();

    return kTRUE;
}

Int_t MRFEnergyEst::Test()
{
    if(!fMatrixTest)
    {
        *fLog << err << dbginf << "fMatrixTest not found... aborting." << endl;
        return kFALSE;
    }

    const Int_t ncols = (fMatrixTest->GetM()).GetNcols();
    const Int_t nrows = (fMatrixTest->GetM()).GetNrows();

    if(ncols<=0 || nrows <=0)
    {
        *fLog << err << dbginf << "No. of columns or no. of rows of fMatrixTrain equal 0 ... aborting." << endl;
        return kFALSE;
    }

    TMatrix *mptr=(TMatrix*)&(fMatrixTest->GetM());

    if(!ReadForests())
    {
        *fLog << err << dbginf << "Reading RFs failed... aborting." << endl;
        return kFALSE;
    }

    const Int_t nbins=fEForests.GetSize();

    Double_t e_low = 100;
    Double_t e_up  = 0;

    for(Int_t j=0;j<nbins;j++)
    {
        e_low = TMath::Min(atof((fEForests[j])->GetTitle()),e_low);
        e_up  = TMath::Max(atof((fEForests[j])->GetTitle()),e_up);
    }

    TH1F hres("hres","",1000,-10,10);
    TH2F henergy("henergy","",100,e_low,e_up,100,e_low,e_up);

    for(Int_t i=0;i<nrows;i++)
    {
        Double_t e_true = (*mptr)(i,ncols-1);
        Double_t e_est = 0;
        //Double_t hmax  = 0;
        Double_t hsum  = 0;

        for(Int_t j=0;j<nbins;j++)
        {
            const TVector v=TMatrixRow(*mptr,i);
            Double_t h = ((MRanForest*) (fEForests[j]))->CalcHadroness(v);
            Double_t e = atof((fEForests[j])->GetTitle());
            /*if(h>=hmax)
            {
                hmax=h;
                e_est=pow(10.,e);
            }*/
            hsum+=h;
            e_est+=h*e;
        }
        e_est/=hsum;
        e_est=pow(10.,e_est);

        if(e_true>80.) hres.Fill((e_est-e_true)/e_true);
        henergy.Fill(log10(e_true),log10(e_est));
    }

    if(TestBit(kEnableGraphicalOutput))
    {
        if(gStyle) gStyle->SetOptFit(1);
        // show results
        TCanvas *c=MH::MakeDefCanvas("c","",300,900);

        c->Divide(1,3);
        c->cd(1);
        henergy.SetTitle("Estimated vs true energy");
        henergy.GetXaxis()->SetTitle("log10(E_{true}[GeV])");
        henergy.GetYaxis()->SetTitle("log10(E_{est}[GeV])");
        henergy.DrawCopy();

        c->cd(2);

        TH1F *hptr=(TH1F*)henergy.ProfileX("_px",1,100,"S");
        hptr->SetTitle("Estimated vs true energy - profile");
        hptr->GetXaxis()->SetTitle("log10(E_{true}[GeV])");
        hptr->GetYaxis()->SetTitle("log10(E_{est}[GeV])");
        hptr->DrawCopy();

        c->cd(3);
        hres.Fit("gaus");
        hres.SetTitle("Energy resolution for E_{true}>80Gev");
        hres.GetXaxis()->SetTitle("(E_{estimated}-E_{true})/E_{true})");
        hres.GetYaxis()->SetTitle("counts");
        hres.DrawCopy();


        c->GetPad(1)->SetGrid();
        c->GetPad(2)->SetGrid();
        c->GetPad(3)->SetGrid();

    }

    return kTRUE;
}

Int_t MRFEnergyEst::ReadForests(MParList *plist)
{
    TFile fileRF(fRFfileName,"read");
    if(!fileRF.IsOpen())
    {
        *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
        return kFALSE;
    }

    TList *list=(TList*)fileRF.GetListOfKeys();
    const Int_t n=list->GetSize()-1;// subtract 1 since 1 key belongs to MDataArray

    fEForests.Expand(n);

    MRanForest forest;
    for(Int_t i=0;i<n;i++)
    {
        forest.Read(Form("%d",i));
        MRanForest *curforest=(MRanForest*)forest.Clone();
        const char *energy=list->FindObject(Form("%d",i))->GetTitle();
        const char *bin   =list->FindObject(Form("%d",i))->GetName();

        curforest->SetTitle(energy);
        curforest->SetName(bin);

        fEForests[i]=curforest;
    }

    if(plist)
    {
        fData=(MDataArray*)plist->FindCreateObj("MDataArray");
        fData->Read("rules");
    }
    fileRF.Close();

    return kTRUE;
}

Int_t MRFEnergyEst::PreProcess(MParList *plist)
{
    fEnergyEst = (MParameterD*)plist->FindCreateObj("MParameterD", "MEnergyEst");
    if (!fEnergyEst)
    {
        *fLog << err << dbginf << "MEnergyEst [MParameterD] not found... aborting." << endl;
        return kFALSE;
    }

    if(!ReadForests(plist))
    {
        *fLog << err << dbginf << "Reading RFs failed... aborting." << endl;
        return kFALSE;
    }

    if (!fData)
    {
        *fLog << err << dbginf << "MDataArray not found... aborting." << endl;
        return kFALSE;
    }

    if (!fData->PreProcess(plist))
    {
        *fLog << err << dbginf << "PreProcessing of the MDataArray failed for the columns failed... aborting." << endl;
        return kFALSE;
    }

    return kTRUE;
}

// --------------------------------------------------------------------------
//
//
Int_t MRFEnergyEst::Process()
{
    TVector event;
    *fData >> event;

    Double_t eest = 0;
    //Double_t hmax  = 0;
    Double_t hsum  = 0;
        
    for(Int_t j=0;j<fEForests.GetSize();j++)
    {
        Double_t h = ((MRanForest*) (fEForests[j]))->CalcHadroness(event);
        Double_t e = atof((fEForests[j])->GetTitle());
        /*if(h>=hmax)
        {
            hmax=h;
            e_est=pow(10.,e);
        }*/
        hsum+=h;
        eest+=h*e;
    }
    eest/=hsum;
    eest=pow(10.,eest);

    fEnergyEst->SetVal(eest);
    fEnergyEst->SetReadyToSave();

    return kTRUE;
}
