#!/usr/bin/python -itt

import struct
import sys
import numpy as np
from pprint import pprint
import rlcompleter
import readline
readline.parse_and_bind('tab: complete')
from ROOT import *
import readcorsika
import matplotlib.pyplot as plt

from Turnable import *

import simulation_helpers as sh

class Mirror( Turnable ):
    """ Mirror description/abstraction class
        
        from Turnable:
            turn( axis, angle) - turn the mirror around *axis* by *angle*
        
        *pos* - the position of the center of the mirror plane
        
        *turnables*:
            *pos* 
            *dir*

    """
    def __init__(self, index, pos, normal_vector, focal_length, hex_size ):
        self.pos = pos
        self.dir = normal_vector
        
        # list for Turnable.turn()
        self._turnables = [ "pos", "dir"]
        
        self.index = index
        self.focal_length = focal_length
        self.hex_size = hex_size
    
    # standard __repr__
    def __repr__( self ):
        return "%s(%r)" % (self.__class__, self.__dict__)

class Photon( Turnable ):
    """ Photon description/abstraction class
        
        from Turnable:
            turn( axis, angle) - turn the mirror around *axis* by *angle*
        
        *pos* - the position of the photon in 3D
        *dir* - the direction of flight in 3D
        *time*- time since 2st interaction [ns]
        *wavelength* - wavelength in [nm]
        *mother* - particle ID of mother particle
        
        *turnables*:
            *pos* 
            *dir*
        
    """
    def __init__(self, photon_definition_array ):
        """ Construct Photon form 10-element 1D-np.array
        
        the photon constructor understands the 10-element 1D-np.array
        which is stored inside a run.event.photons 2D-np.array

        the *photon_definition_array* pda contains:
            pda[0] - encoded info
            pda[1:3] - x,y position in cm
            pda[3:5] - u,v cosines to x,y axis  --> so called direction cosines
            pda[5] - time since first interaction [ns]
            pda[6] - height of production in cm
            pda[7] - j ??
            pda[8] - imov ??
            pda[9] - wavelength [nm]
        """
        pda = photon_definition_array
        
        self.pos = np.array([pda[1],pda[2],0.])
        self.dir = np.array([pda[3],pda[4], np.sqrt(1-pda[3]**2-pda[4]**2) ])
        self._turnables = ("pos", "dir")
        
        self.wavelength = pda[9]
        self.time = pda[5]
    
    # standard __repr__
    def __repr__( self ):
        return "%s(%r)" % (self.__class__, self.__dict__)


class Focal_Plane( Turnable ):
    """
    """
    def __init__(self, pos, dir, size ):
        self.pos = pos
        self.dir = dir
        self.size = size
        self._turnables = ["pos", "dir"]
        
class Point( Turnable ):
    """
    """
    def __init__(self, pos):
        self.pos = pos
        self._turnables = ["pos"]
    


def read_reflector_definition_file( filename ):
    """
    """
    mirrors = []
    
    f = open( filename )
    for index, line in enumerate(f):
        if line[0] == '#':
            continue
        line = line.split()
        if len(line) < 8:
            continue
        #print line
        
        # first 3 colums in the file are x,y,z coordinates of the center
        # of this mirror in cm, I guess
        pos = np.array(map(float, line[0:3]))
        
        # the next 3 elements are the elements of the normal vector
        # should be normalized already, so the unit is of no importance.
        normal_vector = np.array(map(float,line[3:6]))
        
        # focal length of this mirror in mm
        focal_length = float(line[6])
        
        # size of the hexagonal shaped facette mirror, measured as the radius
        # of the hexagons *inner* circle.
        hex_size = float(line[8])
        mirror = Mirror( index, pos, normal_vector, focal_length, hex_size )
        
        mirrors.append(mirror)
        
    return mirrors




