/* ======================================================================== *\
!
! *
! * This file is part of MARS, the MAGIC Analysis and Reconstruction
! * Software. It is distributed to you in the hope that it can be a useful
! * and timesaving tool in analysing Data of imaging Cerenkov telescopes.
! * It is distributed WITHOUT ANY WARRANTY.
! *
! * Permission to use, copy, modify and distribute this software and its
! * documentation for any purpose is hereby granted without fee,
! * provided that the above copyright notice appear in all copies and
! * that both that copyright notice and this permission notice appear
! * in supporting documentation. It is provided "as is" without express
! * or implied warranty.
! *
!
!
!   Author(s): Thomas Hengstebeck 2/2005 <mailto:hengsteb@physik.hu-berlin.de>
!
!   Copyright: MAGIC Software Development, 2000-2005
!
!
\* ======================================================================== */

/////////////////////////////////////////////////////////////////////////////
//
//  MRFEnergyEst
//
//
////////////////////////////////////////////////////////////////////////////
#include "MRFEnergyEst.h"

#include <TFile.h>
#include <TList.h>

#include <TH1.h>
#include <TH2.h>
#include <TStyle.h>
#include <TCanvas.h>
#include <TMath.h>
#include <TVector.h>

#include "MHMatrix.h"

#include "MLog.h"
#include "MLogManip.h"

#include "MParList.h"
#include "MTaskList.h"
#include "MEvtLoop.h"

#include "MRanTree.h"
#include "MRanForest.h"
#include "MRanForestGrow.h"

#include "MData.h"
#include "MParameters.h"

ClassImp(MRFEnergyEst);

using namespace std;

static const TString gsDefName  = "MRFEnergyEst";
static const TString gsDefTitle = "RF for energy estimation";

// --------------------------------------------------------------------------
//
//  Default constructor. Set the name and title of this object
//
MRFEnergyEst::MRFEnergyEst(const char *name, const char *title)
    : fNumTrees(-1), fNumTry(-1), fNdSize(-1), fData(0), fEnergyEst(0),
    fTestMatrix(0)
{
    fName  = name  ? name  : gsDefName.Data();
    fTitle = title ? title : gsDefTitle.Data();
}

