
####### PERIMETER TRACING AND PLOTTING FUNCTIONS #######
# Special thanks to Thomas for most of this code

import numpy as np
import scipy.ndimage as ndi
from numba import njit
from numba.core.types.misc import IterableType, NoneType
# from skimage.segmentation import clear_border
from scipy.ndimage.measurements import label
from scipy.ndimage import binary_erosion, binary_dilation
from scipy.spatial import KDTree
from scipy.stats import binned_statistic
import matplotlib.pyplot as plt
from collections.abc import Iterable


@njit()
def get_perimeter_avg(labelled_mask, var):
    '''
    Calculates the average value along a single perimeter.
    '''
    # Initalized storage arrays
    n_clouds = labelled_mask.max()
    p_sum = np.zeros(n_clouds).astype(np.float64)
    p_count = np.zeros(n_clouds).astype(np.float64)

    # Get iterators
    xl = labelled_mask.shape[0]
    yl = labelled_mask.shape[1]

    # First create new padded array with cyclic boundaries
    pad_m = np.zeros((xl+2,yl+2))
    pad_v = np.zeros((xl+2,yl+2))

    pad_m[1:-1,1:-1] = labelled_mask
    pad_m[0 ,1:-1] = labelled_mask[-1,:]
    pad_m[-1,1:-1] = labelled_mask[0,:]
    pad_m[1:-1, 0] = labelled_mask[:,-1]
    pad_m[1:-1,-1] = labelled_mask[:,0]

    pad_v[1:-1,1:-1] = var
    pad_v[0 ,1:-1] = var[-1,:]
    pad_v[-1,1:-1] = var[0,:]
    pad_v[1:-1, 0] = var[:,-1]
    pad_v[1:-1,-1] = var[:,0]


    # Now, loop through center of padded array, calculating 
    # average across perimeter, for all labelled clouds, as you go
    for i in range(xl):
        i += 1 # shift index into padded array
        for j in range(yl):
            j += 1 # shift index into padded array

            # The label # will determine which
            # index in p_sum/p_count we are operating on
            pixel = pad_m[i,j]
            idx = int(pixel) - 1

            # If the pixel is a cloud
            if pixel != 0:
                
                # Shift up, if values differ -> EDGE!
                # Add the var on each side of the edge, add 2 to the count
                if pad_m[i-1,j] != pixel:
                    p_sum[idx] += pad_v[i-1,j] + pad_v[i,j]
                    p_count[idx] += 2

                if pad_m[i+1,j] != pixel:
                    p_sum[idx] += pad_v[i+1,j] + pad_v[i,j]
                    p_count[idx] += 2

                if pad_m[i,j-1] != pixel:
                    p_sum[idx] += pad_v[i,j-1] + pad_v[i,j]
                    p_count[idx] += 2

                if pad_m[i,j+1] != pixel:
                    p_sum[idx] += pad_v[i,j+1] + pad_v[i,j]
                    p_count[idx] += 2

    # Remove clouds w/ zero area (they trigger NaNs)
    p_sum = p_sum[~(p_count == 0)]
    p_count = p_count[~(p_count == 0)]
    p_avg = p_sum / p_count # and that's your average

    return p_avg

@njit()
def get_perimeter_edge_values(labelled_mask, var):
    '''
    Will return the value of all edge segments for a give 2D slice
    Each edge segment value is an average of the inner and outer
    gridpoint normal to the gridpoint edge
    '''
    # Initalized storage arrays
    #n_clouds = labelled_mask.max()
    # p_sum = np.zeros(n_clouds).astype(np.float64)
    # p_count = np.zeros(n_clouds).astype(np.float64)
    edge_v = []

    # Get iterators
    xl = labelled_mask.shape[0]
    yl = labelled_mask.shape[1]

    # First create new padded array with cyclic boundaries
    pad_m = np.zeros((xl+2,yl+2))
    pad_v = np.zeros((xl+2,yl+2))

    pad_m[1:-1,1:-1] = labelled_mask
    pad_m[0 ,1:-1] = labelled_mask[-1,:]
    pad_m[-1,1:-1] = labelled_mask[0,:]
    pad_m[1:-1, 0] = labelled_mask[:,-1]
    pad_m[1:-1,-1] = labelled_mask[:,0]

    pad_v[1:-1,1:-1] = var
    pad_v[0 ,1:-1] = var[-1,:]
    pad_v[-1,1:-1] = var[0,:]
    pad_v[1:-1, 0] = var[:,-1]
    pad_v[1:-1,-1] = var[:,0]


    # Now, loop through center of padded array, calculating 
    # average across perimeter, for all labelled clouds, as you go
    for i in range(xl):
        i += 1 # shift index into padded array
        for j in range(yl):
            j += 1 # shift index into padded array

            # The label # will determine which
            # index in p_sum/p_count we are operating on
            pixel = pad_m[i,j]
            idx = int(pixel) - 1

            # If the pixel is a cloud
            if pixel != 0:
                
                # Shift up, if values differ -> EDGE!
                # Add the var on each side of the edge, add 2 to the count
                if pad_m[i-1,j] != pixel:
                    edge_v.append((pad_v[i-1,j] + pad_v[i,j])/2)

                if pad_m[i+1,j] != pixel:
                    edge_v.append((pad_v[i+1,j] + pad_v[i,j])/2)

                if pad_m[i,j-1] != pixel:
                    edge_v.append((pad_v[i,j-1] + pad_v[i,j])/2)

                if pad_m[i,j+1] != pixel:
                    edge_v.append((pad_v[i,j+1] + pad_v[i,j])/2)

    return edge_v


