import numpy as np
import math
import gen_rtr as gen
import cPickle as pickle
import colorsys

def get_params(filename):
   infile = open(filename, 'r')
   lines = []
   for line in infile:
      lines.append(line.split())
   risco = float(lines[0][2])
   rivr = float(lines[1][2])
   rilr = float(lines[2][2])
   omega = float(lines[3][2])
   return risco, rivr, rilr, omega

def get_raytrace(filename):
      
   # Retreiving data from cPickle file
   infile = open(filename, 'r')

   res = pickle.load(infile) 
   ipmax = pickle.load(infile)
   a = pickle.load(infile)
   mu_o = pickle.load(infile)
   x = pickle.load(infile)
   kvec = pickle.load(infile)
   hit = pickle.load(infile) 
   g = pickle.load(infile)
   emcos = pickle.load(infile)

   infile.close()  

   return res, ipmax, a, mu_o, x, kvec, hit, g, emcos

def get_xi(xi_file, xip_file, x):
   xip_infile = open(xip_file, 'r')
   xi_infile = open(xi_file, 'r')
   r_z = []
   xi_z = []
   xi_zp = []

   for line in xi_infile:
      r_z.append(float(line.split()[0]))
      xi_z.append(float(line.split()[1]))
   for line in xip_infile:
      xi_zp.append(float(line.split()[1]))

   np.array(r_z)
   np.array(xi_z)
   np.array(xi_zp)
  
   res = len(x)

   xiz_interp = np.ndarray(shape = (res,res), dtype = float)
   xizp_interp = np.ndarray(shape = (res,res), dtype = float)
   for i in range(res):
      for j in range(res):
         xiz_interp[i][j] = interpolate(x[i][j][1], r_z, xi_z)
         xizp_interp[i][j] = interpolate(x[i][j][1], r_z, xi_zp)
   return xiz_interp, xizp_interp, r_z[-1], xi_z[-1], xi_zp[-1]

def params(a, mu, n, ipm, lowres = False):
   xip_file = 'Xiz_cmode'+str(a)+'_h0.01_n'+str(n)+'_thetai1.57079.dat'
   xi_file = 'Xizprime_cmode'+str(a)+'_h0.01_n'+str(n)+'_thetai1.57079.dat'
   param_file = 'parameters_cmode'+str(a)+'_h0.01_n'+str(n)+'_thetai1.57079.dat'

   if lowres:
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'lowres.dat'
   if ipm == 0:
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'.dat'
   else: 
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'ipm'+str(ipm)+'.dat'
      print read_in_file

   risco, rivr, rilr, w = get_params(param_file)
   res, ipmax, a1, mu1, x, kvec, hit, g, emcos = get_raytrace(read_in_file)
  
   xiz_interp, xizp_interp, dout, xiz_dout, xizp_dout=get_xi(xi_file,xip_file,x)

   return risco, w, res, ipmax, x, kvec, hit, g, xiz_interp, xizp_interp, dout,\
           xiz_dout, xizp_dout

def interpolate(point, x, y):
   
   if point < min(x) or point > max(x):
      return 0

   for i in range(len(x)):
      if x[i] == point:
         return y[i]
      if x[i] > point:
         next_x = x[i]
         next_y = y[i]
         interp = (point - last_x)*((next_y - last_y)/(next_x-last_x)) + last_y

         return interp
      last_x = x[i]
      last_y = y[i]

def nvec(xi_z, xi_zp, u, a, x, m, w, ksi_o, delt, tilt):
   t = x[0]
   r = x[1]
   phi = x[3]
   if tilt:
      n0 = ksi_o * w * r * math.sin(m*phi - w *t + delt)
      n1 = ksi_o * math.cos(m*phi - w*t + delt)
      n2 = -r
      n3 = -ksi_o * m *r * math.cos(m*phi - w *t + delt)
   else:
      n0 = -w* ksi_o * xi_z * math.sin(m*phi - w*t + delt)
      n1 = xi_zp * ksi_o *math.cos(m*phi - w*t + delt)
      n2 = -r
      n3 = m * xi_z *ksi_o * math.sin(m*phi - w*t + delt)
   n = np.array([n0, n1, n2, n3], np.float32)
   temp1 = np.dot(u, n)
   gdown = gen.g_down(x, a)
   temp2 = gen.matdot(gdown, u)
   return n + np.dot(temp2, temp1)


