import numpy as np
import os
from lib.generate_features import generate_feature_sets
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import pandas as pd

def evaluate_user_with_aes(user_id, model, scaler, X_legit, X_attacker, ul_sets):
    """Evaluate a single user's model with AES enhancement"""
    # Prepare data
    X_legit_scaled = scaler.transform(X_legit)
    X_attacker_scaled = scaler.transform(X_attacker)
    
    # Create labels: -1 for legitimate, +1 for attacker
    y_legit = -np.ones(X_legit.shape[0])
    y_attacker = np.ones(X_attacker.shape[0])
    
    # Combine data
    X_test_scaled = np.vstack((X_legit_scaled, X_attacker_scaled))
    y_test = np.hstack((y_legit, y_attacker))
    
    # Split for calibration
    X_cal, X_test_final, y_cal, y_test_final = train_test_split(
        X_test_scaled, y_test, test_size=0.8, random_state=42, stratify=y_test
    )
    
    # Calibration for probability
    df_cal = model.decision_function(X_cal)
    is_legit_cal = (y_cal == -1)
    calibrator = LogisticRegression()
    calibrator.fit(df_cal.reshape(-1, 1), is_legit_cal)
    
    # Get decision function, probabilities, and base predictions for final test set
    df_test = model.decision_function(X_test_final)
    prob_legit = calibrator.predict_proba(df_test.reshape(-1, 1))[:, 1]
    predictions_flipped = np.sign(-df_test)  # Flipped: -1 legit, +1 attacker
    
    # Base model performance
    base_acc = accuracy_score(y_test_final, predictions_flipped)
    base_tp = np.sum((y_test_final == 1) & (predictions_flipped == 1))
    base_fn = np.sum((y_test_final == 1) & (predictions_flipped == -1))
    base_tpr = base_tp / (base_tp + base_fn) if (base_tp + base_fn) > 0 else 0
    base_fp = np.sum((y_test_final == -1) & (predictions_flipped == 1))
    base_tn = np.sum((y_test_final == -1) & (predictions_flipped == -1))
    base_fpr = base_fp / (base_fp + base_tn) if (base_fp + base_tn) > 0 else 0
    base_mdr = 1 - base_tpr
    base_fpa = base_fpr
    
    print(f"\nUser {user_id} Base Model Performance:")
    print(f"Accuracy: {base_acc:.4f}")
    print(f"Miss Detection Rate (MDR): {base_mdr:.4f}")
    print(f"False Alarm Rate (PFA): {base_fpa:.4f}")
    
    # Compute metrics for each (u, l) pair
    results = []
    for u, l in ul_sets:
        aes_time = 0
        adjusted_predictions = predictions_flipped.copy()
        
        for i in range(len(adjusted_predictions)):
            r = prob_legit[i]  # No need to multiply by 100 since u,l are now 0-1
            
            if l <= r <= u:
                adjusted_predictions[i] = y_test_final[i]
                aes_time += 1
        
        # Metrics
        acc = accuracy_score(y_test_final, adjusted_predictions)
        tp = np.sum((y_test_final == 1) & (adjusted_predictions == 1))
        fn = np.sum((y_test_final == 1) & (adjusted_predictions == -1))
        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
        fp = np.sum((y_test_final == -1) & (adjusted_predictions == 1))
        tn = np.sum((y_test_final == -1) & (adjusted_predictions == -1))
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
        mdr = 1 - tpr
        fpa = fpr
        aes_rate = aes_time / len(y_test_final)
        complexity_reduction = 1 - aes_rate
        
        results.append({
            'u': u,
            'l': l,
            'base_accuracy': base_acc,
            'accuracy': acc,
            'MDR': mdr,
            'FPA': fpa,
            'reduction': complexity_reduction
        })

    return base_acc, base_mdr, base_fpa, results

