#!/usr/bin/python
#
# Werner Lustermann
# ETH Zurich
#
from ctypes import *
import numpy as np
from scipy import signal

# get the ROOT stuff + my shared libs
from ROOT import gSystem
# fitslib.so is made from fits.h and is used to access the data
gSystem.Load('~/py/fitslib.so')
from ROOT import *


class RawData( object ):
    """ raw data access and calibration
    
    - open raw data file and drs calibration file
    - performs amplitude calibration
    - performs baseline substraction if wanted
    - provides all data in an array:
      row = number of pixel
      col = length of region of interest
      
    """

    def __init__(self, data_file_name,
                 calib_file_name, baseline_file_name=''):
        """ initialize object

        open data file and calibration data file
        get basic information about the data in data_file_name
        allocate buffers for data access

        data_file_name   : fits or fits.gz file of the data including the path
        calib_file_name : fits or fits.gz file containing DRS calibration data
        baseline_file_name : npy file containing the baseline values
        
        """

        self.data_file_name = data_file_name
        self.calib_file_name = calib_file_name
        self.baseline_file_name = baseline_file_name
        
        # baseline correction: True / False
        if len(baseline_file_name) == 0:
            self.correct_baseline = False
        else:
            self.correct_baseline = True
        
        # access data file
        try:
            data_file = fits(self.data_file_name)
        except IOError:
            print 'problem accessing data file: ', data_file_name
            raise  # stop ! no data
        #: data file (fits object)
        self.data_file = data_file
        
        # get basic information about the data file
        #: region of interest (number of DRS slices read)
        self.nroi    = data_file.GetUInt('NROI')
        #: number of pixels (should be 1440)
        self.npix    = data_file.GetUInt('NPIX')
        #: number of events in the data run
        self.nevents = data_file.GetNumRows()
        
        # allocate the data memories
        self.event_id = c_ulong()
        self.trigger_type = c_ushort()
        #: 1D array with raw data
        self.data  = np.zeros( self.npix * self.nroi, np.int16 )
        #: slice where drs readout started
        self.start_cells = np.zeros( self.npix, np.int16 )

        # set the pointers to the data++
        data_file.SetPtrAddress('Event ID', self.event_id)
        data_file.SetPtrAddress('TriggerType', self.trigger_type)
        data_file.SetPtrAddress('StartCellData', self.start_cells) 
        data_file.SetPtrAddress('Data', self.data) 
                
        # open the calibration file
        try:
            calib_file = fits(self.calib_file_name)
        except IOError:
            print 'problem accessing calibration file: ', calib_file_name
            raise
        #: drs calibration file
        self.calib_file = calib_file
        
        baseline_mean = calib_file.GetN('BaselineMean')
        gain_mean = calib_file.GetN('GainMean')
        trigger_offset_mean = calib_file.GetN('TriggerOffsetMean')

        self.blm = np.zeros(baseline_mean, np.float32)
        self.gm  = np.zeros(gain_mean, np.float32)
        self.tom = np.zeros(trigger_offset_mean, np.float32)

        self.Nblm = baseline_mean / self.npix
        self.Ngm  = gain_mean / self.npix
        self.Ntom  = trigger_offset_mean / self.npix

        calib_file.SetPtrAddress('BaselineMean', self.blm)
        calib_file.SetPtrAddress('GainMean', self.gm)
        calib_file.SetPtrAddress('TriggerOffsetMean', self.tom)
        calib_file.GetRow(0)

        self.v_bsl = np.zeros(self.npix)  # array of baseline values (all ZERO)
        self.data_saverage_out = None
        self.pulse_time_of_maximum = None
        self.pulse_amplitude = None

    def next_event(self):
        """ load the next event from disk and calibrate it
        
        """

        self.data_file.GetNextRow()
        self.calibrate_drs_amplitude()

    def calibrate_drs_amplitude(self):
        """ perform the drs amplitude calibration of the event data
        
        """

        to_mV = 2000./4096.
        #: 2D array with amplitude calibrated dat in mV
        acal_data = self.data * to_mV  # convert ADC counts to mV

        # make 2D arrays: row = pixel, col = drs_slice
        acal_data = np.reshape(acal_data, (self.npix, self.nroi) )
        blm = np.reshape(self.blm, (self.npix, 1024) )
        tom = np.reshape(self.tom, (self.npix, 1024) )
        gm  = np.reshape(self.gm,  (self.npix, 1024) )
        
        for pixel in range( self.npix ):
            # rotate the pixel baseline mean to the Data startCell
            blm_pixel = np.roll( blm[pixel,:], -self.start_cells[pixel] )
            acal_data[pixel,:] -= blm_pixel[0:self.nroi]
            acal_data[pixel,:] -= tom[pixel, 0:self.nroi]
            acal_data[pixel,:] /= gm[pixel,  0:self.nroi]
            
        self.acal_data = acal_data * 1907.35

        
    def filter_sliding_average(self, window_size=4):
        """ sliding average filter

        using:
            self.acal_data
        filling array:
            self.data_saverage_out

        """

        #scipy.signal.lfilter(b, a, x, axis=-1, zi=None)
        data_saverage_out = self.acal_data.copy()
        b = np.ones( window_size )
        a = np.zeros( window_size )
        a[0] = len(b)
        data_saverage_out[:,:] = signal.lfilter(b, a, data_saverage_out[:,:])

        #: data output of sliding average filter
        self.data_saverage_out = data_saverage_out

        
    def filter_CFD(self, length=10, ratio=0.75):
        """ constant fraction discriminator (implemented as FIR)
        
        using:
            self.data_saverage_out
        filling array:
            self.data_CFD_out

        """
        
        if self.data_saverage_out == None:
            print """error pyfact.filter_CFD was called without
            prior call to filter_sliding_average
            variable self.data_saverage_out is needed
            """
            
        data_CFD_out = self.data_saverage_out.copy()
        b = np.zeros(length)
        a = np.zeros(length)
        b[0] = -1. * ratio
        b[length-1] = 1.
        a[0] = 1.
        data_CFD_out[:,:] = signal.lfilter(b, a, data_CFD_out[:,:])
        
        #: data output of the constant fraction discriminator
        self.data_CFD_out = data_CFD_out
 
    def find_peak(self, min=30, max=250):
        """ find maximum in search window
        
        using: 
            self.data_saverage_out
        filling arrays:
            self.pulse_time_of_maximum
            self.pulse_amplitude

        """

        if self.data_saverage_out == None:
            print 'error pyfact.find_peakMax was called without prior call to filter_sliding_average'
            print ' variable self.data_saverage_out is needed '
            pass

        pulse_time_of_maximum = np.argmax( self.data_saverage_out[:,min:max], 1)
        pulse_amplitude = np.max( self.data_saverage_out[:,min:max], 1)
        self.pulse_time_of_maximum = pulse_time_of_maximum
        self.pulse_amplitude = pulse_amplitude

    def sum_around_peak(self, left=13, right=23):
        """ integrate signal in gate around Peak

        using:
            self.pulse_time_of_maximum
            self.acal_data
        filling array:
            self.pulse_integral_simple
            
        """
        
        if self.pulse_time_of_maximum == None:
            print 'error pyfact.sum_around_peak was called without prior call of find_peak'
            print ' variable self.pulse_time_of_maximum is needed'
            pass

        # find left and right limit and sum the amplitudes in the range
        pulse_integral_simple = np.empty(self.npix)
        for pixel in range(self.npix):
            min = self.pulse_time_of_maximum[pixel]-left
            max = self.pulse_time_of_maximum[pixel]+right
            pulse_integral_simple[pixel] = self.acal_data[pixel,min:max].sum()
        
        self.pulse_integral_simple = pulse_integral_simple
        
    def baseline_read_values(self, file, bsl_hist='bsl_sum/hplt_mean'):
        """
        
        open ROOT file with baseline histogram and read baseline values
        file       name of the root file
        bsl_hist   path to the histogram containing the basline values

        """

        try:
            f = TFile(file)
        except:
            print 'Baseline data file could not be read: ', file
            return
        
        h = f.Get(bsl_hist)

        for i in range(self.npix):
            self.v_bsl[i] = h.GetBinContent(i+1)

        f.Close()
        
    def baseline_correct(self):
        """ subtract baseline from the data

        """
        
        for pixel in range(self.npix):
            self.acal_data[pixel,:] -= self.v_bsl[pixel]
                    
    def info(self):
        """ print run information
        
        """
        
        print 'data file:  ', data_file_name
        print 'calib file: ', calib_file_name
        print 'calibration file'
        print 'N baseline_mean: ', self.Nblm
        print 'N gain mean: ', self.Ngm
        print 'N TriggeroffsetMean: ', self.Ntom
        
