 #!/usr/bin/python
#
# Dominik Neise, Werner Lustermann
# TU Dortmund, ETH Zurich
#
import numpy as np
import matplotlib.pyplot as plt
#import * from fir_filter

class GlobalMaxFinder(object):
    """ Pulse Extractor
        Finds the global maximum in the given window.
        (Best used with filtered data)
    """
    
    def __init__(self, min=30, max=250 , name = 'GlobalMaxFinder'):
        """ initialize search Window
        
        """
        self.min = min
        self.max = max
        self.name = name
        
    def __call__(self, data):
        time = np.argmax( data[ : , self.min:self.max ], 1)
        amplitude = np.max( data[ : , self.min:self.max], 1)
        return amplitude, time

    def __str__(self):
        s = self.name + '\n'
        s += 'window:\n'
        s += '(min,max) = (' + str(self.min) + ',' + str(self.max) + ')\n'
        return s


class FixedWindowIntegrator(object):
    """ Integrates in a given intergration window
    """
    
    def __init__(self, min=55, max=105 , name = 'FixedWindowIntegrator'):
        """ initialize integration Window
        """
        self.min = min
        self.max = max
        self.name = name
        
    def __call__(self, data):
        integral = np.empty( data.shape[0] )
        for pixel in range( data.shape[0] ):
            integral[pixel] = data[pixel, self.min:self.max].sum()
        return integtral

    def __str__(self):
        s = self.name + '\n'
        s += 'window:\n'
        s += '(min,max) = (' + str(self.min) + ',' + str(self.max) + ')\n'
        return s    

class ZeroXing(object):
    """ Finds zero crossings in given data
        (should be used on CFD output for peak finding)
        returns list of lists of time_of_zero_crossing
    """
    def __init__(self, slope=1, name = 'ZeroXing'):
        if (slope >= 0):
            self.slope = 1  # search for rising edge crossing
        elif (slope < 0):
            self.slope = -1 # search for falling edge crossing
        self.name = name


    def __call__(self, data):
        all_hits = []
        for pix_data in data:
            hits = []
            for i in range( data.shape[1] ):
                if ( slope > 0 ):
                    if ( pix_data[i] > 0 ):
                        continue
                else:
                    if ( pix_data[i] < 0):
                        continue
                if ( pix_data[i] * pix_data[i+1] <= 0 ):
                    # interpolate time of zero crossing with 
                    # linear polynomial: y = ax + b
                    a = (pix_data[i+1] - pix_data[i]) / ((i+1) - i)
                    time = -1.0/a * pix_data[i] + i
                    hits.append(time)
            all_hits.append(hits)
        return all_hits

    def __str__(self):
        s = self.name + '\n'
        if (self.slope == 1):
            s += 'search for rising edge crossing.\n'
        else:
            s += 'search for falling edge crossing.\n'
        return s



if __name__ == '__main__':
    """ test the extractors """
    
    gmf = GlobalMaxFinder((12,40))
    print gmf
    fwi = FixedWindowIntegrator(1,3)
    print fwi
    zx = ZeroXing()
    print zx