def reflect_photon( photon, mirrors):
    """ finds out: 
            which mirror is hit by photon 
            and where 
            and in which angle relative to mirror 
    """
    
    # the line defined by the photon is used to find the intersection point 
    # with the plane of each facette mirror. Then I check,
    # if the intersection point lies within the limits of the facette mirrors 
    # hexagonal boundaries.
    # If this is the case I have found the mirror, which is hit, and 
    # can calculate:
    # the distance of the intersection point from the center of the facette mirror
    # and the angle relative to the mirror (normal or plane not sure yet)
    
    for mirror in mirrors:
        #facette mirror plane, defined as n . x = d1 . n
        n = mirror.dir
        d1 = mirror.pos
        
        # line of photon defined as r = lambda * v + d2
        v = photon.dir
        d2 = photon.pos
        
        # the intersection coordinates are found by solving
        # n . (lambda * v + d2) - d1 . n == 0, for lambda=lambda_0
        # and then the intersection is: i = lambda_0 * v + d2
        #
        # putting int in another form:
        # solve:
        # lambda * n.v + n.d2 - n.d1 == 0
        # or
        # lambda_0 = n.(d1-d2) / n.v
        
        # FIXME: if one of the two dot-products is very small,
        # we shuold take special care maybe
        # if n.(d1-d2) is very small, this means that the starting point of 
        #   the photon is already nearly in the plane, so lambda_0 is expected to
        #   be very small ... erm .. maybe this is actually not a special case
        #   but very good.
        # of n.v is very small, this means the patch of the photon is nearly
        #   parallel to the plane, so the error ob lambda_0 might be very large.
        #   in addition, this might just tell us, that the mirror is hit under 
        #   strange circumstances ... so its not a good candidate and we can already go on.
        lambda_0 = (np.dot(n,(d1-d2)) / np.dot(n,v))
        
        #intersection between line and plane
        i = lambda_0 * v + d2
        
        # I want the distance beween i and d1 so I can already from the distance find 
        # out if this is our candidate.
        distance = np.sqrt(((i-d1)*(i-d1)).sum())
        
        #print "photon pos:", d2, "\t dir:", v/length(v)
        #print "mirror pos:", d1, "\t dir:", n/length(n)
        
        #print "lambda_0", lambda_0
        #print "intersection :", i
        #print "distance:",distance
        
        if distance <= mirror.hex_size/2.:
            break
        else: 
            mirror = None
    
    if  not mirror is None:
        photon.mirror_index = mirror.index
        photon.mirror_intersection = i
        photon.mirror_center_distance = distance
        #print "mirror found:", mirror.index , 
        #print "distance", distance
        # now I have to find out, if the photon is not only in the 
        # right distance but actually has hit the mirror.
        # this I do like this
        # i-d1    is a vector in the mirror plane pointing from d1 to the intersection point i.
        # if I know turn the entire mirror plane so it lies withing the x-y-plane
        # by applying a simple turning-matrix, then each vector inside the plane with turn into
        # a nice x,x vector. 
        # now I assume, that the hexagon is "pointing" lets say to into y direction
        # so I can e.g. say:
        # x has to be between -30.3 and +30.3 and y has to be
        # between 35 - m * |x| and -35 + m * |x| ... pretty simple.
        # maybe one can leave the turning aside, but I like that I can imagine it very nicely
        # 
        #
        # I don't do this yet .. since I don't know by heart how a turning matrix looks :-)
        # so I just simulate round mirrors
        ######################################################################
        
        
        # next step, since I know the intersection point, is the new direction.
        # So I need the normal of the mirror in the intersection point.
        # since the normal of every mirror is alway pointing to the camera center
        # this is not difficult.
        
        normal_at_intersection = (mirror_alignmen_point.pos - i) / sh.length(mirror_alignmen_point.pos - i)
        #print "normal_at_intersection",normal_at_intersection
        
        angle = np.arccos(np.dot( v, normal_at_intersection) / (sh.length(v) * sh.length(normal_at_intersection)))
        photon.angle_to_mirror_normal = angle
        #print "angle:", angle/np.pi*180., "deg"
        
        
        # okay, now I have the intersection *i*, 
        # the old direction of the photon *v*
        # and the normalvector at the intersection.
        ######################################################################
        ######################################################################
        # I do this now differently.
        # I will mirror the "point" at the tip of *v* at the line created by
        # the normalvector at the intersection and the intersection.
        # this will gibe me a mirrored_point *mp* and the vector from *i* to *mp*
        # is the *new_direction* it should even be normalized.
        
        # 1. step: create plane through the "tip" of *v* and the normal_at_intersection.
        # 2. step: find crossingpoint on line through *i* and the normal_at_intersection,
        # 3. step: vector from "tip" of *v* to crossingpoint times 2 points to 
        #           the "tip" of *mirrored_v*
        
        # plane: n_plane_3 . r = p_plane_3 . n_plane_3
        # p_plane_3 = i+v
        # n_plane_3 = normal_at_intersection
        
        # line: r = lambda_3 * v_line_3 + p_line_3
        # p_line_3 = i
        # v_line_3 = normal_at_intersection
        
        # create crossing: n_plane_3 . (lambda_3 * v_line_3 + p_line_3) = p_plane_3 . n_plane_3
        #   <=> lambda_3 = (p_plane_3 - p_line_3 ).n_plane_3  / n_plane_3 . v_line_3
        #   <=> lambda_3 = (i+v - i).normal_at_intersection  / normal_at_intersection . normal_at_intersection
        #   <=> lambda_3 = v.normal_at_intersection
        
        lambda_3 = np.dot(v, normal_at_intersection)
        #print "lambda_3", lambda_3
        crossing_point_3 = lambda_3 * normal_at_intersection + i
        #print "crossing_point_3", crossing_point_3
        
        from_tip_of_v_to_crossing_point_3 = crossing_point_3 - (i+v)
        
        tip_of_mirrored_v = i+v+ 2*from_tip_of_v_to_crossing_point_3
        
        new_direction = tip_of_mirrored_v - i
        
        #print "new_direction",new_direction
        #print "old direction", v
        photon.new_direction = new_direction
        ######################################################################
        ######################################################################
        """
        # both directions form a plane, and when I turn the old *v* by
        # twice the angle between *v* and *normal_at_intersection* 
        # inside this plane then I get the new direction of the photon.
        
        # so lets first get the normal of the reflection plane
        normal_of_reflection_plane =np.cross( v, normal_at_intersection)
        
        print length(normal_of_reflection_plane), "should be one"
        print length(v), "should be one"
        print length(normal_at_intersection), "should be one"
        print np.dot(v, normal_at_intersection), "should *NOT* be zero"
        print np.dot(v, normal_of_reflection_plane), "should be zero"
        print np.dot(normal_at_intersection, normal_of_reflection_plane), "should be zero"
        
        angle = np.arccos(np.dot( v, normal_at_intersection) / (length(v) * length(normal_at_intersection)))
        photon.angle_to_mirror_normal = angle
        print "angle:", angle/np.pi*180., "deg"
        
        # the rotation matrix for the rotation of *v* around normal_of_reflection_plane is
        R = make_rotation_matrix( normal_of_reflection_plane, 2*angle )
        
        print "R"
        pprint(R)
        
        new_direction = np.dot( R, v)
        photon.new_direction = new_direction
        
        print "old direction", v, length(v)
        print "new direction", new_direction, length(new_direction)
        print "mirror center", mirror.pos
        print "interception point", i
        print "center of focal plane", focal_plane.pos
        """

        # new the photon has a new direction *new_direction* and is starting
        # from the intersection point *i*
        # now I want to find out where there focal plane is hit.
        # So I have to repeat the stuff from up there
        
        #print "np.dot(focal_plane.dir,new_direction))", np.dot(focal_plane.dir,new_direction)
        
        lambda_1 = (np.dot(focal_plane.dir ,(focal_plane.pos - i)) / np.dot(focal_plane.dir,new_direction))
        
        #print "lambda_1", lambda_1
        focal_plane_spot = lambda_1 * new_direction + i
        #print "focal_plane_spot",focal_plane_spot
        photon.hit = True
        focal_plane_pos = focal_plane_spot - focal_plane.pos
        photon.focal_plane_pos =focal_plane_pos
        #photon.hit = True
        if sh.length(focal_plane_pos) <= focal_plane.size:
            photon.hit = True
        else:
            photon.hit = False
            return photon
        # now as a final step we have to find the coordinates of the vector
        # from the center of the focal plane to the spot where the photon 
        # actually hit the focal plane, as if the plane was not turned.
        # so if we turn the plane back into the x-y-plane
        # our *focal_plane_pos* vector has only two coordinates x,y, which are non-zero.
        # so lets do that.
        # in order to do so, we need the angles, by which the telescope
        # was turned .. we hae made them global variables !!ugly i know!!
        
        R = make_rotation_matrix( np.array([0,1,0]), -1.*telescope_theta/ 180. *np.pi )
        turned_focal_plane_pos = np.dot( R, focal_plane_pos)
        R = make_rotation_matrix( np.array([-1,0,0]), -1.*telescope_phi/ 180. *np.pi )
        turned_focal_plane_pos = np.dot( R, turned_focal_plane_pos)
        photon.turned_focal_plane_pos = turned_focal_plane_pos
        #if np.abs(turned_focal_plane_pos[2] ) > 1e-12:
        #    print turned_focal_plane_pos[2]
        #    raise Exception("the z-coordinate should be zero but is larger than 1e-12")
        
        
        
        #print "distance from focal plane center=",  length(focal_plane_spot-focal_plane.pos)
    else:
        photon.hit = False
    return photon
    