# --------------------------------------------------------------------------------
class fnames( object ):
    """ organize file names of a FACT data run

    """
    
    def __init__(self, specifier = ['012', '023', '2011', '11', '24'],
                 rpath = '/scratch_nfs/res/bsl/',
                 zipped = True):
        """
        specifier : list of strings defined as:
            [ 'DRS calibration file', 'Data file', 'YYYY', 'MM', 'DD']
            
        rpath     : directory path for the results; YYYYMMDD will be appended to rpath
        zipped    : use zipped (True) or unzipped (Data) 

        """
        
        self.specifier = specifier
        self.rpath     = rpath
        self.zipped    = zipped
        
        self.make( self.specifier, self.rpath, self.zipped )


    def make( self, specifier, rpath, zipped ):
        """ create (make) the filenames

        names   : dictionary of filenames, tags { 'data', 'drscal', 'results' }
        data    : name of the data file
        drscal  : name of the drs calibration file
        results : radikal of file name(s) for results (to be completed  by suffixes)
        """

        self.specifier = specifier
        
        if zipped:
            dpath = '/data00/fact-construction/raw/'
            ext   = '.fits.gz'
        else:
            dpath = '/data03/fact-construction/raw/'
            ext   = '.fits'
    
        year  = specifier[2]
        month = specifier[3]
        day   = specifier[4]
        
        yyyymmdd = year + month + day
        dfile = specifier[1]
        cfile = specifier[0]

        rpath = rpath + yyyymmdd + '/'
        self.rpath = rpath 
        self.names = {}

        tmp = dpath + year + '/' + month + '/' + day + '/' + yyyymmdd + '_'
        self.names['data']  =  tmp + dfile + ext
        self.names['drscal'] = tmp + cfile + '.drs' + ext
        self.names['results'] =  rpath + yyyymmdd + '_' + dfile + '_' + cfile 

        self.data    = self.names['data']
        self.drscal  = self.names['drscal']
        self.results = self.names['results']

    def info( self ):
        """ print complete filenames

        """
        
        print 'file names:'
        print 'data:    ', self.names['data']
        print 'drs-cal: ', self.names['drscal']
        print 'results: ', self.names['results']

# end of class definition: fnames( object )

if __name__ == '__main__':
    """
    create an instance
    """
    data_file_name = '/data03/fact-construction/raw/2011/11/24/20111124_121.fits'
    calib_file_name = '/data03/fact-construction/raw/2011/11/24/20111124_111.drs.fits'
    rd = rawdata( data_file_name, calib_file_name )
    rd.info()
    rd.next()
    
# for i in range(10):
#    df.GetNextRow() 

#    print 'evNum: ', evNum.value
#    print 'start_cells[0:9]: ', start_cells[0:9]
#    print 'evData[0:9]: ', evData[0:9]
