import numpy as np
from numpy.linalg import *
import raytrace
import math
import cPickle as pickle


def init_4v():
   return np.array([0.0, 0.0, 0.0,0.0], np.float32)

def g_init():
   g = np.array([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]], np.float32)
   return g

def e_init():
   e = np.ndarray(shape=(4,4,4,4), dtype = float)  
   return e  

def init_mat(res):
   x = np.ndarray(shape=(res,res,4), dtype = float)
   return x

def llambda(alpha, mu_o):
   '''Calculates lambda from impact parameters'''
   ll = -alpha
   ll *= math.sqrt(1.0 - mu_o**2)
   return ll 

def qsquared(alpha, beta, a, mu_o):
   '''Calculates the Carter constant from impact parameters'''
   ll = llambda(alpha,mu_o)
   cos2 = mu_o**2
   q2 = beta**2 - a**2*cos2 + ll**2*cos2/(1.0-cos2)
   return q2

def kerr_Sigma(r, mu, a):
   Sigma = r**2
   Sigma += a**2 * (mu**2)
   return Sigma

def kerr_Delta(r, a):
   return r**2 + a**2 - 2.0*r

def kerr_A(r, mu, a):
   Sigma = kerr_Sigma(r, mu, a)
   mu2 = mu**2
   x = r**2 + a**2
   A = x**2
   A -= a*a*(1-mu*mu)*kerr_Delta(r,a)
   return A       
   

def g_down(x, a):
   g = g_init()
   r = x[1]
   mu = math.cos(x[2])
   mu2 = mu*mu
   Sigma = kerr_Sigma(r, mu, a)
   Delta = kerr_Delta(r, a)
   A = kerr_A(r, mu, a)

   g[0][0] =  -(1.0 - 2.0*r/Sigma)
   g[1][1] = Sigma/Delta
   g[2][2] = Sigma
   g[3][3] = A*(1.0 - mu2)/Sigma
   g[0][3] = -2.0*a*r*(1.0-mu2)/Sigma
   g[3][0] = g[0][3]
   
   return g

def g_up(x, a):
   g = np.array([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]], np.float32)
   
   r = x[1]
   mu = math.cos(x[2])
   mu2 = mu**2
   Sigma = kerr_Sigma(r, mu, a)
   Delta = kerr_Delta(r, a)
   A = kerr_A(r, mu, a)

   g[0][0] = -A/(Sigma*Delta)
   g[1][1] = Delta/Sigma
   g[2][2] = 1.0/Sigma
   g[3][3] = Sigma/(A*(1.0-mu2))
   g[3][3] -= (4.0*r*r*a*a)/(A*Sigma*Delta)
   g[0][3] = -2.0*r*a/(Sigma*Delta)
   g[3][0] = g[0][3]

   return g
def Kepler4v(a, x):
   u = init_4v()
   r = x[1]*math.sin(x[2])
   temp = r**2 -3.0*r + 2.0*a*math.sqrt(r)
 #  print temp
   u[0] = (r*r + a*math.sqrt(r))/(r*math.sqrt(temp))
   u[1] = 0.0
   u[2] = 0.0
   u[3] = 1.0/math.sqrt(r*temp)
   return u


def discNormal(a, x):
   n = init_4v()
   r = x[1]
   n[0] = 0.0
   n[1] = 0.0
   n[2] = -1.0/r
   n[3] = 0.0
   return n

def matdot(mat, vec):
   ''' Returns the dot product between a matrix and vector'''
   outvec = init_4v()
   for i in range(4):
      outvec[i] = np.dot(mat[i], vec)
   return outvec


def lchelper(a, b, c, d):
   parity = 1.0
   A = [a, b, c, d]
   for i in range(3):
      for j in range(3 - i):
         if (A[j] > A[j+1]):
            temp = A[j]
            A[j] = A[j+1]
            A[j+1] = temp
            parity *= -1.0
         elif (A[j] == A[j+1]):
            return 0.0

   return parity


def levi_civita(x, a):
   e = e_init()
   gmat = g_down(x, a)

   g = np.linalg.det(gmat)
   if (g > 0.0):
      print("Error: Det(g_mu_nu > 0, e uninitiated)\n")
      return
   for i in range(4):
      for j in range(4):
         for k in range(4):
            for l in range(4):
               e[i][j][k][l] = math.sqrt(-g)*lchelper(i, j, k, l)
   g_mat = None
   return e

def Calc_emcos(a, x_em, k_em, u_em, norm):
   ''' Calculates the emission angle cosine at position x_em '''
   g_em = g_down(x_em, a)

   tempvec1 = matdot(g_em, norm)
   tempvec2 = matdot(g_em, u_em)

   emcos = -np.dot(k_em, tempvec1)/np.dot(k_em, tempvec2)
   # freeing up memory
   g_em = None
   tempvec1 = None
   tempvec2 = None
   return emcos


