#!/usr/bin/python
#
# Werner Lustermann, Dominik Neise
# ETH Zurich, TU Dortmund
#
# plotter.py

import numpy as np
import matplotlib.pyplot as plt

class Plotter(object):
    """ simple x-y plot """
    def __init__(self, name, x=None, style = '.:', xlabel='x', ylabel='y', ion=True, grid=True, fname=None):
        """ initialize the object """
                
        self.name  = name
        self.x = x
        self.style = style
        self.xlabel = xlabel
        self.ylabel = ylabel
        
        #not sure if this should go here
        if ion:
            plt.ion()

        self.figure = plt.figure()
        self.fig_id = self.figure.number
        
        plt.grid(grid)
        self.fname = fname
           
    def __call__(self, ydata, label=None):
        """ set ydata of plot """
        style = self.style
        
        # make acitve and clear
        plt.figure(self.fig_id)
        plt.cla()
        
        # the following if else stuff is horrible,
        # but I want all those possibilities, .... still working on it.
        
        # check if 1Dim oder 2Dim
        ydata = np.array(ydata)
        if ydata.ndim ==1:
            if self.x==None:
                plt.plot(ydata, self.style, label=label)
            else:
                plt.plot(self.x, ydata, self.style, label=label)
        else:
            for i in range(len(ydata)):
                if self.x==None:
                    if label:
                        plt.plot(ydata[i], style, label=label[i])
                    else:
                        plt.plot(ydata[i], style)
                else:
                    if label:
                        plt.plot(self.x, ydata[i], style, label=label[i])
                    else:
                        plt.plot(self.x, ydata[i], style)
        plt.title(self.name)
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)
        if label:
            plt.legend()
        
        if self.fname != None:
            plt.savefig(self.fname)
        
        plt.draw()
            
        
class CamPlotter(object):
    """ plotting data color-coded into FACT-camera  """
    def __init__(self, name, ion=True, grid=True, fname=None, map_file_path = '../map_dn.txt', vmin=None, vmax=None):
        """ initialize the object """
        self.name  = name
        
        if ion:
            plt.ion()

        chid, y,x,xe,ye,yh,xh,softid,hardid = np.loadtxt(map_file_path ,unpack=True)
        self.xe = xe
        self.ye = ye

        self.H = (6,0,30./180.*3.1415926)
        
        self.figure = plt.figure(figsize=(6, 6), dpi=80)
        self.fig_id = self.figure.number
        
        self.grid = grid
        self.fname = fname
        self.vmin = vmin
        self.vmax = vmax
        
    def __call__(self, data):
        xe = self.xe
        ye = self.ye
        H = self.H
        name = self.name
        grid = self.grid
        vmin = self.vmin
        vmax = self.vmax

        plt.figure(self.fig_id)
        plt.clf()
        self.ax = self.figure.add_subplot(111, aspect='equal')
        self.ax.axis([-22,22,-22,22])
        self.ax.set_title(name)
        self.ax.grid(grid)
        
        result = self.ax.scatter(xe,ye,s=25,alpha=1, c=data, marker=H, linewidths=0., vmin=vmin, vmax=vmax)
        self.figure.colorbar( result, shrink=0.8, pad=-0.04 )

        plt.draw()

class HistPlotter(object):
    
    def __init__(self, name, bins, range, grid=True, ion=True):
        """ initialize the object """
        self.bins = bins
        self.range = range
        self.name  = name
        self.figure = plt.figure()
        self.fig_id = self.figure.number
        self.grid = grid
        
        if ion:
            plt.ion()
        
    def __call__(self, ydata, label=None, log=False):
        plt.figure(self.fig_id)
        plt.cla()

        bins = self.bins
        range = self.range
        grid = self.grid
        
        ydata = np.array(ydata)
        
        if ydata.ndim > 1:
            ydata = ydata.flatten()
        
        plt.hist(ydata, bins, range, label=label, log=log)
        if label:
            plt.legend()
        plt.title(self.name)
        
        plt.draw()
            


def _test_Plotter():
    """ test of maintaining two independant plotter instances 
        with different examples for init and call
    """
    x = np.linspace(0., 2*np.pi , 100)
    plot1 = Plotter('plot1', x, 'r.:')
    plot2 = Plotter('plot2')
    
    y1 = np.sin(x) * 7
    plot1(y1)
    
    number_of_graphs_in_plot2 = 3
    no = number_of_graphs_in_plot2  # short form
    
    # this is where you do your analysis...
    y2 = np.empty( (no, len(x)) )   # prepare some space
    y2_labels = []                  # prepare labels
    for k in range(no):
        y2[k] = np.sin( (k+1)*x )
        y2_labels.append('sin(%d*x)' % (k+1) )
        
    # plot the result of your analysis
    plot2(y2, y2_labels)
    raw_input('next')       # do not forget this line, or your graph is lost
    
    plot1(np.cos(x) * 3.)
    plot2.name += ' without labels!!!' # changing titles 'on the fly' is possible
    plot2(y2)
    raw_input('next')       # DO NOT forget 


def _test_CamPlotter():
    """ test of CamPlotter """
    
    c1 = np.random.randn(1440)
    c2 = np.linspace(0., 1., num=1440)
    plot1 = CamPlotter('plot1')
    plot2 = CamPlotter('plot2')
    
    plot1(c1)
    plot2(c2)
    raw_input('next')
    
def _test_HistPlotter():
    """ test of the HistPlotter """
    plt.ion()

    data = np.random.randn(1000)
    hp = HistPlotter('test hist plotter',34, (-5,4))
    
    hp(data, 'test-label')
    raw_input('next')

if __name__ == '__main__':
    """ test the class """
    _test_Plotter()
    _test_CamPlotter()
    _test_HistPlotter()