import numpy as np
from lib.zcseq import generate_zc_pilot_matrix

def generate_signals_legitimate(K, T, L, M, SNR_dB, p_k, f_k_set):
    """
    Redesigned generate_signals with optimized shapes for NumPy efficiency.
    Now uses h_k_t (shape (K, T)) for time-varying channels per user.
    Args:
        K: Number of users
        T: Number of time slots  
        L: Length of pilot (subcarriers)
        M: Oversampling factor for FFT
        SNR_dB: Signal-to-noise ratio in dB
        p_k: Power of legitimate user
    Returns:
        r_k_t_groundtruth: np.ndarray of shape (K, T, L) - ground truth
        r_k_t_reconstructed: np.ndarray of shape (K, T, L) - reconstructed
    """
    # User indices
    all_even_indices = np.arange(0, L, 2)

    if K == 1:
        user_ind = np.array([0])
    else:
        step = len(all_even_indices) / K
        user_ind = all_even_indices[np.round(np.arange(0, K) * step).astype(int)]
        
    pk = p_k  # MATLAB uses pk
    
    # # Generate DFT sequences: shape (L, L)
    # DFT_matrix = np.fft.fft(np.eye(L)) / np.sqrt(L)
    
    # # Generate vk_set: shape (T, L, L) - stack DFT_matrix T times along time axis
    # vk_set = np.stack([DFT_matrix] * T, axis=0)  # Efficient 3D array

    ZC_matrix = generate_zc_pilot_matrix(L, u=1)  # Use u=1 or your preferred root

    # Then, for vk_set:
    vk_set = np.stack([ZC_matrix] * T, axis=0)  # Shape (T, L, L)
        
    # Generate l: shape (T, L) - 2D grid for time and subcarrier indices
    l_time = np.arange(T)[:, np.newaxis]  # Shape (T, 1)
    l_subcarrier = np.arange(L)[np.newaxis, :]  # Shape (1, L)
    l = l_time * L + l_subcarrier  # Shape (T, L) - matches original flattened order
    
    # Generate wk_set: shape (T, L, K) - broadcasting for phase shifts
    phase = 2j * np.pi * f_k_set[np.newaxis, np.newaxis, :] * l[:, :, np.newaxis]
    wk_set = np.exp(phase)
    
    # Generate channel coefficients: h_k_t=(randn(K,T)+1i*randn(K,T))/sqrt(2), shape (K, T)
    h_k_t = (np.random.randn(T, K) + 1j * np.random.randn(T, K)) / np.sqrt(2)
    
    # Generate signal: vectorized over (T, L, K)
    # vk_selected: shape (T, L, K) - select subcarriers for each user
    vk_selected = vk_set[:, :, user_ind]  # Advanced indexing
    
    # Broadcast h_k_t to (T, 1, K) for multiplication: h_k_t.T is (T, K), then add axis for (T, 1, K)
    h_broadcast = h_k_t[:, np.newaxis, :]  # Shape (T, 1, K)
    
    # signal: shape (T, L) - sum over users (axis=2)
    r_k_t = np.sqrt(L * pk) * vk_selected * wk_set * h_broadcast  # Shape (T, L, K)
    r_t_without_noise = np.sum(r_k_t, axis=2)  # Shape (T, L)

    # Add noise
    r_t_power = np.mean(np.abs(r_t_without_noise)**2, axis=1)  # Power per time slot
    snr_linear = 10**(SNR_dB / 10)
    noise_power = r_t_power / snr_linear  # shape (T,)
    noise_power_broadcast = noise_power[:, np.newaxis]  # shape (T, 1) for broadcasting
    noise = np.sqrt(noise_power_broadcast / 2) * (np.random.randn(T, L) + 1j * np.random.randn(T, L))
    r_t = r_t_without_noise + noise  # Shape (T*L,)
    
    # Vectorized ground truth: r_k_t_groundtruth, shape (K, T, L)
    r_k_t_groundtruth = np.transpose(r_k_t, (2, 0, 1))  # Shape (K, T, L)
    
    # Vectorized reconstruction: r_k_t_reconstructed, shape (K, T, L)
    # Prepare vk_selected_conj: shape (K, T, L)
    vk_selected_conj = np.conj(np.transpose(vk_set[:, :, user_ind], (2, 0, 1)))  # (K, T, L)

    # r_t: shape (T, L) -> (K, T, L) by broadcasting
    r_t_broadcast = np.broadcast_to(r_t.reshape(1, T, L), (K, T, L))

    s_k = r_t_broadcast * vk_selected_conj * np.sqrt(L)  # (K, T, L)
    # FFT along last axis (subcarrier axis)
    S_k = np.fft.fft(s_k, M * L, axis=2)  # (K, T, M*L)

    LPF_k = S_k.copy()

    # Apply low-pass filter
    start_idx = int(np.floor(M))
    end_idx = M * L - int(np.floor(M))
    LPF_k[:, :, start_idx-1:end_idx] = 0  # Zero out along subcarrier axis

    # Inverse FFT to get time domain signal
    lpf_k = np.fft.ifft(LPF_k, axis=2)  # (K, T, M*L)

    r_k_t_reconstructed = lpf_k[:, :, :L]  # (K, T, L)

        # --- Apply low-pass filter to r_k_t (ground truth) ---
    # Conjugate with vk_selected_full (K, T, L)
    r_k_t_conj = r_k_t_groundtruth * vk_selected_conj * np.sqrt(L)  # (K, T, L)
    r_k_t_fft = np.fft.fft(r_k_t_conj, M * L, axis=2)  # FFT along subcarrier axis (L -> M*L)
    r_k_t_lpf = r_k_t_fft.copy()
    start_idx = int(np.floor(M))
    end_idx = M * L - int(np.floor(M))
    r_k_t_lpf[:, :, start_idx-1:end_idx] = 0  # Zero out along subcarrier axis

    # Inverse FFT to get time domain signal
    r_k_t_lpf_time = np.fft.ifft(r_k_t_lpf, axis=2)  # (K, T, M*L)
    r_k_t_lpf_final = r_k_t_lpf_time[:, :, :L]  # (K, T, L)


    return r_k_t_lpf_final, r_k_t_reconstructed