def Calc_g(a, x_em, k_em, u_em): 
   ''' Calculates the redshift matrix, g'''
   g_obs = g_init()
   g_em = g_down(x_em, a)
   k_obs = init_4v()
   u_obs = init_4v()

   g_obs[0][0] = -1.0
   g_obs[0][1] = 0.0
   g_obs[0][2] = 0.0
   g_obs[0][3] = 0.0

   g_obs[1][0] = 0.0
   g_obs[1][1] = 1.0
   g_obs[1][2] = 0.0
   g_obs[1][3] = 0.0

   g_obs[2][0] = 0.0
   g_obs[2][1] = 0.0
   g_obs[2][2] = 1.0
   g_obs[2][3] = 0.0

   g_obs[3][0] = 0.0
   g_obs[3][1] = 0.0
   g_obs[3][2] = 0.0
   g_obs[3][3] = 1.0

   k_obs[0] = 1.0 
   k_obs[1] = 1.0 
   k_obs[2] = 0.0
   k_obs[3] = 0.0
   u_obs[0] = 1.0 
   u_obs[1] = 0.0 
   u_obs[2] = 0.0
   u_obs[3] = 0.0

   tempvec1 = np.dot(g_obs, u_obs)
   tempvec2 = np.dot(g_em, u_em)

   g = np.dot(k_obs, tempvec1)/np.dot(k_em, tempvec2)

   # freeing up memory 
   k_obs = None
   u_obs = None
   tempvec1 = None
   tempvec2 = None
   g_obs = None
   g_em = None

   return g;

def ab_boundary(ipmax, res, a, mu_o):
   hit = np.ndarray(shape=(res,res), dtype= float)
   x = np.ndarray(shape=(res,res,4), dtype= float)
   k = np.ndarray(shape=(res,res,4), dtype= float)
   xtemp = raytrace.init_4v()
   ktemp = raytrace.init_4v()
   bhl = {}
   bhr = {}
   left = []
   right = []
   for ii in range(res):
      beta = -(-ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(ii)/float(res)))
      for jj in range(res):
         alpha = -ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(jj)/float(res))

         ll = llambda(alpha, mu_o)
         q2 = qsquared(alpha, beta, a, mu_o)
         # hit matrix holds 0 if the photon hits inside risco and 1 otherwise
         curr_hit = raytrace.raytrace(ll, q2, a, mu_o, alpha, beta, 0.0, \
                     xtemp, ktemp)
         if curr_hit == 0:
            if beta < 0:
               if alpha in bhl:
                  bhl[alpha].append(beta)
               else:
                  bhl.update({alpha:[beta]})
            else:
               if alpha in bhr:
                  bhr[alpha].append(beta)
               else:
                  bhr.update({alpha:[beta]})
   for alpha in bhl:
      left.append((alpha, min(bhl[alpha])))
   for alpha in bhr:
      right.append((alpha, max(bhr[alpha])))
   return left, right

def Calc_rtr_highres(ipmax, res, a, mu_o):
   ''' Performs the raytrace and returns the hit matrix, and position and
       momentum vectors'''
   hit = np.zeros(shape=(res,res), dtype= float)
   x = np.zeros(shape=(res,res,4), dtype= float)
   k = np.zeros(shape=(res,res,4), dtype= float)
   xtemp = raytrace.init_4v()
   ktemp = raytrace.init_4v()
   for ii in range(res):
      beta = -(-ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(ii)/float(res)))
      for jj in range(res/2):
         alpha = -ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(jj)/float(res))

         ll = llambda(alpha, mu_o)
         q2 = qsquared(alpha, beta, a, mu_o)
         # hit matrix holds 0 if the photon hits inside risco and 1 otherwise
         temp = raytrace.raytrace(ll, q2, a, mu_o, alpha, beta, 0.0, \
                     xtemp, ktemp)
         if temp == 0:
            break
         hit[ii][jj] = temp
         # x vector holds [dt, r, theta, phi] for each point on res x res matrix
         x[ii][jj] = [raytrace.getx(xtemp, 0),raytrace.getx(xtemp, 1),\
                      raytrace.getx(xtemp, 2),raytrace.getx(xtemp, 3)] 
         k[ii][jj] = [raytrace.getx(ktemp, 0),raytrace.getx(ktemp, 1),\
                      raytrace.getx(ktemp, 2),raytrace.getx(ktemp, 3)]
      for kk in range(res-1, res-res/2-1, -1):
         alpha = -ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(kk)/float(res))

         ll = llambda(alpha, mu_o)
         q2 = qsquared(alpha, beta, a, mu_o)
         # hit matrix holds 0 if the photon hits inside risco and 1 otherwise
         temp = raytrace.raytrace(ll, q2, a, mu_o, alpha, beta, 0.0, \
                     xtemp, ktemp)
         if temp == 0:
            break
         hit[ii][kk] = temp
         # x vector holds [dt, r, theta, phi] for each point on res x res matrix
         x[ii][kk] = [raytrace.getx(xtemp, 0),raytrace.getx(xtemp, 1),\
                      raytrace.getx(xtemp, 2),raytrace.getx(xtemp, 3)] 
         k[ii][kk] = [raytrace.getx(ktemp, 0),raytrace.getx(ktemp, 1),\
                      raytrace.getx(ktemp, 2),raytrace.getx(ktemp, 3)]


   return(hit, x, k)

  