def label_mask(matrix, structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])):
    '''
    Identifies individual clouds, gives them an integer label and returns an array
    of the same size. Note works for periodic boundary conditions.

    Matrix should be a 2D binary (0,1) array
    '''
    labelled_matrix, n_clouds = label(matrix, structure, output=np.int32)
    # If clouds wrap, set them to the same identification number:
    for i, value in enumerate(labelled_matrix[0]):
        if value != 0:
            if labelled_matrix[labelled_matrix.shape[0]-1, i] != 0 and labelled_matrix[labelled_matrix.shape[0]-1, i] != value:
                # want not a cloud and not already changed
                labelled_matrix[labelled_matrix == labelled_matrix[labelled_matrix.shape[0]-1, i]] = value  # set to same identification number
    for j, value in enumerate(labelled_matrix[:,0]):
        if value != 0:
            if labelled_matrix[j, labelled_matrix.shape[1]-1] != 0 and labelled_matrix[j, labelled_matrix.shape[1]-1] != value:
                # want not a cloud and not already changed
                labelled_matrix[labelled_matrix == labelled_matrix[j, labelled_matrix.shape[1]-1]] = value  # set to same identification number

    n_clouds = labelled_matrix.max()

    return labelled_matrix, n_clouds

def get_perimeter_area(matrix, edge_behavior, print_info=True, structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]), dtype=np.uint32):
    '''
        Input:
            matrix:         binary matrix; 1 is a cloud 0 is not
            edge_behavior: str: How to treat clouds that come in contact with the edge. Options:
                                'wrap':      clouds on opposite sides of matrix are connected.
                                'truncate':   Include perimeter and area of the portion of the cloud that 
                                                is inside the matrix
                                                ** does not include perimeter along edge of matrix **
                                'truncate with edges':   Same as above, but include perimeter along edge.
                                'remove':        Remove any cloud that touches the edge.

            print_info:     T/F: Print number of clouds found and edge behavior

            structure:      2D np.array (3,3): Defines connectivity. Passed to scipy.ndimage.measurements.label

        Output:
            (perimeters, areas)
                Each is a 1D np.ndarray

        Given a 2-D binary matrix of 0s and 1s, calculate perimeters and areas of individual connected areas.
    '''

    if edge_behavior == 'wrap':

        labelled_matrix, n_clouds = label_mask(matrix)

        # Shift labelled matrix over and see what changed. Where values change are where the edges are.
        shifted_right = np.roll(labelled_matrix, shift=1, axis=1)
        shifted_down = np.roll(labelled_matrix, shift=1, axis=0)
        diff_right = np.abs(shifted_right-labelled_matrix)
        diff_down = np.abs(shifted_down-labelled_matrix)
    elif edge_behavior == 'truncate with edges':
        # Append zeros on two sides to cut clouds along boundary:
        matrix = np.append(np.zeros((1, matrix.shape[1])), matrix, axis=0)
        matrix = np.append(np.zeros((matrix.shape[0], 1)), matrix, axis=1)
        labelled_matrix, n_clouds = label(matrix, structure) 
        shifted_right = np.roll(labelled_matrix, shift=1, axis=1)
        shifted_down = np.roll(labelled_matrix, 1, axis=0)
        diff_right = np.abs(shifted_right-labelled_matrix)
        diff_down = np.abs(shifted_down-labelled_matrix)
    elif edge_behavior == 'truncate':
        labelled_matrix, n_clouds = label(matrix, structure)
        # Shift labelled matrix over and see what changed. Where values change are where the edges are.
        # For truncate without edges, start as with truncate with edges:
        matrix = np.append(np.zeros((1, matrix.shape[1])), matrix, axis=0)
        matrix = np.append(np.zeros((matrix.shape[0], 1)), matrix, axis=1)
        shifted_right = np.roll(labelled_matrix, shift=1, axis=1)
        shifted_down = np.roll(labelled_matrix, 1, axis=0)
        # Now, set the difference on all sides to 0 since we don't want it to be counted (already 0 on 2 sides)
        diff_right = np.abs(shifted_right-labelled_matrix)
        diff_down = np.abs(shifted_down-labelled_matrix)
        diff_down[0,:] = 0
        diff_right[:,0] = 0
    elif edge_behavior == 'remove':
        matrix = find_edge_clouds(matrix, structure)
        labelled_matrix, n_clouds = label(matrix, structure)
        shifted_right = np.roll(labelled_matrix, shift=1, axis=1)
        shifted_down = np.roll(labelled_matrix, 1, axis=0)
        diff_right = np.abs(shifted_right-labelled_matrix)
        diff_down = np.abs(shifted_down-labelled_matrix)
    else: raise ValueError('Invalid input for edge behavior: {}'.format(edge_behavior))

    if n_clouds == 0: return None, None
    if print_info: print('Getting perimeter and area for {} clouds, edge_behavior: {}'.format(n_clouds, edge_behavior))

    # return_counts on np.unique is almost exactly what is needed (don't need the unique values themselves),
    # but we don't want to count occurences of 0
    perimeters = np.unique(diff_right, return_counts=True)[1][1:].astype(dtype)
    perimeters += np.unique(diff_down, return_counts=True)[1][1:].astype(dtype)
    areas = np.unique(labelled_matrix, return_counts=True)[1][1:].astype(dtype)

    return perimeters, areas