def norm_n(nvec, x, a):
   g = gen.g_up(x, a)

   temp1 = gen.matdot(g, nvec)
   temp2 = np.dot(temp1, nvec)
   return nvec / math.sqrt(temp2)
 
def Calc_mu_em(n, kvec, u_em):
   num = np.dot(kvec, n)
   denom = np.dot(kvec, u_em)
   return num/ denom  

def f(mu_em):
   # limb darkening
   return 1 + 2.06*mu_em
 
def f2(mu_em): 
   # limb brightening
   if mu_em < 0:
     return 0
   return math.log(1 + 1/mu_em)

def simple_I(ptype, a, x, kvec, g, hit, p, res, risco, ksi_o):

   I = np.ndarray(shape=(res,res), dtype = float)
   for i in range(res):
      for j in range(res/2):
         r = x[i][j][1]
         if r > risco and r < 25 and hit[i][j] == 1:
            # if r > perturbation limit, smoothes out the transition into decay
            u_em = gen.Kepler4v(a, x[i][j])
            nvect = np.array([0, 0, -r, 0])
            n = norm_n(nvect, x[i][j], a)
            mu_em = Calc_mu_em(n, kvec[i][j], u_em)
            if ptype == 1:
               I[i][j] = ksi_o*g[i][j]**3 *r**(-p)*f(mu_em)
            elif ptype == 2:
               I[i][j] = ksi_o*g[i][j]**3 *r**(-p)*f2(mu_em)            
         else:
            I[i][j] = 0 
      for k in range(res-1, res-res/2 -1, -1):
         r = x[i][k][1]
         if r > risco and r < 25 and hit[i][k] == 1:
            # if r > perturbation limit, smoothes out the transition into decay
            u_em = gen.Kepler4v(a, x[i][k])
            nvect = np.array([0, 0, -r, 0])
            n = norm_n(nvect, x[i][k], a)
            mu_em = Calc_mu_em(n, kvec[i][k], u_em)
            if ptype == 1:
               I[i][k] = ksi_o*g[i][k]**3 *r**(-p)*f(mu_em)
            elif ptype == 2:
               I[i][k] = ksi_o*g[i][k]**3 *r**(-p)*f2(mu_em)            
         else:
            I[i][k] = 0 
   return I

