#!/usr/bin/python
#
# Werner Lustermann
# ETH Zurich
#
import numpy as np

import fir_filter as fir

class DRSSpikes(object):
    """ remove spikes (single or double false readings) from DRS4 data
    Strategy:
    * filter the data, removing the signal, thus spike(s) are clearly visible
    * search single and double spikes
    * replace the spike by a value derived from the neighbors
    
    """
    
    def __init__(self, threshold, 
                 single_pattern=np.array( [0.5, 1.0, 0.5]) ,
                 double_pattern=np.array([1., 1., 1., 1.]), 
                 debug = False):
        """ initialize spike filter 
        template_single: template of a single slice spike
        template_double: template of a two slice spike

        """

        self.threshold = threshold
        self.single_pattern = list(single_pattern * threshold)
        self.double_pattern = list(double_pattern * threshold)
        
        self.remove_signal = fir.RemoveSignal()
        
    def __call__(self, data):

        self.row, self.col = data.shape
        indicator = self.remove_signal(data)
        a = indicator.flatten()
        singles = []
        doubles = []
       
        #: find single spikes
        p = self.single_pattern
        for i, x in enumerate(zip(-a[:-2], a[1:-1], -a[2:])):
            if ( (x[0]>p[0]) & (x[1] > p[1]) & (x[2] > p[2]) ):
                singles.append(i)

        #: find double spike
        p = self.double_pattern
        for i, x in enumerate(zip(-a[:-3], a[1:-2], a[2:-1], -a[3:])):
            if (x[0] > p[0]) & (x[1] > p[1]) & (x[2] > p[2]) & (x[3] > p[3]):
                doubles.append(i)

        print 'singles: ', singles
        print 'doubles: ', doubles

        data = self.remove_single_spikes(singles, data)
        data = self.remove_double_spikes(doubles, data)
        return data

    def remove_single_spikes(self, singles, data):
        data = data.flatten() 
        for spike in singles:
            data[spike] = (data[spike-1] + data[spike+1]) / 2.
        return data.reshape(self.row, self.col)
    
    def remove_double_spikes(self, doubles, data):
        data = data.flatten() 
        for spike in doubles:
            data[spike:spike+2] = (data[spike-1] + data[spike+2]) / 2.
        return data.reshape(self.row, self.col)


def _test():
  
    a = np.ones((3,12)) * 3.
    a[0,3] = 7.
    a[1,7] = 14.
    a[1,8] = 14.
    a[2,4] = 50.
    a[2,5] = 45.
    a[2,8] = 20.
    
    print a 

    SpikeRemover = DRSSpikes(3.)
    print 'single spike pattern ', SpikeRemover.single_pattern
    print 'double spike pattern ', SpikeRemover.double_pattern
    afilt = SpikeRemover(a)
    print afilt
    
if __name__ == '__main__':
    """ test the class """
    _test()
    
    
