#!/usr/bin/env python3
"""Plots Figure S4, showing the change in OH and H2O2 concentrations at 
different distances from the towers in the 3-Tower simulation. Adjusting 
`conc_file_nm` allows for plotting of the winter data, to produce Figure S5. """
#imports
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from geopy.distance import geodesic

#path to two sets of GEOS-Chem outputs (with and without H2O2 towers added)
ref_gc_outpath = "data/Base/"
tower_gc_outpath = "data/3-Tower/"

#name of species and aerosol concentration files (this should be high time 
# resolution (and the same time resolution) for both the reference and tower models)
conc_file_nm = "reduced_GEOSChem.SpeciesConc.20160601_0000z.nc4"

#List of species to plot concentrations of
plot_conc_specs = ["SpeciesConcVV_H2O2", "SpeciesConcVV_O3",
                   "SpeciesConcVV_OH"]


#Path to a drirectory to store the plots as pdfs
plot_outdir = "plots"

#prefix to prepend to the filename of plots
file_prefix = "Summer"

#locations of the towers
tower_latlons = {"Alberta" : (50.5, -110.5), 
                 "Alabama" : (31.4, -87.7), 
                 "California" : (35.3, -119.1)}

#offset between UTC and local time for each tower
tower_utc_offsets = {"Alberta" : -7,
                     "Alabama" : -6,
                     "California" : -8}

#radii around each tower to plot 'local air quality' metrics (in km)
close_thresh = [50, 100, 250]

cols = ["firebrick", "rebeccapurple", "seagreen", "goldenrod"]
###############################################################################
def plot_diurnal(ts_df, ax, c, label):
    #calcualte diurnal dataframe
    di_df = ts_df.to_frame()

    di_df["hour"] = di_df.index.hour
        
    di_df_med = di_df.groupby("hour").median()[0]
    di_df_25 = di_df.groupby("hour").quantile(0.25)[0]
    di_df_75 = di_df.groupby("hour").quantile(0.75)[0]

    ax.fill_between(di_df_med.index, di_df_75, di_df_25,
                    alpha = 0.25, color = c)
    ax.plot(di_df_med, "-", color = c, label = label)
    ax.set_xlim(0,24)

###############################################################################
#read in data
print("Reading GEOS-Chem outputs...")
#read in each dataset
ref_concs = xr.open_dataset(f"{ref_gc_outpath}/{conc_file_nm}")
tower_concs = xr.open_dataset(f"{tower_gc_outpath}/{conc_file_nm}")

#only select the bottom few vertical levels
ref_concs = ref_concs.isel(lev=slice(0,8))
tower_concs = tower_concs.isel(lev=slice(0,8))

#trim the outer boxes of the model which are not correct because of the nested grid
ref_concs = ref_concs.sel(lat = slice(ref_concs.lat[5], ref_concs.lat[-5]),
                          lon = slice(ref_concs.lon[5], ref_concs.lon[-5]))
tower_concs = tower_concs.sel(lat = slice(tower_concs.lat[5], tower_concs.lat[-5]),
                              lon = slice(tower_concs.lon[5], tower_concs.lon[-5]))
print("...finished reading data!")
###############################################################################
#go through each lat-lon coordinate and assign it to each tower if it is within 
#the threashold distance to be classed as 'close'
def get_close_latlons(target_latlon):
    sel_points = {k : [] for k in close_thresh}
    for sel_lat in ref_concs.lat:
        for sel_lon in ref_concs.lon:
            dist = geodesic(target_latlon, (sel_lat.values, sel_lon.values)).km
            for k in close_thresh:
                if dist <= k:
                    sel_points[k].append((sel_lat.values, sel_lon.values))
    return sel_points
        
close_latlons = {k : get_close_latlons(v) for k,v in tower_latlons.items()}

#splitting the lat and lon values into dataarrays allows for the right slicing 
#later (e.g. https://stackoverflow.com/questions/72179103/xarray-select-the-data-at-specific-x-and-y-coordinates)
close_lats = {k1 : {k2 : xr.DataArray([x[0] for x in v2], dims=["location"]) for k2, v2 in v1.items()} for k1,v1 in close_latlons.items()}
close_lons = {k1 : {k2 : xr.DataArray([x[1] for x in v2], dims=["location"]) for k2, v2 in v1.items()} for k1,v1 in close_latlons.items()}

