import numpy as np
from samIO import simulations as sim, perimeters as prm
from netCDF4 import Dataset
from scipy.stats import mode

def weighted_avg_and_std(values, weights):
    """
    Return the weighted average and standard deviation.

    They weights are in effect first normalized so that they 
    sum to 1 (and so they must not all be 0).

    values, weights -- NumPy ndarrays with the same shape.
    """
    average = np.average(values, weights=weights)
    # Fast and numerically precise:
    variance = np.average((values-average)**2, weights=weights)
    return (average, np.sqrt(variance))

def calcPerimStats():

###############################
######### PARAMETERS ##########
###############################
    namelists = ['nlRCE295FR.json','nlRCE300FR.json','nlRCE305FR.json']
    masks = [0.01] # What QN threshold should be used for the cloud mask
    times = np.arange(2880,172801, 2880) # Match output times of 3D files
    calc_weights = True # rh0*dz or no weights?
    file_name_append = '3sims_1mask_weighted_fullrad'
###############################
###############################
###############################
    
    nt = len(times)
    nsim = len(namelists)
    nmask = len(masks)

    # Establish arrays for statistics collection
    perim_stats = np.zeros((nsim, nmask, nt, 2), dtype=np.float32)
    perim_stats_all = np.zeros((nsim, nmask, 2), dtype=np.float32)
    # Calculate for each RCE simulation
    for s, nl in enumerate(namelists):
        print(f'{nl}')
        # Initiate simulation instance and bring in edge dataset
        sam = sim.EGP(f'namelists/{nl}')

        for q, mask in enumerate(masks):

            if isinstance(mask, str):
                m_str = 'QS'
            else:
                m_str = f'{mask:.3f}'

            print(f'Mask: {m_str}')
            nc_path = f'{sam.dSAM}OUT_EDGE/{sam.fname}_EDGE_M{m_str}.nc'
            edges = Dataset(nc_path, 'r')

            # Determine weights for each height
            if calc_weights:
                rho_dz_weight = np.empty((sam.nz,), np.float32)
                for k in range (sam.nz):
                    rho_dz_weight[k] = edges['RHO'][k] #* edges['dz'][k]

            # Catch list for all times in a simulation
            mse_all = [] # all edge values
            weights_all = []
            if calc_weights:
                weights_all = [] # all weights
            else:
                weights_all = None # No weights

            # Calculate moments for each point in time
            for t in range(nt):
                print(f'{t+1} of {nt}', end='\r')
                
                mse_z = edges['edge_mse'][t,:] # Edge values at t (ragged array)
                mse = np.concatenate(mse_z) # 1D

                # Create 1D array of weights to feed into moment calculation
                if calc_weights:
                    weights = [] # weights for time t
                    for k, perim_mse in enumerate(mse_z):
                        nedges = len(perim_mse)
                        if nedges != 0:
                            weights.append([rho_dz_weight[k]] * nedges)
                    weights = np.concatenate(weights)
                    
                    assert len(mse) == len(weights)
                else:
                    weights = None

                # Calculate average and standard deviation
                mse_avg, mse_std = weighted_avg_and_std(mse, weights)
                perim_stats[s,q,t,0] = mse_avg
                perim_stats[s,q,t,1] = mse_std

                # Append mse and weights to list across all times
                # Used to calculated temporal average across all perimeters
                mse_all.append(mse)
                if calc_weights:
                    weights_all.append(weights)

            edges.close() # Close netCDF
 
            # Combine to 1D array
            mse_all = np.concatenate(mse_all)
            if calc_weights:
                weights_all = np.concatenate(weights_all)
                assert len(mse_all) == len(weights_all)

            # Calculate averages and std across all times
            mse_avg_all, mse_std_all = weighted_avg_and_std(mse_all, weights_all)
            perim_stats_all[s,q,0] = mse_avg_all
            perim_stats_all[s,q,1] = mse_std_all


    np.savez(f'{sam.dWrite}RCE_stats_perim_{file_name_append}.npz', perim_stats=perim_stats, perim_stats_all=perim_stats_all)

if __name__ == '__main__':
    calcPerimStats()