def Calc_I_highres(ptype,xi_z, xi_zp, a, x, kvec, g, hit, p, m, w, risco, dout, 
           xiz_dout, xizp_dout, ksi_o, delt, tilt, rout): #just added rout param
   ''' Generates the intensity matrix without any perturbations '''
   # initializing array
   res = len(x)
   I = np.zeros(shape=(res,res), dtype=float)
   
   # creating matrix of radii
   for i in range(res):
      for j in range(res/2):
         r = x[i][j][1]
         if hit[i][j] == 0:
            break
         if r > risco and r < rout and hit[i][j] == 1:
            # if r > perturbation limit, smoothes out the transition into decay
            if r > dout:
               xiz_new = xiz_dout*math.exp((xizp_dout/xiz_dout)*(r - dout))
               xizp_new = xizp_dout*math.exp((xizp_dout/xiz_dout)*(r-dout))
            u_em = gen.Kepler4v(a, x[i][j])
            nvect = nvec(xi_z[i][j], xi_zp[i][j], u_em, a, x[i][j], m, w, ksi_o, delt, tilt)
            n = norm_n(nvect, x[i][j], a)
            mu_em = Calc_mu_em(n, kvec[i][j], u_em)
            if ptype == 1:
               I[i][j] = g[i][j]**3 *r**(-p)*f(mu_em)
            elif ptype == 2:
               I[i][j] = g[i][j]**3 *r**(-p)*f2(mu_em)            
         else:
            I[i][j] = 0 
      for k in range(res-1, res-res/2 -1, -1):
         r = x[i][k][1]
         if hit[i][k] == 0:
            break
         if r > risco and r < rout and hit[i][k] == 1:
            # if r > perturbation limit, smoothes out the transition into decay
            if r > dout:
               xiz_new = xiz_dout*math.exp((xizp_dout/xiz_dout)*(r - dout))
               xizp_new = xizp_dout*math.exp((xizp_dout/xiz_dout)*(r-dout))
            u_em = gen.Kepler4v(a, x[i][k])
            nvect = nvec(xi_z[i][k], xi_zp[i][k], u_em, a, x[i][k], m, w, ksi_o, delt, tilt)
            n = norm_n(nvect, x[i][k], a)
            mu_em = Calc_mu_em(n, kvec[i][k], u_em)
            if ptype == 1:
               I[i][k] = g[i][k]**3 *r**(-p)*f(mu_em)
            elif ptype == 2:
               I[i][k] = g[i][k]**3 *r**(-p)*f2(mu_em)            
         else:
            I[i][k] = 0
   return I
   

def scale_g(mat):
   '''Scales the matrix such that the range becomes 1'''
   mat_range = mat.max() - mat.min()
   return np.multiply(mat,(1.0 / mat_range))


def scale_I(mat):
   # initializing array 
   height = len(mat)
   width = len(mat[0])
   scaled_I = np.ndarray(shape=(height, width), dtype = float)

   # want to scale
   maximum = mat.max()
   scale = .95/maximum
   for i in range(height):
      for j in range(width):
         # return zero if inside the event horizon or outside the disk edge
         if mat[i][j] == 0:
            scaled_I[i][j] = 0
         else:
            scaled_I[i][j] = mat[i][j]*scale + 0.05

   return scaled_I


def bin_I(g, h, I, nbins):

   step = (g.max()-g.min())/nbins
   g_range = np.arange(g.min(), g.max(), step)
   # initializes the final array. 1-D, with length of len(g_range)
   spect = np.array(len(g_range)*[0], np.float32)
   res = len(g)
   for x in range(len(g_range) -1): 
      temp = 0
      for i in range(res):
         for j in range(res):
            if g[i][j] >= g_range[x] and g[i][j] < g_range[x+1]: 
               if h[i][j] == 1:
                  temp += I[i][j]
      spect[x] = temp
   return spect

def gen_simple_I(ptype, a, mu, p, ksi_o, nbins):
   read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'ipm30.dat'
   param_file = 'parameters_cmode'+str(a)+'_h0.01_n0_thetai1.57079.dat'
   res, ipmax, a1, mu1, x, kvec, hit, g, emcos = get_raytrace(read_in_file)
   risco, rivr, rilr, w = get_params(param_file)

   I = simple_I(ptype, a, x, kvec, g, hit, p, res, risco, ksi_o)
   spect = bin_I(g, hit, I, nbins)
   x_range = np.arange(g.min(), g.max(), (g.max()-g.min())/nbins)

   outfile = 'perturb/data/spec_a'+str(a)+'_mu'+str(mu)+'_'+str(ptype)+'.dat'
   out = open(outfile, "wb")
   pickle.dump(spect, out)
   pickle.dump(g.max(), out)
   pickle.dump(x_range, out)
   out.close()