print("Selecting data for each tower...")
tower_conc_dict = {k : {} for k in tower_latlons.keys()}
ref_conc_dict = {k : {} for k in tower_latlons.keys()}
tower_aer_dict = {k : {} for k in tower_latlons.keys()}
ref_aer_dict = {k : {} for k in tower_latlons.keys()}
for k, (lat, lon) in tower_latlons.items():
    print(f"...{k}...")
    for c in close_thresh:
        tower_conc_dict[k][c] = tower_concs.sel(lat=close_lats[k][c], lon=close_lons[k][c])
        ref_conc_dict[k][c] = ref_concs.sel(lat=close_lats[k][c], lon=close_lons[k][c])
        
        #convert time from UTC to local time
        tower_conc_dict[k][c]["time"] = [x.values + np.timedelta64(tower_utc_offsets[k],
                                                                   'h') for x in tower_conc_dict[k][c]["time"]]
        ref_conc_dict[k][c]["time"] = [x.values + np.timedelta64(tower_utc_offsets[k],
                                                                 'h') for x in ref_conc_dict[k][c]["time"]]
        tower_aer_dict[k][c]["time"] = [x.values + np.timedelta64(tower_utc_offsets[k],
                                                                  'h') for x in tower_aer_dict[k][c]["time"]]
        ref_aer_dict[k][c]["time"] = [x.values + np.timedelta64(tower_utc_offsets[k],
                                                                'h') for x in ref_aer_dict[k][c]["time"]]
###############################################################################
print("Beginning plotting...")
#Plot the Local OH and H2O2 concentrations for the paper SI
fig = plt.Figure(figsize = (10,6))
ax_i = 1
for twr_nm in tower_latlons.keys():
    h2o2_ax = fig.add_subplot(2, 3, ax_i)
    oh_ax = fig.add_subplot(2, 3, ax_i+3)
    #iterate through each df within the nested dict (for different distances from the tower)
    for i2, (dist) in enumerate(ref_conc_dict[twr_nm].keys()):
        h2o2_ref_ds = ref_conc_dict[twr_nm][dist]["SpeciesConcVV_H2O2"].isel(lev=1) * 1E9
        h2o2_tower_ds = tower_conc_dict[twr_nm][dist]["SpeciesConcVV_H2O2"].isel(lev=1) * 1E9
        oh_ref_ds = ref_conc_dict[twr_nm][dist]["SpeciesConcVV_OH"].isel(lev=1) * 1E12
        oh_tower_ds = tower_conc_dict[twr_nm][dist]["SpeciesConcVV_OH"].isel(lev=1) * 1E12

        h2o2_diff_ds = (h2o2_tower_ds - h2o2_ref_ds)
        oh_diff_ds = (oh_tower_ds - oh_ref_ds)

        h2o2_avg_diff = h2o2_diff_ds.mean("location").to_pandas()
        oh_avg_diff = oh_diff_ds.mean("location").to_pandas()

        plot_diurnal(h2o2_avg_diff, h2o2_ax, cols[i2], f"{dist} km")
        plot_diurnal(oh_avg_diff, oh_ax, cols[i2], f"{dist} km")

    h2o2_ax.set_ylabel("Difference (ppb)")
    oh_ax.set_ylabel("Difference (ppt)")
    h2o2_ax.set_xlabel("Hour of Day (Local Time)")
    oh_ax.set_xlabel("Hour of Day (Local Time)")
    
    h2o2_ax.legend(loc="upper left")
    oh_ax.legend(loc="upper left")
    
    h2o2_ax.set_xlim((0,24))
    oh_ax.set_xlim((0,24))
    
    h2o2_lab = ["(a)", "(b)", "(c)"][ax_i-1]
    oh_lab = ["(d)", "(e)", "(f)"][ax_i-1]

    h2o2_ax.set_title(f"{h2o2_lab} {twr_nm} H$_2$O$_2$")
    oh_ax.set_title(f"{oh_lab} {twr_nm} OH")
    
    ax_i += 1 
fig.tight_layout()
fig.savefig(f"{plot_outdir}/{file_prefix}_Local_OH_H2O2_6Plot.png", dpi=500)
###############################################################################
