
from samIO import simulations as sim
import numpy as np

# This script will create a 2D histogram using variables of your choosing from the output of SAM.
# Using MSE and the mass flux will calculate the convective streamfunction.
# It will loop through the each 3D snapshot, and bin a variable or calculate a variable for each height.

#############################################################
####################### Script Parameters ###################
#############################################################

# Point to appropriate directory with 3D variables of interest
dwrite = '../STATISTICS/'
nl = 'namelists/nlRCE300HR.json' # Location describing SAM simulation output
dim_labels = ['time', 'xbins', 'ybins'] # Should be dimensions of arrays, in this order

times = np.arange(2880,172801,2880) # Array of times to loop through

varName = 'W' # Variable to read out of SAM
varCalc = True # Calculate mass flux? or just use variable
varCalcName = 'MF' # Naming convention
xVarName = 'MSE' # binning variable X (Y bin is always height)
xVarCalc = False # X-bin variable manipulation?
xVarFunc = np.log10 # What kind of manipulation?
append = '_300K_FR' # Resulting file naming extension

# Creation of Bins
# For MSE:
dx = 0.5 # K
h_bin_edges = np.arange(315.95, 350.06, dx) # 300 # Start, Finish+dx, dx
# h_bin_eedges = np.arange(289.95, 340.06, dx) # 295K
# h_bin_eedges = np.arange(320.95, 360.06, dx) # 305K
# For Height:
dz = 500
z_bin_edges = np.arange(50, 20_001, dz)

#############################################################
#############################################################

fname = f'HIST_{varCalcName}_{xVarName}{append}.npz'
nt = len(times)
nhbins = len(h_bin_edges)-1
nzbins = len(z_bin_edges)-1

def calcVarByHeight():

    print(f'Calculating {varCalcName} across {xVarName}')

    sam = sim.EGP(nl)
    sam.get_coordinates(2880)
    sam.yVar = sam.coordinates[-1]
    sam.yVar = np.tile(sam.yVar, (sam.nx, sam.ny, 1)).flatten()

    sums = np.zeros(shape=(nt, nhbins, nzbins))
    counts = np.zeros_like(sums, dtype=np.int64)

    for t, time in enumerate(times):
        print(f'Calculating time: {time}')
        print('Loading Data...')
        # Bring in relevant data
        sam.Var = sam.load_3D(varName, time).flatten()
        sam.xVar = sam.load_3D(xVarName, time).flatten()

        # Apply functional shift to xbinning variable
        if xVarCalc:
            with np.errstate(divide='ignore'):
                sam.xVar = xVarFunc(sam.xVar)
        
        # Apply calculation to variable of interest if needed
        if varCalc:
            ###### Put Calculations here ######
            sam.Var = sim.interpolate_3D_velocities(sam.Var, axis=2, periodic=False)
            rho = sam.load_stat('RHO')[0] # 1D (z,) but 256, remove top level
            sam.Var = rho * sam.Var
        sums[t], xedges, yedges = np.histogram2d(sam.xVar, sam.yVar, bins=[h_bin_edges, z_bin_edges], weights=sam.Var)

        counts[t], _, _ = np.histogram2d(sam.xVar, sam.yVar, bins=[h_bin_edges, z_bin_edges])

    # Create dictionary to have coordinates for histogram
    dims = {dim_labels[0]:times, dim_labels[1]:xedges, dim_labels[2]:yedges}

    np.savez(f'{dwrite}{fname}', sums=sums, cnts=counts, dims=dims)


if __name__ == '__main__':
    calcVarByHeight()