def gen__I(ptype, a, mu, n, ksi_o, nbins, tilt, delt, ext):

   risco, w, res, ipmax, x, kvec, hit, g, \
   xiz_interp, xizp_interp, dout, xiz_dout, xizp_dout = params(a, mu, n, ipm)

   I = Calc_I_highres(ptype,xiz_interp, xizp_interp, a, x, kvec, g, hit, p, m, w, \
               risco, dout, xiz_dout, xizp_dout, ksi_o, delt, tilt)
   spect = bin_I(g, hit, I, nbins)
   x_range = np.arange(g.min(), g.max(), (g.max()-g.min())/nbins)

   outfile = 'perturb/data/int'+str(a)+'_mu'+str(mu)+'_'+str(ptype)+'_'+str(ext)+'.dat'
   out = open(outfile, "wb")
   pickle.dump(spect, out)
   pickle.dump(x_range, out)
   pickle.dump(g.max(), out)
   pickle.dump(ksi_o, out)
   pickle.dump(nbins, out)
   out.close()

def bin_I_g_r(a, mu, n, ipm, ksi_o, nbins, ptype, gmax, lowres):
   param_file = 'parameters_cmode'+str(a)+'_h0.01_n'+str(n)+'_thetai1.57079.dat'

   if lowres == True:
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'smallres.dat'
   elif ipm == 0:
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'.dat'
   else: 
      read_in_file = 'outputs/perturb_a'+str(a)+'_mu'+str(mu)+'ipm'+str(ipm)+'.dat'
   print read_in_file

   risco, rivr, rilr, w = get_params(param_file)
   res, ipmax, a1, mu1, x, kvec, h, g, emcos = get_raytrace(read_in_file)

   step = gmax/nbins #because gmin is zero
   g_range = np.arange(0, gmax, step)
   # initializes the final array. 1-D, with length of len(g_range)
   spect = np.array(len(g_range)*[0], np.float32)

   I = simple_I(ptype, a, x, kvec, g, h, p, res, risco, ksi_o)

   for k in range(len(g_range) -1): 
      temp_in = 0
      temp_out = 0
      for i in range(res):
         for j in range(res):
            if g[i][j] >= g_range[k] and g[i][j] < g_range[k+1]: 
               if h[i][j] == 1:
                  if x[i][j][1] < rivr:
                     temp_in += I[i][j]
                  else:
                     temp_out += I[i][j]
      if temp_out == 0 and temp_in == 0:
         spect[k] = 0
      else: 
         spect[k] = temp_in/ (temp_in + temp_out)

   outfile = 'ref/intensity'+str(a)+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat'
   out = open(outfile, "wb")
   pickle.dump(spect, out)
   pickle.dump(g_range,out)
   pickle.dump(a, out)
   pickle.dump(mu, out)
   pickle.dump(n, out)
   pickle.dump(ksi_o, out)
   pickle.dump(nbins, out)
   pickle.dump(rout, out)
   pickle.dump(ptype, out)
   pickle.dump(x, out)
   pickle.dump(g, out)
   pickle.dump(I, out)
   pickle.dump(h, out)
   out.close()


   
