#!/usr/bin/python
#
# Dominik Neise, Werner Lustermann
# TU Dortmund, ETH Zurich
#
import numpy as np
import matplotlib.pyplot as plt

class SignalGenerator(object):
    """ Signal Generator
        generates signals for testing several helper classes like:
            * fir filters
            * signal extractors
    """
    
    def __init__(self, option_str = 'len 100 noise 3', name = 'SignalGenerator'):
        """ initialize the generator
            sets default signal to generate
        """
        self.option_str = option_str.lower()
        self.options = make_options_from_str(option_str)
        self.parse_options()
        self.name = name
        
    def parse_options(self):
        o = self.options #shortcut
        if 'len' in o:
            self.npoints = int(o['len'][0])
        else: 
            self.npoints = 100
        if 'noise' in o:
            self.sigma = float(o['noise'][0])
        else:
            self.sigma = 1
        if 'bsl' in o:
            self.bsl = float(o['bsl'][0])
        else:
            self.bsl = -0.5
        
        if 'step' in o:
            self.step_height = float(o['step'][0])
            self.step_start = int(o['step'][1])
            self.step_stop = int(o['step'][2])

        if 'triangle' in o:
            self.triangle_height = float(o['triangle'][0])
            self.triangle_pos = float(o['triangle'][1])
            self.triangle_rise = int(o['triangle'][2])
            self.triangle_fall = int(o['triangle'][3])

        if 'spike' in o:
            self.spikes = []
            for i in range(len(o['spike'])/2):
                self.spikes.append( ( int(o['spike'][2*i]), float(o['spike'][2*i+1]) ) )

    def __call__(self, option_str = ''):
        self.option_str = option_str.lower()
        if self.option_str:
            self.options = make_options_from_str(self.option_str)
            self.parse_options()

        signal = np.zeros(self.npoints)
        signal += self.bsl
        signal += np.random.randn(self.npoints) * self.sigma
        if 'step' in self.options:
            signal[self.step_start:self.step_stop] += self.step_height
        if 'triangle' in self.options:
            start = self.triangle_pos - self.triangle_rise
            stop = self.triangle_pos + self.triangle_fall
            pos = self.triangle_pos
            height = self.triangle_height
            signal[start:pos] += np.linspace(0., height, self.triangle_rise)
            signal[pos:stop] += np.linspace(height, 0. , self.triangle_fall)
        for spike in self.spikes:
            signal[spike[0]] += spike[1]
        return signal 

    def __str__(self):
        s = self.name + '\n'
        s += 'possible options and parameters\n'
        s += ' * len:      number of samples (100)\n'
        s += ' * noise:    sigma (1)\n'
        s += ' * bsl:      level (-0.5)\n'
        s += ' * step:     height, start, end\n'
        s += ' * triangle: height, position, risingedge, fallingedge\n'
        s += ' * spike:    pos height [pos height ...]\n'
        
        s += 'current options are:\n'
        for key in self.options.keys():
            s += key + ':' + str(self.options[key]) + '\n'
        return s


# Helper function to parse signalname and create a dictionary
# dictionary layout :
# key : string
# value : [list of parameters]
def make_options_from_str(signalname):
    options = {}
    for word in (signalname.lower()).split():
        if word.isalpha():
            current_key = word
            options[current_key] = []
#        if word.isdigit():
        else:
            options[current_key].append(word)
#        else:
#            print '-nothing'
    return options
    
def plotter(signal, text):
    x=range(len(signal))
    ax = plt.plot(x, signal, 'b.', label='signal')
    plt.title('test of SignalGenerator with option string:\n' + text)
    plt.xlabel('sample')
    plt.legend()
    plt.grid(True)
    plt.show()
    
if __name__ == '__main__':
    """ test the class """
    myGenerator = SignalGenerator()
    sig = myGenerator('len 100 noise 0.3 bsl -2.5 step 20.3 12 24 triangle 10.2 50 10 30 spike 2 50. 20 50')
    print myGenerator
    plotter(sig, myGenerator.option_str)