if __name__ == '__main__':
    
    # these three things define my telescope today:
    #   * a set off mirrors, read from a file
    #   * a focal plane, which has a postion, a direction and a size
    #   * and a mirror alignmen point, which is needed to construct the mirror 
    #       normal vectors.
    #
    mirrors = read_reflector_definition_file( "030/fact-reflector.txt" )
    focal_plane = Focal_Plane(  pos=np.array([0.,0.,978.132/2.]),   # center of focal_plane
                                dir=np.array([0., 0., 1.]),         # direction of view
                                size=20 )                           # radius in cm 

    mirror_alignmen_point = Point(  pos=np.array([0.,0.,978.132]) )


    # Now we read the corsika file, which will give us a few thousand photons
    # to work with. But in order to work with these photons,
    # we need to turn our telescope into the right direction. 
    # In order to find the right direction, we simply use the mean
    # direction of the photons in the corsika file.
    #
    # In addition we change a little bit in the output format of
    # the cosika files ... we move all the photons so, they hit a 5m circle
    # I can't explain that all here, please ask me or wait for the docu :-(
    # 
    print "working on corsika file: ", sys.argv[1]
    corsika = readcorsika.read_corsika_file(sys.argv[1])
    
    # so first we want to loop over all events in corsika
    # and move the photons of each event. 
    # in addition we want to find the mean direction of *all* photons in corsika.
    uv_event_means = []
    for event in corsika.events:
        # jump over empty events... I wonder why they exist...
        if event.info_dict['num_photons'] == 0:
            continue
        core_loc = np.array(event.info_dict['core_loc'])
        # subtract the core location of this event from the x and y coordinates
        event.photons[:,1:3] -= core_loc
        uv_mean = event.photons[:,3:5].mean(axis=0)
        uv_event_means.append(uv_mean)
    uv_event_means = np.array(uv_event_means)
    
    u,v = uv_event_means.mean(axis=0).tolist()
    print "u,v mean =", u,v
    theta, phi = sh.uv_to_theta_phi(u,v)
    theta = theta
    phi = phi /2.
    print "theta, phi mean =", theta, phi
    
    # turn the telescope 
    # ALARM ... the axis has a minus here .. it works ... but I don't know why.
    # I turn the telescope here. Around the negative x-axis
    print mirrors[0].dir
    turning_axis = np.array([-1,0,0])
    for mirror in mirrors:
        mirror.turn( turning_axis, phi)
    focal_plane.turn( turning_axis, phi)
    mirror_alignmen_point.turn( turning_axis, phi)
    # ... and around the y-axis..
    turning_axis = np.array([0,1,0])
    for mirror in mirrors:
        mirror.turn( turning_axis, theta)
    focal_plane.turn( turning_axis, theta)
    mirror_alignmen_point.turn( turning_axis, theta)
    print mirrors[0].dir
    #
    # the axes here were found out by trial and error... I still have to find out
    # which axis is which, in which program and so on. This is pretty confusing still 
    # for me, but on the other hand ... left right up down ... where is the difference :-)
    
    global telescope_phi
    global telescope_theta
    telescope_phi = phi
    telescope_theta = theta
    
    
    for event_counter, event in enumerate(corsika.events):
        print event_counter
        event.photons_who_hit = []
        if event.photons is None:
            continue
        for photon in event.photons:
            photon = Photon(photon)
            photon = reflect_photon( photon, mirrors )
            if photon.hit:
                event.photons_who_hit.append(photon)
        if event_counter > 100:
            break

    g = TGraph2D()
#    g2 = TGraph2D()
    h = TH2F("h","title",196,-22,22,196,-22,22)
    
    graph_point_counter = 0
    for ev in corsika.events:
        if not hasattr(ev, "photons_who_hit"):
            continue
        for ph in ev.photons_who_hit:
            tfpp = ph.turned_focal_plane_pos 
            h.Fill(tfpp[0], tfpp[1])

            fpp = ph.mirror_intersection
            #fpp = ph.focal_plane_pos
            g.SetPoint(graph_point_counter, fpp[0],fpp[1],fpp[2])
            graph_point_counter += 1
        
    c1 = TCanvas("c1","c1",0,0,500,500)
    g.SetMarkerStyle(20)
    g.Draw("pcol")
    c1.Update()
    
    c2 = TCanvas("c2","c2",0,500,500,500)
    h.Draw("colz")
    c2.Update()