// --------------------------------------------------------------------------
//
// Train the RF with the goven MHMatrix. The last column must contain the
// True energy.
//
Int_t MRFEnergyEst::Train(const MHMatrix &matrixtrain, const TArrayD &grid)
{
    gLog.Separator("MRFEnergyEst - Train");

    if (!matrixtrain.GetColumns())
    {
        *fLog << err << "ERROR - MHMatrix does not contain rules... abort." << endl;
        return kFALSE;
    }

    const Int_t ncols = matrixtrain.GetM().GetNcols();
    const Int_t nrows = matrixtrain.GetM().GetNrows();
    if (ncols<=0 || nrows <=0)
    {
        *fLog << err << "ERROR - No. of columns or no. of rows of matrixtrain equal 0 ... abort." << endl;
        return kFALSE;
    }

    const Int_t nbins = grid.GetSize()-1;
    if (nbins<=0)
    {
        *fLog << err << "ERROR - Energy grid not vaild... abort." << endl;
        return kFALSE;
    }

    // rules (= combination of image par) to be used for energy estimation
    TFile fileRF(fFileName, "recreate");
    if (!fileRF.IsOpen())
    {
        *fLog << err << "ERROR - File to store RFs could not be opened... abort." << endl;
        return kFALSE;
    }

    MDataArray usedrules;
    for(Int_t i=0; i<ncols-3; i++) // -3 is important!!!
        usedrules.AddEntry((*matrixtrain.GetColumns())[i].GetRule());

    const TMatrix &m = matrixtrain.GetM();

    // training of RF for each energy bin
    for (Int_t ie=0; ie<nbins; ie++)
    {
        TMatrix mat1(nrows, ncols);
        TMatrix mat0(nrows, ncols);

        // prepare matrix for current energy bin
        Int_t irow1=0;
        Int_t irow0=0;

        for (Int_t j=0; j<nrows; j++)
        {
            const Double_t energy = m(j,ncols-1);

            if (energy>grid[ie] && energy<=grid[ie+1])
                TMatrixFRow(mat1, irow1++) = TMatrixFRow_const(m,j);
            else
                TMatrixFRow(mat0, irow0++) = TMatrixFRow_const(m,j);
        }

        const Bool_t valid = irow1*irow0>0;

        if (!valid)
            *fLog << warn << "WARNING - Skipping";
        else
            *fLog << inf << "Training RF for";

        *fLog << " energy bin " << ie << " (" << grid[ie] << ", " << grid[ie+1] << ")" << endl;

        if (!valid)
            continue;

        gLog.SetNullOutput(kTRUE);

        // last column must be removed (true energy col.)
        mat1.ResizeTo(irow1, ncols-1);
        mat0.ResizeTo(irow0, ncols-1);

        // create MHMatrix as input for RF
        MHMatrix matrix1(mat1, "MatrixHadrons");
        MHMatrix matrix0(mat0, "MatrixGammas");

        matrix1.AddColumns(&usedrules);
        matrix0.AddColumns(&usedrules);

        // training of RF
        MTaskList tlist;
        MParList plist;
        plist.AddToList(&tlist);
        plist.AddToList(&matrix0);
        plist.AddToList(&matrix1);

        MRanForestGrow rfgrow;
        rfgrow.SetNumTrees(fNumTrees); // number of trees
        rfgrow.SetNumTry(fNumTry);     // number of trials in random split selection
        rfgrow.SetNdSize(fNdSize);     // limit for nodesize

        tlist.AddToList(&rfgrow);
    
        MEvtLoop evtloop;
        evtloop.SetDisplay(fDisplay);
        evtloop.SetParList(&plist);

        gLog.SetNullOutput(kFALSE);

        if (!evtloop.Eventloop())
            return kFALSE;

        // save whole forest
        MRanForest *forest=(MRanForest*)plist.FindObject("MRanForest");
        const TString title = Form("%f", TMath::Sqrt(grid[ie]*grid[ie+1]));
        //const TString title = Form("%f", (grid[ie]+grid[ie+1])/2);
        forest->SetTitle(title);
        forest->Write(title);
    }

    // save rules
    usedrules.Write("rules");

    fileRF.Close();

    return kTRUE;
}
/*
Int_t MRFEnergyEst::Test(const MHMatrix &matrixtest)
{
    gLog.Separator("MRFEnergyEst - Test");

    const Int_t ncols = matrixtest.GetM().GetNcols();
    const Int_t nrows = matrixtest.GetM().GetNrows();

    if (ncols<=0 || nrows <=0)
    {
        *fLog << err << dbginf << "No. of columns or no. of rows of matrixtrain equal 0 ... aborting." << endl;
        return kFALSE;
    }

    if (!ReadForests())
    {
        *fLog << err << dbginf << "Reading RFs failed... aborting." << endl;
        return kFALSE;
    }

    const Int_t nbins=fEForests.GetSize();

    Double_t elow =  FLT_MAX;
    Double_t eup  = -FLT_MAX;

    for(Int_t j=0;j<nbins;j++)
    {
        elow = TMath::Min(atof(fEForests[j]->GetTitle()), elow);
        eup  = TMath::Max(atof(fEForests[j]->GetTitle()), eup);
    }

    TH1F hres("hres", "", 1000, -10, 10);
    TH2F henergy("henergy", "",100, elow, eup,100, elow, eup);

    const TMatrix &m=matrixtest.GetM();
    for(Int_t i=0;i<nrows;i++)
    {
        Double_t etrue = m(i,ncols-1);
        Double_t eest  = 0;
        Double_t hsum  = 0;

        for(Int_t j=0;j<nbins;j++)
        {
            const TVector  v = TMatrixFRow_const(m,i);

            const Double_t h = ((MRanForest*)fEForests[j])->CalcHadroness(v);
            const Double_t e = atof(fEForests[j]->GetTitle());

            hsum += h;
            eest += h*e;
        }
        eest /= hsum;
        eest  = pow(10., eest);

        //if (etrue>80.)
        //    hres.Fill((eest-etrue)/etrue);

        henergy.Fill(log10(etrue),log10(eest));
    }

    if(gStyle)
        gStyle->SetOptFit(1);

    // show results
    TCanvas *c=MH::MakeDefCanvas("c","",300,900);

    c->Divide(1,3);
    c->cd(1);
    henergy.SetTitle("Estimated vs true energy");
    henergy.GetXaxis()->SetTitle("log10(E_{true}[GeV])");
    henergy.GetYaxis()->SetTitle("log10(E_{est}[GeV])");
    henergy.DrawCopy();

    c->cd(2);
    TH1F *hptr=(TH1F*)henergy.ProfileX("_px",1,100,"S");
    hptr->SetTitle("Estimated vs true energy - profile");
    hptr->GetXaxis()->SetTitle("log10(E_{true}[GeV])");
    hptr->GetYaxis()->SetTitle("log10(E_{est}[GeV])");
    hptr->DrawCopy();

    c->cd(3);
    hres.Fit("gaus");
    hres.SetTitle("Energy resolution for E_{true}>80Gev");
    hres.GetXaxis()->SetTitle("(E_{estimated}-E_{true})/E_{true})");
    hres.GetYaxis()->SetTitle("counts");
    hres.DrawCopy();

    c->GetPad(1)->SetGrid();
    c->GetPad(2)->SetGrid();
    c->GetPad(3)->SetGrid();

    return kTRUE;
}
*/
Int_t MRFEnergyEst::ReadForests(MParList *plist)
{
    TFile fileRF(fFileName,"read");
    if (!fileRF.IsOpen())
    {
        *fLog << err << dbginf << "File containing RFs could not be opened... aborting." << endl;
        return kFALSE;
    }

    fEForests.Delete();

    TIter Next(fileRF.GetListOfKeys());
    TObject *o=0;
    while ((o=Next()))
    {
        MRanForest *forest;
        fileRF.GetObject(o->GetName(), forest);
        if (!forest)
            continue;

        forest->SetTitle(o->GetTitle());
        forest->SetBit(kCanDelete);

        fEForests.Add(forest);

    }

    if (plist)
    {
        fData = (MDataArray*)plist->FindCreateObj("MDataArray");
        fData->Read("rules");
    }

    fileRF.Close();

    return kTRUE;
}