##### LINEAR REGRESSION POWER LAW FITTING #####
# Adapted from Thomas's Class-based code for scenes

def create_bins(log_bin_min, log_bin_max, n_bins):
    '''
    Creates the bins by which perimeter/area amounts will be bucketed.

    INPUT
    log_bin_min - int, lowest logarithmic bin power (10**X)
    log_bin_max - int, highest logarithmic bin power (10**Y)
    n_bins - int, number of bins

    RETURNS
    log_bin_edges - array of the edge values for each bin: length = n_bins+1
    log_bin_middles - array of the middle value for each bin for plotting: length = n_bins
    '''
    
    log_bin_edges = np.log10(np.logspace(log_bin_min, log_bin_max, n_bins))
    log_bin_middles = (log_bin_edges[1:] + log_bin_edges[:-1])/2

    return log_bin_edges, log_bin_middles

# def find_edge_clouds(matrix, structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])):
#     """
#         Input:
#             matrix:         2D np.ndarray binary matrix; 1 is a cloud 0 is not

#             structure:      2D np.array (3,3): Defines connectivity. Passed to scipy.ndimage.measurements.label
#         Output:
#             matrix:         2D np.ndarray binary matrix; 1 is a cloud 0 is not with only clouds touching the edge included

#         Remove any cloud not touching the edge.
#     """
#     border_cleared = clear_border(matrix)
#     edge_clouds = (matrix - border_cleared).astype(np.int8) # The difference between the original and border_cleared will be exactly the edge clouds.

#     return edge_clouds  


def find_edge_clouds_OLD(labelled_mask):
    '''
    Given cloud area or perimeter 2D labelled mask (where areas or perimeters are labelled 1, 2, 3...) determines the unique
    cloud "numbers" that intersect one of the edges of the domain
    '''

    # Get the edge arrays
    # Cloud must be in both edges
    # top = np.unique(labelled_mask[0,:])
    # bottom = np.unique(labelled_mask[-1,:])
    # tb_edge_numbers = top[np.isin(top, bottom)]
    
    # left = np.unique(labelled_mask[:,0])
    # right = np.unique(labelled_mask[:,-1])
    # lr_edge_numbers = left[np.isin(left, right)]

    # Combine for 1 list of clouds that touch edges
    # uniq_edge_numbers = np.unique(np.concatenate([tb_edge_numbers, lr_edge_numbers]))

    uniq_edge_numbers = np.unique(np.concatenate([labelled_mask[0,:], labelled_mask[-1,:], labelled_mask[:,0], labelled_mask[:,-1]]))
    uniq_edge_numbers = uniq_edge_numbers[uniq_edge_numbers != 0] # 0 means no cloud

    if len(uniq_edge_numbers) == 0:
        uniq_edge_numbers = None

    return uniq_edge_numbers


def find_min_thresh_index(log_bin_edges, resolution, kernel=1, min_thresh=10, parameter='perimeter'):
    '''
    Because of the resolution, we need to make a cut-off point for the lower bins that reside
    close to the resolution length. This function will determine the minimum threshold index
    and its value from the logarithmically spaced bins.

    INPUT
    log_bin_edges  - 1D array of bin edge values (len +1 greater than counts array)
    resolution - length of pixel/dx in simulation (must be in comprable physical units)
    min_thresh - apparently 10-20 works best
    parameter - "area" or "perimeter", determines threshold value

    RETURNS
    min_thresh_value - float, numerical value of bin cutoff
    min_thresh_index - the smallest LEFT bin edge index to remove

    For instance to filter out smaller bins you would do counts[:min_thresh_index] = np.nan
    '''

    # Set cut-off size
    if parameter == 'area':
        thresh = min_thresh * resolution**2
        thresh *= kernel
    elif parameter == 'perimeter':
        thresh = min_thresh * resolution * kernel

    # find smallest bin larger than the resolution cutoff size
    for i, bin_edge in enumerate(log_bin_edges):
        if 10**bin_edge >= thresh: 
            min_thresh_value = 10**bin_edge
            min_thresh_index = i
            min_thresh_index = np.arange(0,min_thresh_index,1)

            return min_thresh_value, min_thresh_index
    
    return None, None


