import numpy as np
import matlab
from joblib import Parallel, delayed
from lib.legit_sample_generate import generate_signals_legitimate
from lib.legit_Att_sample_generate import generate_signals_legitimate_and_attacker
from lib.feature_calculation import calculate_features

multi_processing = False
cores = 2 # Only effective if multi_processing is True

def generate_feature_sets(K, T, L, M, SNR_dB, p_k, p_e, f_k_set_legit, f_cfo_max, 
                          num_of_leg_sample, num_of_att_sample, state, mode, 
                          window_size, num_sample_per_sim):
    
    """
    Generate three sets of features: legitimate, legitimate with attacker, and pure attacker.
    
    Parameters:
    -----------
    K : int
        Number of users
    T : int
        Number of features desired for output
    L : int
        Parameter L
    M : int
        Parameter M
    SNR_dB : float
        Signal-to-Noise Ratio in dB
    p_k : float or array
        Power parameter for legitimate users
    p_e : float
        Power parameter for attacker
    f_cfo_max : float
        Maximum CFO value
    num_of_leg_sample : int
        Number of legitimate samples to generate
    num_of_att_sample : int
        Number of attacker samples to generate
    state : str
        State parameter, either 'ear' or 'auth'
    mode : str
        Mode parameter, either 'both' or 'pure'
    window_size : int
        Size of the sliding window for feature calculation
    num_sample_per_sim : int, optional
        Number of samples to extract per simulation. If None, defaults to T.
    
    Returns:
    --------
    features_legit : ndarray
        Features of legitimate signals
    features_legit_att : ndarray
        Features of legitimate signals with attacker, returned if mode is 'both'
    att_pure_feature : ndarray
        Features of pure attacker signals, returned if mode is 'pure'
    """

    # Generate random CFO values for legitimate users
    f_k_set_legit = f_k_set_legit
    
    # Calculate required time slots for signal generation
    # We need window_size + num_sample_per_sim - 1 time slots to generate num_sample_per_sim features with sliding window
    required_T = window_size + num_sample_per_sim - 1
    
    # Initialize arrays for features
    total_leg_samples = num_of_leg_sample * num_sample_per_sim
    total_att_samples = num_of_att_sample * num_sample_per_sim

    features_legit = np.zeros((K, total_leg_samples, 2))
    features_legit_att = np.zeros((K, total_att_samples, 2))
    att_pure_feature = np.zeros((K, total_att_samples, 2))
    
    # Initialize storage for all arp_groundtruth values
    arp_groundtruth_all = np.zeros((K, total_leg_samples))

    # Parallel feature generation for legitimate signals
    def legit_worker(i):
        r_k_t_groundtruth, r_k_t_reconstructed = generate_signals_legitimate(
            K, required_T, L, M, SNR_dB, p_k, f_k_set_legit
        )
        arp_groundtruth, features = calculate_features(r_k_t_groundtruth, r_k_t_reconstructed, window_size)
        return arp_groundtruth[:, :num_sample_per_sim], features[:, :num_sample_per_sim, :]

    if multi_processing:
        legit_results = Parallel(n_jobs=cores)(
            delayed(legit_worker)(i) for i in range(num_of_leg_sample)
        )
    else:
        legit_results = []
        for i in range(num_of_leg_sample):
            res = legit_worker(i)
            legit_results.append(res)


    for i, (arp_groundtruth, features) in enumerate(legit_results):
        idx = slice(i * num_sample_per_sim, (i + 1) * num_sample_per_sim)
        arp_groundtruth_all[:, idx] = arp_groundtruth
        features_legit[:, idx, :] = features

    arp_groundtruth_mean = np.mean(arp_groundtruth_all, axis=1)

    # Parallel feature generation for attacker signals
    def att_worker(i):
        r_k_t_groundtruth, r_k_t_reconstructed, r_k_t_attacker_reconstructed = generate_signals_legitimate_and_attacker(
            K, required_T, L, M, SNR_dB, p_k, p_e, f_k_set_legit, f_cfo_max, state=state
        )
        _, features = calculate_features(r_k_t_groundtruth, r_k_t_reconstructed, window_size)
        _, att_features = calculate_features(r_k_t_groundtruth, r_k_t_attacker_reconstructed, window_size)
        return features[:, :num_sample_per_sim, :], att_features[:, :num_sample_per_sim, :]

    if multi_processing:
        att_results = Parallel(n_jobs=cores)(
        delayed(att_worker)(i) for i in range(num_of_att_sample)
        )
    else:
        att_results = []
        for i in range(num_of_att_sample):
            res = att_worker(i)
            att_results.append(res)

    
    for i, (features, att_features) in enumerate(att_results):
        idx = slice(i * num_sample_per_sim, (i + 1) * num_sample_per_sim)
        features_legit_att[:, idx, :] = features
        att_pure_feature[:, idx, :] = att_features

    if mode == 'both':
        return features_legit, features_legit_att, arp_groundtruth_mean
    elif mode == 'pure':
        return features_legit, att_pure_feature, arp_groundtruth_mean
    else:
        raise ValueError("Invalid mode. Choose either 'both' or 'pure'.")