def Calc_rtr(ipmax, res, a, mu_o):
   ''' Performs the raytrace and returns the hit matrix, and position and
       momentum vectors'''
   hit = np.ndarray(shape=(res,res), dtype= float)
   x = np.ndarray(shape=(res,res,4), dtype= float)
   k = np.ndarray(shape=(res,res,4), dtype= float)
   xtemp = raytrace.init_4v()
   ktemp = raytrace.init_4v()
   for ii in range(res):
      beta = -(-ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(ii)/float(res)))
      for jj in range(res):
         alpha = -ipmax*(1.0 + 1.0/float(res)) + 2.0*ipmax*(float(jj)/float(res))

         ll = llambda(alpha, mu_o)
         q2 = qsquared(alpha, beta, a, mu_o)
         # hit matrix holds 0 if the photon hits inside risco and 1 otherwise
         hit[ii][jj] = raytrace.raytrace(ll, q2, a, mu_o, alpha, beta, 0.0, \
                     xtemp, ktemp)
         # x vector holds [dt, r, theta, phi] for each point on res x res matrix
         x[ii][jj] = [raytrace.getx(xtemp, 0),raytrace.getx(xtemp, 1),\
                      raytrace.getx(xtemp, 2),raytrace.getx(xtemp, 3)] 
         k[ii][jj] = [raytrace.getx(ktemp, 0),raytrace.getx(ktemp, 1),\
                      raytrace.getx(ktemp, 2),raytrace.getx(ktemp, 3)]

   return(hit, x, k)


def Calc_ImageArrays(res, a, mu_o, x, k, hit):
   Mres = res - 2

   horizon = 1.0 + math.sqrt(1.0 - a**2)
   temp1 = 1.0 + math.pow(1.0 - a**2, 1.0/3.0)*(math.pow(1.0+a, 1.0/3.0)
           + math.pow(1.0-a, 1.0/3.0))
   temp2 = math.sqrt(3.0*a**2 + temp1**2)
   
   risco = 3.0 + temp2 - math.sqrt((3.0-temp1)*(3.0+temp1 + 2.0*temp2))

   # initialize vectors
   x_em = init_4v()
   p_em = init_4v()
   u_em = init_4v()
   g = np.ndarray(shape=(res,res),dtype=float)
   emcos = np.ndarray(shape=(res,res),dtype=float)

   for i in range(res):
      for j in range(res):
     
         for jj in range(4):
	    x_em[jj] = x[i][j][jj]
	    p_em[jj] = k[i][j][jj]

         if((x_em[1] > risco) and (hit[i][j]==1)):
	    g_d = g_down(x_em, a)
	    u_em = Kepler4v(a, x_em)


            norm = discNormal(a, x_em)
	    emcos[i][j] = Calc_emcos(a, x_em, p_em, u_em, norm)
	    g[i][j] = Calc_g(a, x_em, p_em, u_em)


         else:
            g[i][j] = 0.0
            emcos[i][j] = 0.0

   # free memory

   x_em = None
   p_em = None
   u_em = None
   norm = None

   return g, emcos
   
	

def Output_rtr(filename, res, ipmax, a, mu_o, x, k, hit, g, emcos,):
   '''Python's pickle module stores data in a "black box" fashion. It can only be 
   retrieved in the same order that it was put in.'''
  
   
   outfile = open(filename, "wb")
   
   pickle.dump(res, outfile) 
   pickle.dump(ipmax, outfile)
   pickle.dump(a, outfile)
   pickle.dump(mu_o, outfile)
   pickle.dump(x, outfile)
   pickle.dump(k, outfile)
   pickle.dump(hit, outfile)
   pickle.dump(g, outfile)
   pickle.dump(emcos, outfile)

   outfile.close()

def gen_rtr(out_filename, res, ipmax, a, mu_o):
   ''' Runs the raytrace, produces ImageArrays, and outputs information in the
       designated output file '''
   
   hit, x, k = Calc_rtr_highres(ipmax, res, a, mu_o)
   print("Calc_rtr done!\n")
  
   g, emcos = Calc_ImageArrays(res, a, mu_o, x, k, hit)
   print("Calc_ImageArrays done!\n")
   Output_rtr(out_filename, res, ipmax, a, mu_o, x, k, hit, g, emcos)
   print("Output_rtr done!\n")

   # Freeing up memory
   x = None
   k = None
   g = None
   M = None
   emcos = None



########################## RUNNING THE CODE ####################################
#                                                                              #
# To run the code, simply change the below parameters and use the run this     #
# file from the terminal.                                                      #
#                                                                              #
################################################################################


# resolution
res = 200
# maximum distance in units of BH mass
ipmax = 30
# BH spin in units of BH mass
a = 0.001
# cosine of the angle
mu_o = .5
# the name of the generated file with the data of the raytrace
outfile = 'raytrace_a'+str(a)+'_mu'+str(mu_o)+'_ipm'+str(ipmax)+'.dat'

# This will call all the relevant methods, run the raytrace, and produce
# the output file
gen_rtr(outfile, res, ipmax, a, mu_o)