def find_edge_thresh_index(total_counts, edge_counts, edge_thresh, first=True):  # also min thresh
    '''
    To account for border/edge effects, if a bin has too many clouds touching the edges than it shouldn't be considered in the powerlaw derivation. This function determines what % of clouds in a bin touch the edge and returns the indices which should be removed from the counts. (Use find_edge_clouds to determine how to sort the clouds for each 2D image)

    INPUT
    total_counts - binned frequencies of all clouds
    edge_counts - binned frequencies of all clouds that cross the edge
    min_thresh - float, percentage threshold (e.g. 0.5) where if the fractional amount of 
                edge clouds surpasses the bin is flagged for removal
    first - boolean, if true, will determine the smallest bin where edge_thresh is passed and return the array
                of indices for all bins larger in size

    RETURNS
    edge_threh_index - list of indices by which total_counts[edge_thresh_index] = np.nan
    '''

    # Determine which bins have edge clouds under the threshold amount
    edge_total_cutoff = edge_thresh * total_counts
    if first:
        edge_arg = np.argwhere(edge_counts > edge_total_cutoff)
        if len(edge_arg) != 0:
            min_bin = edge_arg.min() # >= won't handle zero values
            edge_thresh_index = np.arange(min_bin, len(total_counts))
        else:
            edge_thresh_index = np.arange(len(total_counts))
    else: edge_thresh_index = np.argwhere(edge_counts >= edge_total_cutoff).squeeze()

    return edge_thresh_index


def filter_regression_bins(total_counts, edge_thresh_index, min_thresh_index):
    '''
    Given the bin indices which should not be used in the calculation for the powerlaw, will create 2
    arrays which contain the "good"/"bad" bins for calculation and plotting and NaN elsewhere.
    Note: The upside to using this masking NaN approach is that the regression/non-regression arrays 
    are the same length, and are easier to plot when given the same bin arrays.

    INPUT
    total_counts - Binned frequencies for a powerlaw, 1D Array
    edge_thresh_index - The edge threshold indices to be removed from regression plot, 1D Array
    min_thresh_index - The min theshold indices to be removed from regression plot, 1D Array

    RETURNS
    regression_counts - Bins to be used for regression, NaN elsewhere, 1D Array of length nbins
    noreg_counts - Bins not to be used for regression, NaN elsewhere, 1D Array of length nbins
    '''

    # Determine all indices to be made NaN, create a boolean mask for those indices
    filter_idx = np.concatenate((edge_thresh_index, min_thresh_index))
    no_reg_indices = np.unique(filter_idx)
    mask = np.zeros(total_counts.size, dtype=bool)
    mask[no_reg_indices] = True

    # Using the "bad" indices mask, create count arrays for regresion/non-regression bins
    regression_counts = total_counts.copy().astype(np.float32)
    regression_counts[mask] = np.nan
    noreg_counts = total_counts.copy().astype(np.float32)
    noreg_counts[~mask] = np.nan
    noreg_counts[noreg_counts==0] = np.nan 

    return regression_counts, noreg_counts


def linear_regression_np(x, y, print_label='', print_results=False):
    '''
        Return (slope, y-int), error for 95% conf

    INPUT
    x - 1D x-axis array/values, corresponds with bin_middles
    y - 1D y-axis array/values, corresponds with counts/frequencies

    RETURNS
    (slope, y-int) - (float, float), slope of powerlaw relationship, or beta,
                        intercept of powerlaw relationship, or alpha
    error - the 95% confidence interval
    '''

    remove_nan = np.isfinite(x) & np.isfinite(y) # Mask to remove NaNs
    
    if len(x[remove_nan]) < 3:    # the number of data points must exceed order to scale the covariance matrix
        return None, None
    
    try:
        coefficients, cov = np.polyfit(x[remove_nan], y[remove_nan], 1, cov=True)
        error = np.sqrt(np.diag(cov))
    
    except Exception as e:
        return None, None

    if print_results:
        print('Slope for', print_label + ':', str(round(coefficients[0])),'N = ',len(x), 'Y-int: ',round(coefficients[1]))

    return coefficients, 2*error  # 95% conf interval is 2 times standard error 


def calculate_fit_line(coefficients, bins, regression_counts, linear=True):
    '''
    Will calculate the [x0,x1] and [y0,y1] points given the powerlaw coefficients, bin middle points and the
    corresponding counts used in the regression to determin the coefficients (only used to get the appropriate
    non-NaN bins).

    INPUT
    coefficients - tuple of length 2, (beta, alpha), containing powerlaw coefficients
    bins - array of any length (but same as regression_counts) with midpoints of bins
    regression_counts - array of length bins which contains NaN for indices not used in regression calculation

    RETURNS
    X - list of length 2 with points x0, x1 for plotting
    Y - list of length 2 with points y0, y1 for plotting
    '''

    # Unpack coefficients for calculation
    B = coefficients[0]
    A = coefficients[1]

    # Get appropriate bins
    x = bins[~np.isnan(regression_counts)] # series of

    # Ensure bins are in linear space
    if not linear:
        x = 10**x

    # Calculate corresponding Y in linear space
    X = np.array([x[0], x[-1]])
    Y = np.array([ (10**A)*(i)**B for i in X])

    return X, Y



#### MASK MANIPULATION FUNCTIONS ####

