import numpy as np
from lib.legit_sample_generate import generate_signals_legitimate
from lib.feature_calculation import calculate_features
from joblib import Parallel, delayed

def experiment_mse(K, T, L, M, snr_db, p_k, f_cfo_max, window_size, f_dfo_max):
    f_k_set_legit = np.random.rand(K) * f_cfo_max * 2 - f_cfo_max
    r_k_t_groundtruth, r_k_t_reconstructed = generate_signals_legitimate(
        K=K, T=T, L=L, M=M, SNR_dB=int(snr_db), p_k=p_k, f_k_set=f_k_set_legit, f_dfo_max=f_dfo_max
    )
    # Get groundtruth ARP and estimated ARP/CFO features
    arp_gt_window, features = calculate_features(r_k_t_groundtruth, r_k_t_reconstructed, window_size)
    arp_est_window = features[..., 0]  # shape (K, N)
    cfo_est_window = features[..., 1]  # shape (K, N)
    N = arp_est_window.shape[1]

    # Groundtruth CFO for each user, broadcast to (K, N)
    f_k_gt_window = np.tile(f_k_set_legit[:, np.newaxis], (1, N))

    # Select only the first user (k=0)
    arp_gt_k0 = arp_gt_window[0]
    arp_est_k0 = arp_est_window[0]
    cfo_est_k0 = cfo_est_window[0]
    f_k_gt_k0 = f_k_gt_window[0]

    # ARP MSE (formula: mean squared error normalized by groundtruth) for k=0
    mse_arp = np.mean((arp_gt_k0 - arp_est_k0)**2 / (L * p_k * 1))

    # CFO MSE (formula: mean squared error normalized by groundtruth CFO) for k=0
    mse_cfo = np.mean((cfo_est_k0 - f_k_gt_k0)**2)
    # mse_cfo = np.mean((cfo_est_k0 - f_k_gt_k0)**2 / (f_k_gt_k0**2))

    return mse_arp, mse_cfo

def main():
    K_values = [1, 2, 4, 8]
    # T = 50
    L = 150
    M = 9
    p_k = 5.0
    snr_db_range = np.arange(0, 30, 5)
    f_CFO_range = 10
    # f_cfo_max = 1 / (L * f_CFO_range)
    f_cfo_max = 0.003  # Max CFO value for simulation

    # Compute f_dfo_max based on the given formula
    velocity = 500  # km/h
    v_ms = velocity / 3.6  # m/s
    f_dfo_max = (v_ms * 6e9 / 3e8) / 30.72e6

    window_size = 10  # Sliding window size for feature calculation
    num_sample_per_sim = 10  # Number of samples to extract per simulation
    num_experiments = 1000000

    required_T = window_size + num_sample_per_sim - 1  # Ensure enough time slots for sliding window


    # Create a list of all experiment configurations
    experiment_configs = []
    for _ in range(num_experiments):
        for K in K_values:
            for snr_db in snr_db_range:
                experiment_configs.append((K, required_T, L, M, snr_db, p_k, f_cfo_max, window_size, f_dfo_max))

    results = Parallel(n_jobs=-1)(
        delayed(experiment_mse)(K, T, L, M, snr_db, p_k, f_cfo_max, window_size, f_dfo_max)
        for K, T, L, M, snr_db, p_k, f_cfo_max, window_size, f_dfo_max in experiment_configs
    )

    # Process the results
    results_arp = {str(K): [] for K in K_values}
    results_cfo = {str(K): [] for K in K_values}

    # Reshape results into a 3D array: (num_experiments, len(K_values), len(snr_db_range), 2)
    # The last dimension holds mse_arp and mse_cfo
    results_array = np.array(results).reshape(
        (num_experiments, len(K_values), len(snr_db_range), 2)
    )

    # Average over the experiments axis (axis 0)
    avg_results = np.mean(results_array, axis=0)

    for i, K in enumerate(K_values):
        results_arp[str(K)] = avg_results[i, :, 0]  # All SNR results for this K for arp
        results_cfo[str(K)] = avg_results[i, :, 1]  # All SNR results for this K for cfo

    output_name = f"arp_cfo_mse_vs_snr_results_windowsize{window_size}_N{num_sample_per_sim}_exp_{num_experiments}_v{velocity}.npz"
    np.savez(output_name,
             snr_db_range=snr_db_range,
             window_size=window_size, L=L, M=M, p_k=p_k, f_cfo_max=f_cfo_max, f_dfo_max=f_dfo_max, num_experiments=num_experiments,
             arp_mse=results_arp, cfo_mse=results_cfo)

if __name__ == '__main__':
    main()