"""
## Polygenic predictor parameters ##
Polygenic general factor (predictor) for IQ
- Assortative constant = 0.94    (assume same as height and fluid intelligence)
- Between-family variance = 1 / 2
- Within-family variance = 1 / 2  (sibling variance)
- Population variance = 1 = between-family + within-family

"""

from tqdm import tqdm
import numpy as np
import pandas as pd
from scipy.stats import kstest, norm
import matplotlib.pyplot as plt
import seaborn as sns
import time
import json
from IPython.display import clear_output

ASSORTATIVE_CONSTANT = 0.94

WITHIN_FAMILY_VARIANCE = 1 / 2

N_EMBRYOS_V = [5, 10, 15, 20, 25]
N_SELECT_V = [2, 3, 4, 5, 6, 7, 8, 9, 10]
PRED_CORR_V = [0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65]

N_DRAWS = 200_000


def sibling_pgs_sd():
    sibling_pgs_sd = np.sqrt(WITHIN_FAMILY_VARIANCE) * ASSORTATIVE_CONSTANT
    return sibling_pgs_sd


def mad_to_sd(mad):
    return mad / norm.ppf(0.75)


def get_mad_normal(x):
    median = np.median(x)
    mad = np.median(np.abs(x - median))
    sd = mad_to_sd(mad)
    return median, sd

def iterate_over_params(n_embryos_vector, n_select_vector, pred_corr_vector):
    for n_embryos in tqdm(n_embryos_vector):
        for n_select in n_select_vector:
            for pred_corr in pred_corr_vector:
                yield n_embryos, n_select, pred_corr


def get_pheno(n_embryos, n_select, pred_corr, n_draws=N_DRAWS):
    n_select = min(n_select, n_embryos)

    draws_pgs = np.random.normal(size=[n_draws, n_embryos])
    draws_pgs_sorted = np.sort(draws_pgs, axis=1)
    top_draws_pgs_zscore = draws_pgs_sorted[:, -n_select:]

    top_draws_pgs = top_draws_pgs_zscore * sibling_pgs_sd()

    top_draws_pheno_mean = top_draws_pgs * pred_corr

    top_draws_pheno_sd = np.sqrt(1 - pred_corr**2)

    top_draws_pheno = np.random.normal(top_draws_pheno_mean, top_draws_pheno_sd)

    return top_draws_pheno


def compute_parameters(n_draws=N_DRAWS, plot=False):
    data = []

    for n_embryos, n_select, pred_corr in iterate_over_params(N_EMBRYOS_V, N_SELECT_V, PRED_CORR_V):
        entry = {"n_embryos": n_embryos, "n_select": n_select, "pred_corr": pred_corr}

        top_draws = get_pheno(n_embryos, n_select, pred_corr, n_draws=n_draws)
        all_draws = get_pheno(n_embryos, n_embryos, pred_corr, n_draws=n_draws)

        top_draws = top_draws.flatten()
        all_draws = all_draws.flatten()

        top_loc, top_sd = get_mad_normal(top_draws)
        _, all_sd = get_mad_normal(all_draws)

        entry['top_loc'] = top_loc
        entry['top_scale'] = top_sd

        entry['all_scale'] = all_sd 
        
        entry['top_d'] = kstest(top_draws, 'norm', args=(top_loc, top_sd)).statistic
        entry['all_d'] = kstest(all_draws, 'norm', args=(0, all_sd)).statistic

        if np.random.choice(10) == 0 and plot:
            fig, ax = plt.subplots()
            sns.ecdfplot(top_draws, ax=ax, linewidth=3, color='red', alpha=0.2)
            x = np.linspace(-3, 3, 100)
            ax.set_title(f"n_embryos={n_embryos}, n_select={n_select}, pred_corr={pred_corr}")
            ax.plot(x, norm.cdf(x, top_loc, top_sd), color='black', linewidth=1, linestyle='--')
            fig.show()
            time.sleep(1)
            plt.close(fig)
            clear_output(wait=True)

        data.append(entry)

    df = pd.DataFrame(data)
    df.to_csv("parameters.csv", index=False)
    return df






def write_to_json():

    df = pd.read_csv("parameters.csv")

    # Create a nested dictionary structure similar to the one in index.html
    json_data = {}

    # Get unique values for n_embryos
    n_embryos_values = df['n_embryos'].unique()

    # Build the nested structure
    for n_embryos in n_embryos_values:
        json_data[str(n_embryos)] = {}
        
        # Filter for this n_embryos value
        df_embryos = df[df['n_embryos'] == n_embryos]
        
        # Get unique values for n_select for this n_embryos
        n_select_values = df_embryos['n_select'].unique()
        
        for n_select in n_select_values:
            json_data[str(n_embryos)][str(n_select)] = {}
            
            # Filter for this n_select value
            df_select = df_embryos[df_embryos['n_select'] == n_select]
            
            # Get unique values for pred_corr for this combination
            pred_corr_values = df_select['pred_corr'].unique()
            
            for pred_corr in pred_corr_values:
                # Get the row for this specific combination
                row = df_select[df_select['pred_corr'] == pred_corr].iloc[0]
                
                # Add the data to the nested dictionary
                json_data[str(n_embryos)][str(n_select)][str(pred_corr)] = {
                    "top_loc": round(row['top_loc'], 3),
                    "top_scale": round(row['top_scale'], 3),
                    "all_scale": round(row['all_scale'], 3)
                }

    # Write the JSON data to a file
    with open('embryo_parameters.json', 'w') as f:
        json.dump(json_data, f, indent=4)