def binary_manipulate_wrap(matrix, iteration, erode=True):
    '''
    Function to binary erode/dilate a mask array of 0/1's. Uses Scipy's ndimage package,
    however there is no functionality for 'wrapping' edges. This function pads the input mask
    with the appropriate 'wrapped' edges, does the erosion/dilation, then removes the pads.

    INPUT
    matrix - 2D Array, array of any dtype but must be a mask of 0/1's
    iteration - int, number of times to apply the erosion/dilation (see scipy documentation)
    erode - boolen, erosion if True, dilation if False

    RETURNS
    bin_manip_matrix - 2D array of same size/dtype as 'matrix' but with erosion/dilation applied
    '''

    # Since we are iterating the dilation, we need to pad the array based on
    # the number of iterations to prevent 'zeroing out'.
    # This could be done with a variable pad rate, but to make things easy
    # we can just add 10 rows/columns to each side of the array.
    # (Assuming iterations+1 < 10)
    pn = 10 # pad number
    if pn + 1 <= iteration:
        raise ValueError('Pad number (+1) must be greater than number of binary iterations!')

    # Preallocate padded array based on shape of input matrix
    shape = matrix.shape
    pad_matrix = np.zeros((shape[0]+2*pn, shape[1]+2*pn)).astype(matrix.dtype)

    # Assign old array into padded array
    pad_matrix[pn:-pn,pn:-pn] = matrix
    pad_matrix[0:pn,pn:-pn]  = matrix[-pn:,:] # Wrap bottom to top
    pad_matrix[-pn:,pn:-pn]  = matrix[0:pn,:] # Wrap top to bottom
    pad_matrix[pn:-pn,0:pn]  = matrix[:,-pn:] # Wrap right to left
    pad_matrix[pn:-pn,-pn:]  = matrix[:,0:pn] # Wrap left to right

    # Now we can erode/expand the structures
    if erode:
        bin_manip_matrix = binary_erosion(pad_matrix, iterations=iteration).astype(matrix.dtype)
    else:
        bin_manip_matrix = binary_dilation(pad_matrix, iterations=iteration).astype(matrix.dtype)

    # Removal of padded layers
    bin_manip_matrix = bin_manip_matrix[pn:-pn, pn:-pn]

    return bin_manip_matrix


def nearest_object_kdt(matrix, wrap=True, fluctuation=None):
    '''
    An efficient way to calculate the euclidian distances of every point in 
    a 2D matrix/mask to the nearest nonzero point. Uses the KDTree class from
    scipy.spatial package.

    INPUT
    matrix - 2D array, ideally a 0/1 mask, but floats/integers work as well
    wrap - boolean, whether input array has periodic conditions
    fluctuation - positive float, randomly adds a dx and dy-distance to the coordinates
                   bounded by +/- fluctuation. Has even probability of any dx-dy combination.
                   i.e. follows a uniform distribution

    RETURNS
    distances - 2D array of floats, same shape as input matrix with values
                corresponding to the euclidian distance to the nearest
                nonzero point
    '''

    shape = matrix.shape
    nx = shape[0]
    ny = shape[1]

    # Array of [x,y] coordinates for each object point
    obj_coords = np.transpose(np.nonzero(matrix))
    
    # Create KDTree class object
    # See scipy.spatial.KDTree documentation for more options
    if wrap:
        KDT = KDTree(obj_coords, boxsize=[nx,ny])
    else:
        KDT = KDTree(obj_coords)

    # List of [x,y] coords for every point in the domain
    all_coords = np.transpose(np.where(np.ones(shape)))

        # Add fluctuations
    if fluctuation is not None:
        dx = np.random.uniform(-fluctuation, fluctuation, size=all_coords.shape)
        all_coords = all_coords + dx

        # Have to modulus coords for periodic boundaries
        if wrap:
            all_coords[:,0] = all_coords[:,0] % nx
            all_coords[:,1] = all_coords[:,1] % ny

    # Nearest-Neighbor, euclidian distance, 16 threads
    distances, _ = KDT.query(all_coords, k=1, p=2, workers=16)
    distances = distances.reshape(shape)

    return distances

def distance_to_nearest_object(matrix, wrap=True):
    '''
    VERY SLOW AND/OR MEMORY LIMITED
    USE NEAREST_OBJECT_KDT INSTEAD
    Given a 2D array of 0/1s, will determine the distance (in gridsteps) to the nearest nonzero element,
    for every element in the array. If element is nonzero, distance will be 0. If you want to determine
    the distance of a point within a structure of 1's, then pass the perimeter of the structure.

    INPUT
    matrix - 2D array of 0/1s (can also be nonzero, but masks are nice and clean)
    wrap - Boolean, if true, pads array to determine distances across borders

    RETURNS
    distances - 2D array of same shape as matrix with 'lengths' to nearest nonzero element
    '''

    shape = matrix.shape
    nx = shape[0]
    ny = shape[1]

    x, y = np.where(matrix != 0) # Indices of non-zero locations
    x_all, y_all = np.where(np.ones(shape)) # All indices

    # Get the x, y distances to all points of interest
    x_dist = np.subtract.outer(x, x_all)
    y_dist = np.subtract.outer(y, y_all)

    # Periodic Boundary adjustment for dx/dy
    if wrap:
        for p, point in enumerate(x_dist):
            for x, dist in enumerate(point):
                if x_dist[p,x] > nx/2: # If greater than 1/2 domain
                    x_dist[p,x] = nx - x_dist[p,x] # other direction is actually closer
                
        for p, point in enumerate(y_dist):
            for y, dist in enumerate(point):
                if y_dist[p,y] > nx/2:
                    y_dist[p,y] = ny - y_dist[p,y]

    distances = np.sqrt(x_dist**2 + y_dist**2) # Transform to Euclidian distance
    distances = distances.min(axis=0).reshape(shape) # Collapse to nearest point with the min Euc. Distance 

    return distances


