#!/usr/bin/python -tti

from pyfact import RawData
from drs_spikes import DRSSpikes
import sys
import numpy as np
from ctypes import *
import os
import os.path as path
from ROOT import TFile, TCanvas, TH2F, TTree, TStyle, TObject
import time

dfn = sys.argv[1]
cfn = sys.argv[2]
run = RawData( dfn, cfn) 


file_base_name = path.splitext(path.basename(dfn))[0]
root_filename = '/home_nfs/isdc/neise/' + file_base_name + '_' + time.strftime('%Y%m%d_%H%M%S') +'_spikeana.root'

rootfile = TFile(root_filename, "RECREATE")

baum = TTree('spiketree', 'spike ana tree')

# prepare some vars for the tree
chid               =  c_int(0)
startcell          =  c_int(0)
number_of_singles  =  c_int(0)
position_of_spikes_in_logical_pipeline   =  np.zeros(100, np.int32)
position_of_spikes_in_physical_pipeline  =  np.zeros(100, np.int32)
time_to_previous   =  c_long(0)


baum.Branch('chid',chid,'chid/I')
baum.Branch('sc',startcell,'sc/I')
baum.Branch('n',number_of_singles,'n/I')
baum.Branch('logpos',position_of_spikes_in_logical_pipeline,'logpos[n]/I')
baum.Branch('physpos',position_of_spikes_in_physical_pipeline,'physpos[n]/I')
baum.Branch('time',time_to_previous,'time/I')


def spikecallback(candidates, singles, doubles, data, ind):
    if len(singles) >0 :
        for s in singles:
            s = np.unravel_index(s, data.shape)
            hs.Fill( s[0], s[1])
    if len(doubles) >0 :
        for d in doubles:
            d = np.unravel_index(d, data.shape)
            hd.Fill( d[0], d[1])
            
            
    
despike = DRSSpikes(user_action = spikecallback)

def mars_spikes( data ):
    """
    should search for spikes, just as it is implemented in 
    mcore/DrsCalib.h in DrsCalib::RemoveSpikes
    static void RemoveSpikes(float *vec, uint32_t roi)
    {
        if (roi<4)
            return;

        for (size_t ch=0; ch<1440; ch++)
        {
            float *p = vec + ch*roi;

            for (size_t i=1; i<roi-2; i++)
            {
                if (p[i]-p[i-1]>25 && p[i]-p[i+1]>25)
                {
                    p[i] = (p[i-1]+p[i+1])/2;
                }

                if (p[i]-p[i-1]>22 && fabs(p[i]-p[i+1])<4 && p[i+1]-p[i+2]>22)
                {
                    p[i] = (p[i-1]+p[i+2])/2;
                    p[i+1] = p[i];
                }
            }
        }
    }    
        
    """

    #: list, conaining the (chid,slice) tuple of the single spike positions
    singles = []
    #: list, conaining the (chid,slice) tuple of the 1st double spike positions
    doubles = []
    
    for chid, pdata in enumerate(data):
        single_cand = np.where( np.diff(pdata[:-1]) > 25)[0]
        for cand in single_cand:
            if -np.diff(pdata[1:])[cand] > 25:
                singles.append( (chid, cand) )
        
        double_cand = np.where( np.diff(pdata[:-1]) > 22 )[0]
        for cand in double_cand:
            if cand+1 < len(np.diff(pdata[1:])):
                if abs(-np.diff(pdata[1:])[cand])<4 and -np.diff(pdata[1:])[cand+1] > 22:
                    doubles.append( (chid, cand) )
                
    
    return singles, doubles        





event = run.next()
bt_old = event['board_times'].copy

for event in run:
    data = event['data']
    s, d = mars_spikes(data)
    sc = event['start_cells']
    bt = event['board_times'].copy()
    
    if len(s) >0 :
        chid.value  = s[0][0]
        startcell.value = sc[ chid.value ]
        number_of_singles.value = len(s)
        time.value = bt[s[0][0]/9]-bt_old[s[0][0]/9]
        for i in s:
            log = i[1]
            phys  = (startcell.value+log)%1024
            position_of_spikes_in_logical_pipeline[i] = log
            position_of_spikes_in_physical_pipeline[i] = phys
        baum.Fill()
    
    bt_old = event['board_times'].copy()

baum.Write()

rootfile.Close()