#!/usr/bin/python
#
# Dominik Neise, Werner Lustermann
# TU Dortmund, ETH Zurich
#
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt

class FirFilter(object):
    """ finite impulse response filter 
    
    """
    
    def __init__(self, b, a, name = 'general FIR filter'):
        """ initialize filter coefficients
        
        """
        self.a = a
        self.b = b
        self.name = name
        
    def __call__(self, data):
        length = max(len(self.a),len(self.b))-1
        if length > 0:
            if ( data.ndim == 1):
                initial = np.ones(length)
                initial *= data[0]
            elif ( data.ndim == 2):
                initial = np.ones( (data.shape[0], length) )
                for i in range(data.shape[0]):
                    initial[i,:] *= data[i,0]
            else:
                print 'HELP.'
                pass
            
            filtered, zf = signal.lfilter(self.b, self.a, data, zi=initial)
        else:
            filtered= signal.lfilter(self.b, self.a, data)            
        filtered = filtered.reshape(data.shape)
        return filtered

    def __str__(self):
        s = self.name + '\n'
        s += 'initial condition for filter: signal@rest = 1st sample\n'
        s += 'filter, coefficients:\n'
        s += 'nominator ' + str(self.b) + '\n'
        s += 'denominator ' + str(self.a)
        return s
        
class SlidingAverage(FirFilter):
    """ data smoothing in the time domain with a sliding average
    
    """
    
    def __init__(self, length=8):
        """ initialize the object
        length:  lenght of the averaging window
    
        """
        b = np.ones(length)
        a = np.zeros(length)
        if length > 0:
            a[0] = len(b)
        FirFilter.__init__(self, b, a, 'sliding average')
            

class CFD(FirFilter):
    """ Constant Fraction Discriminator """
    def __init__(self, length = 10., ratio = 0.75):
        
        b = np.zeros(length)
        a = np.zeros(length)
        if length > 0:
            b[0] = -1. * ratio
            b[length-1] = 1.
            a[0] = 1.
        FirFilter.__init__(self, b, a, 'constant fraction discriminator')


class RemoveSignal(FirFilter):
    """ estimator to identify DRS4 spikes
    
    """
    
    def __init__(self):
        """ initialize the object """
        
        b = np.array((-0.5, 1., -0.5)) 
        a = np.zeros(len(b))
        a[0] = 1.0
        FirFilter.__init__(self, b, a, 'remove signal')       


def _test_SlidingAverage():
    """ test the sliding average function
    use a step function as input

    """
    npoints = 100
    safilter = SlidingAverage(8)
    signal = np.zeros(npoints)
    signal[10:50] += 20.

    # add noise to the signal
    sigma = 1.5
    signal += np.random.randn(npoints) * sigma
    
    print safilter
    #print 'signal in:  ', signal
    #print 'signal out: ', rsfilter(signal)
    x=range(npoints)
    plt.plot(x, signal, 'b', label='original')
    plt.plot(x, safilter(signal), 'r', label='filtered')
    plt.title(safilter.name)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()


def _test_CFD():
    """ test the remove signal function
    
    """
    
    filt = CFD(8, 0.6)

    npoints = 100
    signal = np.ones(npoints) * 10.
    signal[20:30] += np.linspace(0., 100., 10)
    signal[30:90] += np.linspace(100., 0., 60)
    # add noise to the signal
    sigma = 1.5
    signal += np.random.randn(npoints) * sigma
    
    print filt
    #print 'signal in:  ', signal
    #print 'signal out: ', rsfilter(signal)
    x=range(npoints)
    plt.plot(x, signal, 'b.', label='original')
    plt.plot(x, filt(signal), 'r.', label='filtered')
    plt.title(filt.name)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()


def _test_RemoveSignal():
    """ test the remove signal function
    
    """
    
    rsfilter = RemoveSignal()

    npoints = 100
    signal = np.ones(npoints) * 10.
    signal[10:20] += np.linspace(0., 100., 10)
    signal[20:80] += np.linspace(100., 0., 60)
    # add noise to the signal
    sigma = 3.
    signal += np.random.randn(npoints) * sigma
    
    print rsfilter
    #print 'signal in:  ', signal
    #print 'signal out: ', rsfilter(signal)
    x=range(npoints)
    plt.plot(x, signal, 'b.', label='original')
    plt.plot(x, rsfilter(signal), 'r.', label='filtered')
    plt.title(rsfilter.name)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()

def _test(filter_type, sig, noise_sigma = 1.):
    
    filt = filter_type
    samples = len(sig)
    # add noise to the signal
    sig += np.random.randn(samples) * noise_sigma
    
    print filt
    x=range(samples)
    plt.plot(x, sig, 'b.', label='original')
    plt.plot(x, filt(sig), 'r.', label='filtered')
    plt.title(filt.name)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()
    

if __name__ == '__main__':
    """ test the class """
    
    _test_SlidingAverage()
    _test_RemoveSignal()
    _test_CFD()
    tsig = np.ones(100)
    _test(filter_type=SlidingAverage(8), sig=tsig, noise_sigma=3.) 