def extract_perimeter(matrix, wrap=True):
    '''
    Given a 2D array mask of 0/1's will determine the outermost gridpoints which make up the
    perimeter of a cluster of 1's. This essentially creates a new mask of only the "edge" 
    pixels. Note, uses the function binary_manipulate_wrap if mask is periodic.

    INPUT
    matrix - 2D array of 0/1's with 1's corresponding to objects of interest
    wrap - boolean, calculates perimeters for periodic domain

    RETURNS
    perimeters - 2D array of same shape as matrix, with 1's corresponding to perimeter
    interior - 2D array of same shape as matrix, with 1's corresponding to interior points
    '''

    if wrap:
        interior = binary_manipulate_wrap(matrix, 1, True)
    else:
        # Note structures touching edge will not create a "false" perimeter w/
        # border_value = 1
        interior = binary_erosion(matrix, border_value=1)
    
    perimeters = matrix - interior

    return perimeters, interior

################################

##### Binning and Analysis Functions ####

def bin_by_dimensions(bins, variable, weights=None):
    '''
    WORKS OKAY, SEE BIN_BY_DIMENSION_STAT FOR A MORE VERSATILE FUNCTION

    Given an array of bins, the variable to be binned and the weights to "sum", the
    sum of each bin and the count of each bin is determined for the variable of 
    interest. Summation can occur for a seperate variable if a "weights" array
    is the same shape as the variable array. Note binning occurs along the first,
    dimension of the variable array. If there is only one "level" to be binned, 
    ensure that the variable (and weights) array are padded with an extra dimension.

    INPUT
    bins - 1D array of bin edges
    variable - ND array of variable to be binned into bins array (dim, ...)
    weights - ND array of different value to sum bins by, same shape as 'variable'

    RETURNS
    sum_bins - 1D array of len(bins)-1 containing sum of variable or weights
                in each bin
    cnt_bins - 1D array of len(bins)-1 containing the counts of the variables
                in each bin
    bin_mids - 1D array of len(bins)-1 containing midpoints of bins (for plotting)
    '''

    # Ensure there is a dimension to loop binning process over
    if len(variable.shape) == 1:
        variable = np.expand_dims(variable, axis=0)
        if weights is not None:
            weights = np.expand_dims(weights, axis=0)

    ndim = variable.shape[0]
    bin_mids = (bins[1:]+bins[:-1])/2

    # Collector arrays
    sum_bins = np.zeros((ndim, len(bin_mids)+2)) # pad lower/upper for np.searchsorted
    cnt_bins = np.zeros_like(sum_bins)

    for dim, var in enumerate(variable):

        # Sorts variable into appropriate bins, determines the bin
        # indices and counts
        bin_no = np.searchsorted(bins, var.ravel(), side='left')
        bin_id, idx, cnt = np.unique(bin_no, return_counts=True, return_inverse=True)

        # Determine sum or weighted variable sums
        if weights is None:
            fill_bins = np.bincount(idx, var.ravel())
        else:
            fill_bins = np.bincount(idx, weights[dim].ravel())

        sum_bins[dim,bin_id] = fill_bins
        cnt_bins[dim,bin_id] = cnt

    return sum_bins[:,1:-1], cnt_bins[:,1:-1], bin_mids

def bin_by_dimensions_stat(bins, variable, weights=None, statistics=['sum','count'], percentiles=[]):
    '''
    TEST TO INCORPORATE DIFFERENT STATISTICS INTO BIN BY DIMENSIONS
    Given an array of bins, the variable to be binned and the weights to "sum", the
    sum of each bin and the count of each bin is determined for the variable of 
    interest. Summation can occur for a seperate variable if a "weights" array
    is the same shape as the variable array. Note binning occurs along the first,
    dimension of the variable array. If there is only one "level" to be binned, 
    ensure that the variable (and weights) array are padded with an extra dimension.

    See scipy.stats.binned_statistic documention for more information:
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binned_statistic.html

    INPUT
    bins - 1D array of bin edges
    variable - ND array of variable to be binned into bins array (dim, ...)
    weights - ND array of different value to sum bins by, same shape as 'variable'

    RETURNS
    sum_bins - 1D array of len(bins)-1 containing sum of variable or weights
                in each bin
    cnt_bins - 1D array of len(bins)-1 containing the counts of the variables
                in each bin
    bin_mids - 1D array of len(bins)-1 containing midpoints of bins (for plotting)
    '''

    # Ensure there is a dimension to loop binning process over
    if len(variable.shape) == 1:
        variable = np.expand_dims(variable, axis=0)
        if weights is not None:
            weights = np.expand_dims(weights, axis=0)

    ndim = variable.shape[0]
    nstat = len(statistics) + len(percentiles)

    # Collector arrays
    stat_bins = np.zeros((nstat, ndim, len(bins)-1))

    for dim, var in enumerate(variable):
        print(f'Determining stats for height {dim+1} of {ndim}...')
        for s in range(nstat):#, stat in enumerate(statistics):

            # To compute percentiles, a user-defined function is required
            # For the first statistics, the internel binned_statistic functions are used
            # For the percentiles, a UDF is used
            if s < len(statistics):
                if weights is None:
                    stat_bins[s,dim], _, _ = binned_statistic(var.ravel(), var.ravel(), statistic=statistics[s], bins=bins)
                else:
                    stat_bins[s,dim], _, _ = binned_statistic(var.ravel(), weights[dim].ravel(), statistic=statistics[s], bins=bins)
            
            else:
                i = s - len(statistics) # Set percentile index 
                #perc_func = lambda values, percent: np.percentile(values, percent)
                if weights is None:
                    stat_bins[s,dim], _, _ = binned_statistic(var.ravel(), var.ravel(), statistic=lambda values: np.percentile(values, int(percentiles[i])), bins=bins)
                else:
                    stat_bins[s,dim], _, _ = binned_statistic(var.ravel(), weights[dim].ravel(), statistic=lambda values: np.percentile(values, int(percentiles[i])), bins=bins)

    print('Done.')

    return stat_bins

