import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pylab import *
import cPickle as pickle
import colorsys

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = "serif"
plt.rcParams['font.serif'] = "times"
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 8
plt.rcParams['text.usetex'] = 'True'
plt.rcParams['xtick.major.size'] = 2
plt.rcParams['ytick.major.size'] = 2

def scale_g(mat):
   '''Scales the matrix such that the range becomes 1'''
   mat_range = mat.max() - mat.min()
   return (1.0 / mat_range) * mat
   
def scale_I(mat):
   ''' Scales the matrix such that the maximum value is 1. By first scaling the 
   matrix so the max value is 0.95 and then adding .05, we brighten the image 
   but avoid oversaturation'''
   # initializing array 
   scaled_I = np.ndarray(shape=(res,res), dtype = float)

   # want to scale
   maximum = mat.max()
   scale = .95/maximum
   for i in range(res):
      for j in range(res):
         # return zero if inside the event horizon or outside the disk edge
         if mat[i][j] == 0:
            scaled_I[i][j] = 0
         else:
            scaled_I[i][j] = mat[i][j]*scale + 0.05

   return scaled_I

def gen_I(x, g, hit, p, m, w, k, delt):
   ''' Generates the intensity matrix using a simple wave perturbation '''
   # initializing array
   I = np.ndarray(shape=(res,res), dtype=float)

   # setting radius and phi values
   for i in range(res):
      for j in range(res):
         r = x[i][j][1]
         phi = x[i][j][3]
         t = x[i][j][0]
         if r>risco and r<rout and hit[i][j] == 1:
            temp = math.cos(m*phi - w*t -k*r+ delt)
            I[i][j] = g[i][j]**3*(15+temp)**(15/4)*r**(-p)
         else: 
            I[i][j] = 0.0
   return I
 

def simple_I(x, g, hit, p):
   ''' Generates the intensity matrix without any perturbations '''
   # initializing array
   I = np.ndarray(shape=(res,res), dtype=float)
   # creating matrix of radii
   for i in range(res):
      for j in range(res):
         r = x[i][j][1]
         if r > risco and r < rout and hit[i][j] == 1:
            I[i][j] = g[i][j]**3 *r**(-p)
         else:
            I[i][j] = 0 
   return I


def bin_I(g, h, I, nbins):
   ''' Bins intesnity values based on corresponding redshift (g) values. 
   nbins specifies the number of bins g is divided into. For each 'bin', 
   we return the sum of the intensity values corresponding to the contained 
   g values.'''
   step = (g.max()-g.min())/nbins
   g_range = arange(g.min(), g.max(), step)
   # initializes the final array. 1-D, with length of len(g_range)
   spect = np.array(len(g_range)*[0], np.float32)

   for x in range(len(g_range) -1): 
      temp = 0
      for i in range(res):
         for j in range(res):
            if g[i][j] >= g_range[x] and g[i][j] < g_range[x+1]: 
               if h[i][j] == 1:
                  temp += I[i][j]
      spect[x] = temp
   return spect

def plot_I(g, h, I, nbins, outfile):
   '''Plots the emission spectrum of I as a function of redshift, g. The plot
   is saved in the designated outfile.'''
   spect = bin_I(g, hit, I, nbins)
   x = arange(g.min(), g.max(), (g.max()-g.min())/nbins)

   plt.plot(x, spect)
   plt.xlabel("Redshift")
   plt.ylabel("Intensity")
   plt.savefig(outfile, format='pdf')
   

def plot_image(x, g, I, p, outfile):   
   '''Plots an image of the accretion disk, as well as any perturbations. For 
   each point on the grid, we define a color using HLS color scheme. Then we use
   colorsys to convert that matrix to RGB. The resulting image is saved in the
   designated outfile'''
   rgb = np.ndarray(shape=(res,res, 3), dtype=float)

   # scaling g and I to
   g_scaled = scale_g(g)
   I_scaled = scale_I(I)

   # converting HLS matrix to rgb tuple matrix. 
   # H = hue; given by g
   # L = luminosity; given by I
   # S = saturation; set to 1
   for i in range(res):
      for j in range(res):
         # The numbers in front of g_scaled determine the range of the color
         # spectrum used in the plot. 
         rgb[i][j] = colorsys.hls_to_rgb(-.6+1.3*g_scaled[i][j], \
                     I_scaled[i][j], 1)


   # Creating the plot
   plt.xlabel("$\hat{\\alpha}$")
   plt.ylabel("$\hat{\\beta}$")
   plot = plt.imshow(rgb, extent=[-ipmax,ipmax,-ipmax,ipmax])

   # adding and modifying color bar on the side
   cbar = plt.colorbar(plot,ticks=[0, .5, 1])
   cbar.ax.set_yticklabels([round(g.max(),2), round((g.max()-g.min())/ 2.0, 2),round(g.min(),2 )])
   cbar.set_label(r'redshift')

   # saving the plot to the outfile
   plt.savefig(outfile, format='pdf')
   

def gen_movie(x, g, hit, nbins,p, m, w, k, num_pic, plot_type):
   ''' This function generates a series of images (the number is designated by 
   num_pic) that follow the evolution of the wave perturbation by running 
   gen_I with even steps of delt from 0 to 2pi. plot_type can be either "image" 
   or "intensity". '''

   delt = np.arange(0, 2*math.pi, 2*math.pi/num_pic)
   n = 0
   if plot_type == "image":
      for d in delt:
         outfile = "plots/im"+str(n)
         I = gen_I(x, g, hit, p, m, w, k, d)
         plot_image(x, g, I, p, outfile)
         plt.clf()
         n += 1
         print n
   elif plot_type == "intensity":
      for d in delt:
         outfile = "plots/int" +str(n)
         I = gen_I(x, g, hit, p, m, w, k, d)
         plot_I(g, hit, I, nbins, outfile)
         plt.clf()
         n += 1
         print n
   else:
      print ("ERROR: invalid plot_type entered. OPTIONS: 'image' or 'intensity'")


# Retreiving data from cPickle file

filename = "outputs/dexter1.txt"
infile = open(filename, 'r')

res = pickle.load(infile) 
ipmax = pickle.load(infile)
a = pickle.load(infile)
mu_o = pickle.load(infile)
x = pickle.load(infile)
kvec = pickle.load(infile)
hit = pickle.load(infile)
g = pickle.load(infile)
emcos = pickle.load(infile)

infile.close()

temp1 = 1.0 + math.pow(1.0 - a**2, 1.0/3.0)*(math.pow(1.0+a, 1.0/3.0)\
           + math.pow(1.0-a, 1.0/3.0))
temp2 = math.sqrt(3.0*a**2 + temp1**2)
# The horizon radius   
risco = 3.0 + temp2 - math.sqrt((3.0-temp1)*(3.0+temp1 + 2.0*temp2))


rout = 29
p = 13.0/8.0
nbins = 20
m = 2
w = .25
k = 1

#gen_movie(x, g, hit, nbins, p, m, w, k, 10, "intensity")

I = simple_I(x, g, hit, p)
#I = gen_I(x, g, hit, p, m, w, k, 0*math.pi/10)

#plot_I(g, hit, I, nbins, "plots/temp")
plot_image(x, g, I, p, "temp")

