import numpy as np
import matplotlib.pyplot as plt
from lib.legit_sample_generate import generate_signals_legitimate
import matplotlib.font_manager as fm
import os
from joblib import Parallel, delayed

# Register the Noto Sans SC font
font_path = os.path.join(os.path.dirname(__file__), 'lib', 'Noto_Sans_SC', 'static', 'NotoSansSC-Regular.ttf')
fm.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'Noto Sans SC'

def experiment_rmse(K, T, L, M, snr_db, p_k, f_k_set_legit):
    r_k_t_groundtruth, r_k_t_reconstructed = generate_signals_legitimate(
        K=K, T=T, L=L, M=M, SNR_dB=int(snr_db), p_k=p_k, f_k_set=f_k_set_legit
    )
    # Vectorized relative RMSE calculation
    diff = r_k_t_groundtruth - r_k_t_reconstructed  # shape: (K, T, L)
    gt_norm_sq = np.linalg.norm(r_k_t_groundtruth, axis=2) ** 2  # shape: (K, T)
    diff_norm_sq = np.linalg.norm(diff, axis=2) ** 2  # shape: (K, T)
    relative_mse = diff_norm_sq / gt_norm_sq  # shape: (K, T)
    mse_avg = np.mean(relative_mse, axis=1)  # mean over T for each user
    return mse_avg  # shape: (K,)

def main(T, L_val, K, f_cfo_max, n_samples, snr_db_range):
    # Set parameters here
    T = int(T)
    L = int(L_val)
    K = int(K)
    f_cfo_max = int(f_cfo_max)
    n_samples = int(n_samples)
    snr_db_range = snr_db_range
    M = 9
    p_k = 5.0

    num_experiments = n_samples
    avg_rmse_list = []
    f_k_set_legit = np.zeros(K)
    
    for snr_db in snr_db_range:
        # Parallelize experiments
        rmse_all = Parallel(n_jobs=1)(
            delayed(experiment_rmse)(K, T, L, M, snr_db, p_k, f_k_set_legit)
            for _ in range(num_experiments)
        )
        rmse_all = np.array(rmse_all)  # shape (num_experiments, K)
        avg_rmse = np.mean(rmse_all)   # average over all users and experiments
        avg_rmse_list.append(avg_rmse/K)

    return np.array(avg_rmse_list)

    
    
if __name__ == '__main__':
    L_val = 150
    T = 10
    K = 4
    f_cfo_max = 0.003
    n_samples = 1000
    snr_dB_range = [0,5,10,15,20,25]
    out = main(T,L_val,K,f_cfo_max,n_samples,snr_dB_range)
    print(out)