def main():
    K = 4
    T = 50
    L = 150
    M = 9
    p_k = 5.0
    p_e = 5.0
    SNR_dB = 20
    num_of_sim = 100
    num_sample_per_sim = 10  # Number of samples to extract per simulation

    window_size = 10  # Sliding window size for feature calculation

    f_CFO_range = 10
    f_cfo_max = 1 / (L * f_CFO_range)
    f_cfo_max = 0.003

    nu = 0.05
    gamma = 5.0

    state = 'ear'
    mode = 'pure'

    observed_k = 0

    np.random.seed(42)

    # Generate constant CFOs for all samples
    f_k_set_legit = np.random.rand(K) * f_cfo_max * 2 - f_cfo_max

    if state == 'auth' and mode == 'both':
        # Generate features for legit and auth attackers (state='auth', mode='both')
        features_legit, features_legit_att_auth, _ = generate_feature_sets(
            K=K,
            T=T,
            L=L,
            M=M,
            SNR_dB=SNR_dB,
            p_k=p_k,
            p_e=p_e,
            f_k_set_legit=f_k_set_legit,
            f_cfo_max=f_cfo_max,
            num_of_leg_sample=num_of_sim,
            num_of_att_sample=num_of_sim,
            state='auth',
            mode='both',
            window_size=window_size,
            num_sample_per_sim=num_sample_per_sim
        )
        
        # Generate ear attackers (state='ear', mode='both')
        _, features_legit_att_ear, _ = generate_feature_sets(
            K=K,
            T=T,
            L=L,
            M=M,
            SNR_dB=SNR_dB,
            p_k=p_k,
            p_e=p_e,
            f_k_set_legit=f_k_set_legit,
            f_cfo_max=f_cfo_max,
            num_of_leg_sample=num_of_sim,
            num_of_att_sample=num_of_sim,
            state='ear',
            mode='both',
            window_size=window_size,
            num_sample_per_sim=num_sample_per_sim
        )

        legit_features = features_legit[observed_k]
        auth_features = features_legit_att_auth[observed_k]
        ear_features = features_legit_att_ear[observed_k]

        # Split legit features
        features_legit_train, features_legit_test = train_test_split(
            legit_features, test_size=0.3, shuffle=True, random_state=42
        )
        
        # Split auth and ear features (train not used for attackers)
        _, auth_test = train_test_split(
            auth_features, test_size=0.3, shuffle=True, random_state=42
        )
        _, ear_test = train_test_split(
            ear_features, test_size=0.3, shuffle=True, random_state=42
        )

        # Mix attacker test set with ratio aligned to original (1 - 1 / int(L/2))
        total_test_att = len(auth_test)
        ratio_auth = 1 - 1 / int(L/2)
        n_auth = int(total_test_att * ratio_auth)
        n_ear = total_test_att - n_auth

        indices_auth = np.random.choice(len(auth_test), n_auth, replace=False)
        indices_ear = np.random.choice(len(ear_test), n_ear, replace=False)

        auth_test_selected = auth_test[indices_auth]
        ear_test_selected = ear_test[indices_ear]

        # For 'both' mode: legit includes auth mimics, attacker is only ear
        features_legit_test_combined = np.vstack((features_legit_test, auth_test_selected))
        features_att_test = ear_test_selected

    else:
        # Original pure mode handling for any state/mode
        features_legit, features_legit_att, arp_groundtruth = generate_feature_sets(
            K=K,
            T=T,
            L=L,
            M=M,
            SNR_dB=SNR_dB,
            p_k=p_k,
            p_e=p_e,
            f_k_set_legit=f_k_set_legit,
            f_cfo_max=f_cfo_max,
            num_of_leg_sample=num_of_sim,
            num_of_att_sample=num_of_sim,
            state=state,
            mode=mode, 
            window_size=window_size,
            num_sample_per_sim=num_sample_per_sim
        )

        legit_features = features_legit[observed_k]
        attacker_features = features_legit_att[observed_k]

        features_legit_train, features_legit_test = train_test_split(
            legit_features, test_size=0.3, shuffle=True, random_state=42
        )
        _, features_att_test = train_test_split(
            attacker_features, test_size=0.3, shuffle=True, random_state=42
        )

        features_legit_test_combined = features_legit_test
        features_att_test = features_att_test

    scaler = MinMaxScaler()
    features_legit_train_scaled = scaler.fit_transform(features_legit_train)

    ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu)
    ocsvm.fit(features_legit_train_scaled)

    # Define thresholds for AES (now in 0-1 range)
    ul_sets = [
        (0.99, 0.4),
        (0.9, 0.3),
        (0.8, 0.3),
    ]
    # ul_sets = [
    #     (0.99, 0.9),
    #     (0.99, 0.8),
    #     (0.99, 0.7),
    #     (0.99, 0.6),
    #     (0.99, 0.5),
    #     (0.99, 0.4),
    #     (0.99, 0.3),
    #     (0.99, 0.2),
    #     (0.99, 0.1),
    #     (0.9, 0.8),
    #     (0.9, 0.7),
    #     (0.9, 0.6),
    #     (0.9, 0.5),
    #     (0.9, 0.4),
    #     (0.9, 0.3),
    #     (0.9, 0.2),
    #     (0.9, 0.1),
    #     (0.8, 0.7),
    #     (0.8, 0.6),
    #     (0.8, 0.5),
    #     (0.8, 0.4),
    #     (0.8, 0.3),
    #     (0.8, 0.2),
    #     (0.8, 0.1),
    #     (0.7, 0.6),
    #     (0.7, 0.5),
    #     (0.7, 0.4),
    #     (0.7, 0.3),
    #     (0.7, 0.2),
    #     (0.7, 0.1),
    #     (0.6, 0.5),
    #     (0.6, 0.4),
    #     (0.6, 0.3),
    #     (0.6, 0.2),
    #     (0.6, 0.1),
    #     (0.5, 0.4),
    #     (0.5, 0.3),
    #     (0.5, 0.2),
    #     (0.5, 0.1),
    #     (0.4, 0.3),
    #     (0.4, 0.2),
    #     (0.4, 0.1),
    #     (0.3, 0.2),
    #     (0.3, 0.1),
    #     (0.2, 0.1),
    # ]

    # Evaluate for observed_k = 0
    print(f"\n--- Evaluating User {observed_k} under state={state}, mode={mode}, L = {L} ---")
    X_legit = features_legit_test_combined
    X_attacker = features_att_test
    base_acc, base_mdr, base_fpa, user_results = evaluate_user_with_aes(
        observed_k, ocsvm, scaler, X_legit, X_attacker, ul_sets
    )

    # Display AES results for this user
    df_user_results = pd.DataFrame(user_results)
    print(f"\nUser {observed_k} AES Performance:")
    print(df_user_results)

if __name__ == '__main__':
    main()