def bin_by_dimensions_stat_TEST(bins, variable, weights=None, statistics=['sum','count'], percentiles=[]):
    '''
    TRIED TO REMOVE THE ITERATOR OVER THE FIRST DIMENSION IN VARIABLE, HOWEVER, THE X ARGUEMENT IN BINNED_STATISTIC
    REQUIRES THE SHAPE (N,). PERHAPS THIS COULD BE A FEATURE ADDITION FOR SCIPY

    Given an array of bins, the variable to be binned and the weights to "sum", the
    sum of each bin and the count of each bin is determined for the variable of 
    interest. Summation can occur for a seperate variable if a "weights" array
    is the same shape as the variable array. Note binning occurs along the first,
    dimension of the variable array. If there is only one "level" to be binned, 
    ensure that the variable (and weights) array are padded with an extra dimension.

    See scipy.stats.binned_statistic documention for more information:
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binned_statistic.html

    INPUT
    bins - 1D array of bin edges
    variable - ND array of variable to be binned into bins array (dim, ...)
    weights - ND array of different value to sum bins by, same shape as 'variable'

    RETURNS
    sum_bins - 1D array of len(bins)-1 containing sum of variable or weights
                in each bin
    cnt_bins - 1D array of len(bins)-1 containing the counts of the variables
                in each bin
    bin_mids - 1D array of len(bins)-1 containing midpoints of bins (for plotting)
    '''

    # Ensure there is a dimension to loop binning process over
    # if len(variable.shape) == 1:
    #     variable = np.expand_dims(variable, axis=0)
    #     if weights is not None:
    #         weights = np.expand_dims(weights, axis=0)

    ndim = variable.shape[0]
    nstat = len(statistics) + len(percentiles)

    variable = variable.reshape(ndim,-1)
    if weights is not None:
        weights = weights.reshape(ndim,-1)

    # Collector arrays
    stat_bins = np.zeros((nstat, ndim, len(bins)-1))

    for s in range(nstat):
        print(f'Working on statistic {s+1} of {nstat}...')

        # To compute percentiles, a user-defined function is required
        # For the first statistics, the internel binned_statistic functions are used
        # For the percentiles, a UDF is used
        if s < len(statistics):
            if weights is None:
                stat_bins[s,:,:], _, _ = binned_statistic(variable, variable, statistic=statistics[s], bins=bins)
            else:
                stat_bins[s,:,:], _, _ = binned_statistic(variable, weights, statistic=statistics[s], bins=bins)
        
        else:
            i = s - len(statistics) # Set percentile index 
            perc_func = lambda values, percent: np.percentile(values, percent) # Inline UDF
            if weights is None:
                stat_bins[s,:,:], _, _ = binned_statistic(variable, variable, statistic=perc_func(variable, percentiles[i]), bins=bins)
            else:
                stat_bins[s,:,:], _, _ = binned_statistic(variable, weights, statistic=perc_func(weights, percentiles[i]), bins=bins)

    print('Done.')

    return stat_bins


#########################################


###### Plotting Functions ######

