import numpy as np
from lib.zcseq import generate_zc_pilot_matrix

def generate_signals_legitimate_and_attacker(K, T, L, M, SNR_dB, p_k, p_e, f_k_set, f_cfo_max, state):
    """
    Generate signals for legitimate users and one attacker targeting user 0 (NumPy vectorized).
    
    Args:
        K: Number of users
        T: Number of OFDM symbols
        L: Number of subcarriers
        M: Oversampling factor
        SNR_dB: Signal-to-noise ratio in dB
        p_k: Legitimate user power
        p_e: Attacker power
        f_k_set: Legitimate user CFOs, shape (K,)
        state: Attack state - 'ear' (eavesdropping) or 'auth' (authentication)
        f_cfo_max: Maximum CFO for attacker
        
    Returns:
        r_k_t_groundtruth: Ground truth signals (with attacker on user 0), shape (K, T, L)
        r_k_t_reconstructed: Reconstructed signals (mixed), shape (K, T, L)
        r_k_t_attacker_reconstructed: Reconstructed attacker signals, shape (K, T, L)
    """
    # User indices
    all_even_indices = np.arange(0, L, 2)

    if K == 1:
        user_ind = np.array([0])
    else:
        step = len(all_even_indices) / K
        user_ind = all_even_indices[np.round(np.arange(0, K) * step).astype(int)]
    
    ZC_matrix = generate_zc_pilot_matrix(L, u=1)
    vk_set = np.stack([ZC_matrix] * T, axis=0)  # Shape (T, L, L)
    vk_selected = vk_set[:, :, user_ind]         # (T, L, K)
    
    # Pilot for single attacker targeting user 0 - based on state
    if state == 'ear':
        # Attacker uses the same pilot as user 0 (eavesdropping)
        vk_attacker = vk_selected[:, :, 0:1]  # (T, L, 1)
    elif state == 'auth':
        # Attacker uses randomly selected pilot different from user 0 (authentication attack)
        possible_indices = [i for i in all_even_indices if i != user_ind[0]]
        if len(possible_indices) == 0:
            raise ValueError("No available indices different from user 0.")
        random_index = np.random.choice(possible_indices)
        vk_attacker = vk_set[:, :, random_index:random_index+1]  # (T, L, 1)
    else:
        raise ValueError(f"Invalid state: {state}. Must be 'ear' or 'auth'")
    
    # CFO for single attacker
    f_e = np.random.uniform(-f_cfo_max, f_cfo_max)
    
    # Time/subcarrier indices
    l_time = np.arange(T)[:, np.newaxis]         # (T, 1)
    l_subcarrier = np.arange(L)[np.newaxis, :]   # (1, L)
    l = l_time * L + l_subcarrier                # (T, L)
    
    # Legitimate phase
    phase_legit = 2j * np.pi * f_k_set[np.newaxis, np.newaxis, :] * l[:, :, np.newaxis]
    wk_set_legit = np.exp(phase_legit)           # (T, L, K)
    # Attacker phase (scalar f_e broadcasts)
    phase_attacker = 2j * np.pi * f_e * l[:, :, np.newaxis]
    wk_set_attacker = np.exp(phase_attacker)     # (T, L, 1)
    
    # Channel coefficients
    h_k_t = (np.random.randn(T, K) + 1j * np.random.randn(T, K)) / np.sqrt(2)  # (T, K)
    h_e_t = (np.random.randn(T, 1) + 1j * np.random.randn(T, 1)) / np.sqrt(2)  # (T, 1)
    h_legit = h_k_t[:, np.newaxis, :]      # (T, 1, K)
    h_attacker = h_e_t[:, np.newaxis, :]   # (T, 1, 1)
    
    # Legitimate signals
    r_legit = np.sqrt(L * p_k) * vk_selected * wk_set_legit * h_legit  # (T, L, K)
    
    # Attacker signal (T, L, 1)
    r_attacker = np.sqrt(L * p_e) * vk_attacker * wk_set_attacker * h_attacker  # (T, L, 1)
    
    # Pad attacker to (T, L, K) with zeros for other users
    r_attacker_padded = np.zeros((T, L, K), dtype=complex)
    r_attacker_padded[:, :, 0] = r_attacker.squeeze(axis=2)  # Assign to user 0 dimension
    
    # Combined signals
    r_k_t = r_legit + r_attacker_padded  # (T, L, K)
    r_t_without_noise = np.sum(r_k_t, axis=2)  # (T, L)

    # Noise
    r_attacker_sum = np.sum(r_attacker, axis=2)  # (T, L)
    r_t_power = np.mean(np.abs(r_t_without_noise)**2, axis=1)  # (T,)
    snr_linear = 10**(SNR_dB / 10)
    noise_power = r_t_power / snr_linear          # (T,)
    noise_power_broadcast = noise_power[:, np.newaxis]
    noise = np.sqrt(noise_power_broadcast / 2) * (np.random.randn(T, L) + 1j * np.random.randn(T, L))
    
    # Add noise to all signals
    r_attacker_sum_noisy = r_attacker_sum + noise
    r_t = r_t_without_noise + noise
    
    # Reshape to (K, T, L) for consistency
    r_k_t_groundtruth = np.transpose(r_k_t, (2, 0, 1))  # (K, T, L)

    # Vectorized reconstruction for all users (K, T, L)
    vk_selected_conj = np.conj(np.transpose(vk_set[:, :, user_ind], (2, 0, 1)))  # (K, T, L)
    r_t_broadcast = np.broadcast_to(r_t.reshape(1, T, L), (K, T, L))
    s_k = r_t_broadcast * vk_selected_conj * np.sqrt(L)  # (K, T, L)

    # Fourier Transform and Low-Pass Filtering
    S_k = np.fft.fft(s_k, M * L, axis=2)  # (K, T, M*L)
    LPF_k = S_k.copy()
    start_idx = int(np.floor(M))
    end_idx = M * L - int(np.floor(M))
    LPF_k[:, :, start_idx-1:end_idx] = 0

    # Inverse FFT and take first L subcarriers
    lpf_k = np.fft.ifft(LPF_k, axis=2)  # (K, T, M*L)
    r_k_t_reconstructed = lpf_k[:, :, :L]  # (K, T, L)


    # Vectorized attacker-only reconstruction (K, T, L)
    r_attacker_sum_noisy_broadcast = np.broadcast_to(r_attacker_sum_noisy.reshape(1, T, L), (K, T, L))
    s_k_att = r_attacker_sum_noisy_broadcast * vk_selected_conj * np.sqrt(L)  # (K, T, L)

    # Fourier Transform and Low-Pass Filtering for attacker
    S_k_att = np.fft.fft(s_k_att, M * L, axis=2)  # (K, T, M*L)
    LPF_k_att = S_k_att.copy()
    LPF_k_att[:, :, start_idx-1:end_idx] = 0

    # Inverse FFT and take first L subcarriers for attacker
    lpf_k_att = np.fft.ifft(LPF_k_att, axis=2)  # (K, T, M*L)
    r_k_t_attacker_reconstructed = lpf_k_att[:, :, :L]  # (K, T, L)

    return r_k_t_groundtruth, r_k_t_reconstructed, r_k_t_attacker_reconstructed