#!/usr/bin/python
#
# Dominik Neise, Werner Lustermann
# TU Dortmund, ETH Zurich
#
import numpy as np
import matplotlib.pyplot as plt
from generator import *
from fir_filter import *

class GlobalMaxFinder(object):
    """ Pulse Extractor
        Finds the global maximum in the given window.
        (Best used with filtered data)
    """
    
    def __init__(self, min=30, max=250 , name = 'GlobalMaxFinder'):
        """ initialize search Window
        
        """
        self.min = min
        self.max = max
        self.name = name
        
    def __call__(self, data):
        if data.ndim > 1:
            time = np.argmax( data[ : , self.min:self.max ], 1)
            amplitude = np.max( data[ : , self.min:self.max], 1)
        else:
            time = np.argmax( data[self.min:self.max])
            amplitude = np.max( data[self.min:self.max])
        return amplitude, time+self.min

    def __str__(self):
        s = self.name + '\n'
        s += 'window:\n'
        s += '(min,max) = (' + str(self.min) + ',' + str(self.max) + ')'
        return s
        
    def test(self):
        pass


class FixedWindowIntegrator(object):
    """ Integrates in a given intergration window
    """
    
    def __init__(self, min=55, max=105 , name = 'FixedWindowIntegrator'):
        """ initialize integration Window
        """
        self.min = min
        self.max = max
        self.name = name
        
    def __call__(self, data):
        integral = np.empty( data.shape[0] )
        for pixel in range( data.shape[0] ):
            integral[pixel] = data[pixel, self.min:self.max].sum()
        return integral

    def __str__(self):
        s = self.name + '\n'
        s += 'window:\n'
        s += '(min,max) = (' + str(self.min) + ',' + str(self.max) + ')'
        return s    

class ZeroXing(object):
    """ Finds zero crossings in given data
        (should be used on CFD output for peak finding)
        returns list of lists of time_of_zero_crossing
    """
    def __init__(self, slope=1, name = 'ZeroXing'):
        if (slope >= 0):
            self.slope = 1  # search for rising edge crossing
        elif (slope < 0):
            self.slope = -1 # search for falling edge crossing
        self.name = name


    def __call__(self, data):
        all_hits = []
        for pix_data in data:
            hits = []
            for i in range( data.shape[1]-1 ):
                if ( self.slope > 0 ):
                    if ( pix_data[i] > 0 ):
                        continue
                else:
                    if ( pix_data[i] < 0):
                        continue
                if ( pix_data[i] * pix_data[i+1] <= 0 ):
                    # interpolate time of zero crossing with 
                    # linear polynomial: y = ax + b
                    a = (pix_data[i+1] - pix_data[i]) / ((i+1) - i)
                    time = -1.0/a * pix_data[i] + i
                    hits.append(time)
            all_hits.append(hits)
        return all_hits

    def __str__(self):
        s = self.name + '\n'
        if (self.slope == 1):
            s += 'search for rising edge crossing.\n'
        else:
            s += 'search for falling edge crossing.\n'
        return s


def plotter(signal, text):
    x=range(len(signal))
    ax = plt.plot(x, signal, 'b.', label='signal')
    plt.title(text)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()
    
def histplotter(signal, text):
    plt.xlabel('time of zero crossing')
    plt.title(text)
    plt.axis([0, 300,0,1440])
    plt.grid(True)
    n, bins, patches = plt.hist(signal, 3000, facecolor='r', alpha=0.75)
    plt.show()


if __name__ == '__main__':
    """ test the extractors """
    sg = SignalGenerator()
    pulse_str = 'len 300 bsl -0.5 noise 0.5 triangle 65 10 8 100' 
    pulse = sg(pulse_str)
    event = []
    for i in range(1440):
        event.append(sg(pulse_str))
    event = np.array(event)
    print 'test event with 1000 pixel generated, like this:'
    print pulse_str
    print
    
    # Test GlobalMaxFinder
    gmf = GlobalMaxFinder(30,250)
    print gmf
    amplitude, time = gmf(event)
    if abs(amplitude.mean() - 10) < 0.5:
        print "Test 1: OK GlobalMaxFinder found amplitude correctly", amplitude.mean()
    if abs(time.mean() - 65) < 2:
        print "Test 1: OK GlobalMaxFinder found time correctly", time.mean()
    else:
        print "BAD: time mean:", time.mean()
    
    print 
    
    # Test FixedWindowIntegrator
    fwi = FixedWindowIntegrator(50,200)
    print fwi
    integral = fwi(event)
    #value of integral should be: 150*bsl + 8*10/2 + 100*10/2 = 465
    if abs( integral.mean() - 465) < 2:
        print "Test 2: OK FixedWindowIntegrator found integral correctly", integral.mean()
    else:
        print "Test 2:  X FixedWindowIntegrator integral.mean failed:", integral.mean()
    
    print
    cfd = CFD()
    sa = SlidingAverage(8)
    print sa
    cfd_out = sa(event)
    cfd_out = cfd(cfd_out   )
    cfd_out = sa(cfd_out)
    zx = ZeroXing()
    print zx
    list_of_list_of_times = zx(cfd_out)
    times = []
    for list_of_times in list_of_list_of_times:
        times.extend(list_of_times)
    times = np.array(times)
    
    hist,bins = np.histogram(times,3000,(0,300))
    most_probable_time = np.argmax(hist)
    print 'most probable time of zero-crossing', most_probable_time/10.
    print 'this includes filter delays ... for average filter setting 8 this turns out to be 78.8 most of the time'
    
    #histplotter(times.flatten(), 'times.flatten()')
    #plotter(cfd_out[0], 'cfd_out')
    #plotter (pulse, pulse_str)
    