def plot_power_law(bins, rcounts, nrcounts, coeffs, labels, colors, fit=True, offset=True):
    '''
    Function to plot multiple powerlaws on the same figure, where each dataset has been binned in 
    the same manner. rcounts, nrcounts, coeffs, labels and colors should all be lists (or lists
    containing arrays) of the same length, where each element corresponds to data/metadata of a 
    different dataset. The base figure and axes are created, then returned for additional features,
    if needed. Also note that each value should be in linear space. Values are plotted in linear
    space, then the plot is transformed into log-space.

    INPUT
    bins - array, mid-points of bins in linear space
    rcounts - list of arrays, each arrays contains the counts used in the regression for the powerlaw
                (in linear space)
    nrcounts - list of arrays, same as rcounts, but for counts not used in regression
    coeffs - list of tuple (beta, alpha), containing the powerlaw coefficients needed to plot fit
    labels - list of strings, labels for each data series to appear in legend
    colors - list of colors, preferable hex, coloring the respective data series
    fit - boolean, if true, calculates and plots powerlaw fit
    offset - boolean, if true, offsets each sequential series upward

    RETURNS
    fig - pyplot figure object
    ax - pyplot axes objects
    '''

    fig, ax = plt.subplots()

    for i, cnt in enumerate(rcounts):

        # Create offset amount to add to the counts
        if offset: 
            if i == 0: ofs = 1
            else: ofs = (0.5**i)*(10**i) # Shift upwards an order of magnitude, decreasing with iteration
        else: ofs = 1

        # Plot the good/bad binned counts
        ax.scatter(bins, ofs*cnt, s=35, marker='^', label=labels[i], color=colors[i], edgecolors='black')
        ax.scatter(bins, ofs*nrcounts[i], s=15, marker='x', color=colors[i])

        # Determine the fit and plot
        if fit:
            X, Y = calculate_fit_line(coeffs[i], bins, rcounts[i])
            ax.plot(X, ofs*Y, color=colors[i], label=f'$\\beta$={coeffs[i][0]:.2f}')

    # Change to log-log space
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim([10**0, 10**4])
    ax.set_ylim([10**0, 10**6]) 

    return fig, ax


def plot_bin_differences(bins, rcounts, nrcounts, labels, colors):
    '''
    Function to plot multiple powerlaws on the same figure, where each dataset has been binned in 
    the same manner. rcounts, nrcounts, coeffs, labels and colors should all be lists (or lists
    containing arrays) of the same length, where each element corresponds to data/metadata of a 
    different dataset. The base figure and axes are created, then returned for additional features,
    if needed. Also note that each value should be in linear space. Values are plotted in linear
    space, then the plot is transformed into log-space.

    INPUT
    bins - array, mid-points of bins in linear space
    rcounts - list of arrays, each arrays contains the counts used in the regression for the powerlaw
                (in linear space)
    nrcounts - list of arrays, same as rcounts, but for counts not used in regression
    coeffs - list of tuple (beta, alpha), containing the powerlaw coefficients needed to plot fit
    labels - list of strings, labels for each data series to appear in legend
    colors - list of colors, preferable hex, coloring the respective data series
    fit - boolean, if true, calculates and plots powerlaw fit
    offset - boolean, if true, offsets each sequential series upward

    RETURNS
    fig - pyplot figure object
    ax - pyplot axes objects
    '''

    rcounts = np.array(rcounts)
    nrcounts = np.array(nrcounts)

    # Need to combine for no NaNs, take the difference, normalize,
    # then re-Nan each series.
    # Probably better way to do this but whatever.
    full = rcounts.copy()
    full[np.isnan(full)&~np.isnan(nrcounts)] = nrcounts[~np.isnan(nrcounts)] # No NaNs
    # Reason for the & indexing is if both ncounts and nrcounts have NaNs in same element

    diff = (full[1:] - full[0])/full[0] + 1 # Create differences
    rdiff = diff.copy()
    nrdiff = diff.copy()
    rdiff[np.isnan(rcounts[1:])] = np.nan # Make diff array mimick the reg/noreg data
    nrdiff[np.isnan(nrcounts[1:])] = np.nan
    
    fig, ax = plt.subplots()

    for i, diff in enumerate(rdiff):
        
        # Plot the good/bad binned differencts
        ax.scatter(bins, diff, s=35, marker='^', label=labels[i], color=colors[i], edgecolors='black')
        ax.scatter(bins, nrdiff[i], s=15, marker='x', color=colors[i])

    ax.set_xscale('log')
    ax.set_xlim([10**0, 10**4])
    ax.plot(ax.get_xlim(), [1,1], color='black', linestyle='--')
    ax.plot(ax.get_xlim(), [0,0], color='black', linestyle=':')

    return fig, ax

def flatten_nested_list(xs):
    '''
    Generator that takes any arbitrary list of lists/arrays of varying
    irregular length/shape and returns each element sequentially.

    use list(flatten_nested_list(xs)) for a list
    '''
    for x in xs:
        if isinstance(x, Iterable) and not isinstance(x, bytes):
            yield from flatten_nested_list(x)
        else:
            yield x


@njit()
def calc_moment(A, center=0, moment=1, weights=None, standardize=False):
    '''
    Based on scipy.stats.moment, given a numpy array will calculated the nth moment
    as specified (default is 1). If a center is specified, moments about this center
    will be calculated (otherwise it's zero).

    A MUST BE 1D ARRAY
    CENTER MUST BE A SCALAR
    WEIGHTS MUST BE A 1D ARRAY with length equal to A

    USE FLATTEN_NESTED_LIST TO PREP DATA FOR CALC_MOMENT
    '''

    if moment == 0:
        return 1

    n = len(A)

    if weights is None:
        mom = ((A - center)**moment).sum() / n

    else:
        mom = (weights*(A - center)**moment).sum() / weights.sum()

    if standardize:
        sigma2 = calc_moment(A, center, 2, weights)
        mom = mom / sigma2**(moment/2)

    return mom