def gen_g_stack(mu, n, nbins, ptype):
   file1 = open('ref/intensity0.001'+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat', 'r')
   file2 = open('ref/intensity0.01'+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat', 'r')
   file3 = open('ref/intensity0.1'+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat', 'r')
   file4 = open('ref/intensity0.5'+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat', 'r')
   file5 = open('ref/intensity0.9'+'_mu'+str(mu)+'_n'+str(n)+'_' +str(ptype)+'.dat', 'r')

   spec1 = pickle.load(file1)
   x = pickle.load(file1)
   spec2 = pickle.load(file2)
   spec3 = pickle.load(file3)
   spec4 = pickle.load(file4)
   spec5 = pickle.load(file5)

   final = []

   final.append(spec1)
   final.append(spec1)
   final.append(spec2)
   final.append(spec2)
   final.append(spec3)
   final.append(spec3)
   final.append(spec4)
   final.append(spec4)
   final.append(spec5)
   final.append(spec5)
   
   print final

   outfile = open('ref/stack_mu'+str(mu)+'_n'+str(n)+'_'+str(ptype)+'.dat','wb')
   pickle.dump(final, outfile)
   pickle.dump(x, outfile)
   outfile.close()


   
def gen_image(ptype, a, mu, n, ipm, ksi_o, tilt, delt):   

   risco, w, res, ipmax, x, kvec, hit, g, \
   xiz_interp, xizp_interp, dout, xiz_dout, xizp_dout = params(a, mu, n, ipm)

   rgb = np.ndarray(shape=(res,res, 3), dtype=float)

   I = Calc_I_highres(ptype,xiz_interp, xizp_interp, a, x, kvec, g, hit, p, m, w, \
               risco, dout, xiz_dout, xizp_dout, ksi_o, delt, tilt, ipm)
   g_scaled = scale_g(g)
   I_scaled = scale_I(I)
   

   for i in range(res):
      for j in range(res):
         # -.6 + 1.3
         # -.65 + 2.3
         rgb[i][j] = colorsys.hls_to_rgb(-.6+1.3*g_scaled[i][j], \
                     I_scaled[i][j], 1)
   if tilt:
      out = 'perturb/data/im_tilt_a_'+str(a)+'_mu'+str(mu)+'_n'+str(n)+'_h0.01_'+str(ptype)+'.dat'
   else: 
      out = 'perturb/data/test3_im_a_'+str(a)+'_mu'+str(mu)+'_n'+str(n)+'_h0.01_'+str(ptype)+'.dat'
   outfile = open(out, "wb")
   pickle.dump(rgb, outfile)
   pickle.dump(g, outfile)
   pickle.dump(res, outfile)
   pickle.dump(ipmax, outfile)
   outfile.close()
     
def return_index(num, array):
   copy = array
   return copy.pop(num)


def gen_int(ptype, a, mu, ipm, p, m, ksi_o, nbins, num, n, tilt):
   risco, w, res, ipmax, x, kvec, hit, g, \
   xiz_interp, xizp_interp, dout, xiz_dout, xizp_dout = params(a, mu, n, ipm)
  
   phase = np.arange(0, 2.2, 2.0/num)
   intensity = []
   counter = 0 
   ave = np.array(nbins*[0.0], np.float32)
   np.array(ave)
     
   xiz_max = xiz_interp.max()
   xizp_max = xizp_interp.max()
   for i in range(len(phase)):
      I = Calc_I_highres(ptype, xiz_interp, xizp_interp, a, x, kvec, g, hit, p, m, w, \
                 risco, dout, xiz_dout, xizp_dout, ksi_o, phase[i]*2*math.pi,tilt)
      spect = bin_I(g, hit, I, nbins)
      #print len(spect)
      ave =np.add(ave,spect)
      intensity.append(spect)
      print counter
      counter += 1
   intensity = np.array(intensity)
   original = intensity
   for i in range(len(intensity)):
      intensity[i] -= ave/len(phase)

   pos_max = intensity.max()
   neg_max = math.fabs(intensity.min())
 
   if neg_max > pos_max:
      scale = 1.0/(neg_max)
      real_max = neg_max
   else:
      real_max = pos_max
      scale = 1.0/pos_max
      
   intensity = np.multiply(intensity, scale)
   if tilt:
      out = 'perturb/data/tilt_a'+str(a)+'_mu'+str(mu)+'_'+str(ptype)+'.dat'
      out2 = 'temp.dat'
   else:
     out = 'perturb/data/int_diff_a'+str(a)+'_mu'+str(mu)+'_n'+str(n)+'_h0.01_'+str(ptype)+'.dat'
     out2 = 'perturb/data/max_param_a'+str(a)+'_mu'+str(mu)+'_n'+str(n)+'_h0.01_'+str(ptype)+'.dat'

   outfile = open(out, "wb")
   outfile2 = open(out2, "wb")
   pickle.dump(intensity, outfile)
   pickle.dump(g.max(), outfile)
   pickle.dump(ksi_o, outfile)
   
   if tilt == False:
      pickle.dump(real_max, outfile2)
     # pickle.dump(ratio, outfile2)
      pickle.dump(xiz_max, outfile2)
      pickle.dump(xizp_max, outfile2)
      pickle.dump(ksi_o, outfile2)
      pickle.dump(original, outfile2)
      pickle.dump(intensity, outfile2)
      pickle.dump(gmax, outfile2)
   outfile.close()
   outfile2.close()
  

rout = 20

p = 3
m = 1
nbins = 50
num_phase = 20
tilt = False

a = 0.001
mu = 0.5
n = 2
nbins = 100
ptype = 1
ipm = 30
ksi_o = 0.1
gmax = 1.2

#gen_g_stack(mu, n, nbins, ptype)
for mu in [0.1, 0.5, 0.7]:
   for n in [0, 1, 2]:
      if n == 0:
         gmax = 1.4
      #bin_I_g_r(0.001, mu, n, 30, 0.01, nbins, ptype, gmax, False)
      #bin_I_g_r(0.01, mu, n, 13, 0.007, nbins, ptype, gmax, False)
      #bin_I_g_r(0.1, mu, n, 0, 0.01, nbins, ptype, gmax,False)
      #bin_I_g_r(0.5, mu, n, 9, 0.0005, nbins, ptype, gmax,False)
      #bin_I_g_r(0.9, mu, n, 8, 0.0005, nbins, ptype, gmax,False)


ksi_o1 = 0.00005
ksi_o2 =0.000003


a = .9
mu = .1
n = 0 
ipm = 30



#gen_int(1, a, .1, ipm, p, m, ksi_o1, nbins, num_phase, 0, True)
#gen_int(1, a, .5, ipm, p, m, ksi_o1, nbins, num_phase, 0, True)
#gen_int(1, a, .7, ipm, p, m, ksi_o1, nbins, num_phase, 0, True)


a = .001
ksi_o1 = 0.01
ksi_o2 = .007

ipm = 20


mu = 0.1
#gen_int(1, a, mu, ipm, p, m, ksi_o1, nbins, num_phase, 0, False)
#gen_int(2, a, mu, ipm, p, m, ksi_o2, nbins, num_phase, 0, False)
#gen_int(1, a, mu, ipm, p, m, ksi_o1, nbins, num_phase, 1, False)
#gen_int(2, a, mu, ipm, p, m, ksi_o2, nbins, num_phase, 1, False)
#gen_int(1, a, mu, ipm, p, m, ksi_o1, nbins, num_phase, 2, False)
#gen_int(2, a, mu, ipm, p, m, ksi_o2, nbins, num_phase, 2, False)

mu = 0.5
rout = 20
#gen_int(1, .001, mu, 20, p, m, .1, nbins, num_phase, 0, False)
rout = 13
#gen_int(1, .01, mu, 13, p, m, .07, nbins, num_phase, 0, False)
rout = 11
#gen_int(1, .1, mu, 0, p, m, .005, nbins, num_phase, 1, False)
rout = 9
#gen_int(1, .5, mu, 9, p, m, .0005, nbins, num_phase, 1, False)
rout = 8
#gen_int(1, .9, mu, 8, p, m, .0005, nbins, num_phase, 2, False)


a = 0.001
ipm = 30

mu = 0.7
n = 2
ksi_o1 = 0.01
ksi_o2 = 0.007
tilt  = False
nbins = 30


#gen__I(1, a, mu, n, ksi_o1, nbins, tilt, 0, 0)
#gen__I(1, a, mu, n, ksi_o1, nbins, tilt, math.pi/2.0, 1)
#gen__I(1, a, mu, n, ksi_o1, nbins, tilt, math.pi, 2)
#gen__I(1, a, mu, n, ksi_o1, nbins, tilt, 3.0*math.pi/2.0, 3)


