import numpy as np
from numba import jit
from samIO import simulations as sim
from netCDF4 import Dataset

def calc3DVariables():

    # Will append MSE, SSE and QVSAT as variables into the 3D netcdf files.
    # Also saves profiles (horizontal averages) of MSE and SSE for each timestep
    # to be used for plotting and testing against theory

    times = np.arange(2880,172801, 2880)
    nl_directory = 'namelists/'
    namelists = ['nlRCE295FR.json', 'nlRCE300FR.json', 'nlRCE305FR.json']
    g_kg = 1000

    for ni, nl in namelists:

        print('Namelist: ', nl)
        sam = sim.EGP(f'{nl_directory}{nl}')
        sam.get_coordinates(times[0])
        sam.Z = sam.coordinates[2]

        mse_profile = np.zeros((len(namelists), len(times), sam.nz)) # (sim, time, z)
        sse_profile = np.zeros_like(mse_profile)

        for i, t in enumerate(times):
            
            # Bring in Data
            print(f'Working on time: {i+1} of {len(times)}', end='\r')
            t0 = f'{t:0{sam.ntimedig}d}'
            ncPath = f'{sam.dSAM}OUT_3D/{sam.fname}_{t0}.nc'
            sam.TABS, sam.QV = sam.load_3D(['TABS', 'QV'], t)

            # Calculate MSE
            sam.MSE = calc_mse(sam.TABS, sam.Z, sam.QV/g_kg)
            sam.MSE = sam.MSE.transpose((2,1,0)) # Transpose to keep shape of netcdf files
            mse_profile[ni,i,:] = sam.MSE.mean(axis=(1,2))

            # Write to existing netCDF
            ncf = Dataset(ncPath, mode='a')
            try:
                var = ncf.createVariable('MSE', np.float32, ('time','z','y','x'), zlib=True, least_significant_digit=2)
                var.units = 'K'
                var.long_name = 'Moist Static Energy'
            except:
                var = ncf.variables['MSE']
            var[0,:,:,:] = sam.MSE
            ncf.close()

            # Calculate SSE
            sam.QVSAT = calc_qvsat(sam.TABS, sam.pres)*g_kg #kg/kg
            sam.SSE = calc_mse(sam.TABS, sam.Z, sam.QVSAT)
            sam.QVSAT = sam.QVSAT.transpose((2,1,0))
            sam.SSE = sam.SSE.transpose((2,1,0)) # Annoying transpose
            sse_profile[ni,i,:] = sam.SSE.mean(axis=(1,2))
            
            # Write to existing netCDF
            ncf = Dataset(ncPath, mode='a')
            try:
                var = ncf.createVariable('SSE', np.float32, ('time','z','y','x'), zlib=True, least_significant_digit=2)
                var.units = 'K'
                var.long_name = 'Saturation Static Energy'
            except:
                var = ncf.variables['SSE']
            var[0,:,:,:] = sam.SSE
            ncf.close()

            # Write to existing netCDF
            ncf = Dataset(ncPath, mode='a')
            try:
                var = ncf.createVariable('QVSAT', np.float32, ('time','z','y','x'), zlib=True, least_significant_digit=4)
                var.units = 'g kg^-1'
                var.long_name = 'Saturation Mixing Ratio'
            except:
                var = ncf.variables['QVSAT']
            var[0,:,:,:] = sam.QVSAT
            ncf.close()

    # Save the profile data
    np.savez(f'{sam.dWrite}mean_profiles.npz', mse_prof=mse_profile, sse_prof=sse_profile)


###########################################################################
# Functions for calculating saturation vapor pressure in SAM
###########################################################################

@jit(forceobj=True)
def esatw(T):
    # Uses SAM's 8th-order polnomial approximation for esat over water

    a0 = np.array([6.11239921,0.443987641,0.142986287e-1,0.264847430e-3,0.302950461e-5,0.206739458e-7,0.640689451e-10,-0.952447341e-13,-0.976195544e-15])

    dt = np.maximum(-80.,T-273.16)
    esatw = a0[-1]*np.ones_like(T)
    for i in np.flip(a0[:-1]):

        esatw = esatw*dt + i
        #esatw = a0 + dt*(a1+dt*(a2+dt*(a3+dt*(a4+dt*(a5+dt*(a6+dt*(a7+a8*dt)))))))

    return esatw

@jit(forceobj=True)
def esati(T):
    a0 = 6.11147274
    a1 = 0.503160820
    a2 = 0.188439774e-1
    a3 = 0.420895665e-3
    a4 = 0.615021634e-5
    a5 = 0.602588177e-7
    a6 = 0.385852041e-9
    a7 = 0.146898966e-11
    a8 = 0.252751365e-14
    #a0 = np.array([6.11147274,0.503160820,0.188439774e-1,0.420895665e-3,0.615021634e-5,0.602588177e-7,0.385852041e-9,0.146898966e-11,0.252751365e-14])

    esati = np.zeros_like(T)
    
    dt = np.maximum(-100, T[T<=185.]-273.16)
    esati[T<=185] = 0.00763685 + dt*(0.000151069+dt*7.48215e-07)

    # else if(t.gt.185.) then
    dt = T[(T>185)&(T<=273.15)] - 273.16
    esati[(T>185)&(T<=273.15)] =  a0 + dt*(a1+dt*(a2+dt*(a3+dt*(a4+dt*(a5+dt*(a6+dt*(a7+a8*dt)))))))

    # if(t.gt.273.15) then
    esati[T>273.15] = esatw(T[T>273.15])

    esati

    return esati

@jit(forceobj=True)
def qsatw(T,p):
    # NOTE: p IN mb!!!!
    esat = esatw(T)
    qsatw = 0.622 * esat/np.maximum(esat,p-esat)

    return qsatw

@jit(forceobj=True)
def qsati(T,p):
    # NOTE: p IN mb!!!!
    esat = esati(T)
    qsati = 0.622 * esat/np.maximum(esat,p-esat)

    return qsati

@jit(forceobj=True)
def calc_qvsat(T,p):

    T0 = 273.16 # Temperature threshold for ice
    T00 = 253.16 # Temperature threshold for cloud water
    Tt = (T - T00)/(T0 - T00)

    w_n = np.maximum(0,np.minimum(1,Tt))

    qvsat = w_n*qsatw(T,p) + (1-w_n)*qsati(T,p)

    return qvsat

@jit(nopython=True)
def calc_mse(T,Z,QV):

    L_v = 2.5104e6
    g = 9.81
    c_p = 1004

    mse = T + (g/c_p * Z) + (L_v/c_p * QV)

    return mse

###########################################################################
###########################################################################


if __name__ == '__main__':
    calc3DVariables()