Int_t MRFEnergyEst::PreProcess(MParList *plist)
{
    fEnergyEst = (MParameterD*)plist->FindCreateObj("MParameterD", "MEnergyEst");
    if (!fEnergyEst)
        return kFALSE;

    if (!ReadForests(plist))
    {
        *fLog << err << "Reading RFs failed... aborting." << endl;
        return kFALSE;
    }

    if (fTestMatrix)
        return kTRUE;

    if (!fData)
    {
        *fLog << err << "MDataArray not found... aborting." << endl;
        return kFALSE;
    }

    if (!fData->PreProcess(plist))
    {
        *fLog << err << "PreProcessing of the MDataArray failed... aborting." << endl;
        return kFALSE;
    }

    return kTRUE;
}

// --------------------------------------------------------------------------
//
//
Int_t MRFEnergyEst::Process()
{
    TVector event;
    if (fTestMatrix)
        *fTestMatrix >> event;
    else
        *fData >> event;

    Double_t eest = 0;
    Double_t hsum = 0;

    TIter Next(&fEForests);
    MRanForest *rf = 0;
    while ((rf=(MRanForest*)Next()))
    {
        const Double_t h = rf->CalcHadroness(event);
        const Double_t e = atof(rf->GetTitle());

        hsum += h;
        eest += h*e;
    }

    fEnergyEst->SetVal(eest/hsum);
    fEnergyEst->SetReadyToSave();

    return kTRUE;
}
