import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import os
import pickle
from lib.generate_features import generate_feature_sets
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from matplotlib.lines import Line2D

# Register the Noto Sans SC font
font_path = os.path.join(os.path.dirname(__file__), 'lib', 'Noto_Sans_SC', 'static', 'NotoSansSC-Regular.ttf')
fm.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'Noto Sans SC'

def main(L_val, num_user, window_size, max_cfo, snr, num_of_train_samples, num_of_test_samples):
    K = num_user
    T = 50
    L = L_val
    M = 9
    p_k = 5.0
    p_e = 5.0
    SNR_dB = snr
    num_sample_per_sim = 10  # Number of samples to extract per simulation
    num_of_sim = int( (num_of_train_samples + num_of_test_samples)
                     /num_sample_per_sim)
    test_ratio = num_of_test_samples/(num_of_train_samples + num_of_test_samples)

    window_size = window_size   # Sliding window size for feature calculation

    
    f_cfo_max = max_cfo

    nu = 0.05
    gamma = 5.0

    observed_k = 0

    np.random.seed(42)

    # Generate constant CFOs for all samples
    f_k_set_legit = np.random.rand(K) * f_cfo_max * 2 - f_cfo_max

    # Use generate_feature_sets to get features
    features_legit, features_legit_att, arp_groundtruth = generate_feature_sets(
        K=K,
        T=T,
        L=L,
        M=M,
        SNR_dB=SNR_dB,
        p_k=p_k,
        p_e=p_e,
        f_k_set_legit=f_k_set_legit,
        f_cfo_max=f_cfo_max,
        num_of_leg_sample=num_of_sim,
        num_of_att_sample=num_of_sim,
        state='ear',
        mode='both', 
        window_size=window_size,
        num_sample_per_sim=num_sample_per_sim
    )

    legit_features = features_legit[observed_k]
    attacker_features = features_legit_att[observed_k]

    features_legit_train, features_legit_test = train_test_split(
        legit_features, test_size=0.3, shuffle=True, random_state=42
    )
    _, features_att_test = train_test_split(
        attacker_features, test_size=0.3, shuffle=True, random_state=42
    )

    scaler = MinMaxScaler()
    features_legit_train_scaled = scaler.fit_transform(features_legit_train)
    features_legit_test_scaled = scaler.transform(features_legit_test)
    features_att_test_scaled = scaler.transform(features_att_test)

    ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu)
    ocsvm.fit(features_legit_train_scaled)

    test_features_scaled = np.vstack((features_legit_test_scaled, features_att_test_scaled))
    test_labels = np.concatenate((np.ones(len(features_legit_test)), -np.ones(len(features_att_test))))
    predictions = ocsvm.predict(test_features_scaled)

    accuracy = np.mean(predictions == test_labels) if len(test_labels) > 0 else 0.0
    # Swap PMD and FAR definitions!
    # PMD: attacker samples predicted as legit (false negative for attacker)
    # FAR: legit samples predicted as attacker (false positive for legit)
    pmd = np.mean(predictions[len(features_legit_test):] == 1) if len(features_att_test) > 0 else 0.0
    far = np.mean(predictions[:len(features_legit_test)] == -1) if len(features_legit_test) > 0 else 0.0

    """
    # Plot in original feature space
    plt.scatter(features_att_test[:, 0], features_att_test[:, 1],
                color='red', alpha=0.1, s=int(eta*8), label='合法用户+窃听者')
    plt.scatter(features_legit_test[:, 0], features_legit_test[:, 1],
                color='blue', alpha=0.1, s=int(eta*8), label='合法用户')
    """

    # --- Plot the SVM decision boundary with new method ---
    # 1. Define the scaled region for visualization
    x_min_scaled, x_max_scaled = -5, 15
    y_min_scaled, y_max_scaled = -15, 15

    # 2. Create meshgrid in SCALED space
    h = 0.1
    xx_scaled, yy_scaled = np.meshgrid(
        np.arange(x_min_scaled, x_max_scaled, h),
        np.arange(y_min_scaled, y_max_scaled, h)
    )

    # 3. Compute decision function on SCALED meshgrid
    mesh_scaled = np.c_[xx_scaled.ravel(), yy_scaled.ravel()]
    Z = ocsvm.decision_function(mesh_scaled)
    Z = Z.reshape(xx_scaled.shape)

    # 4. Convert meshgrid to ORIGINAL space
    mesh_original = scaler.inverse_transform(mesh_scaled)
    xx_orig = mesh_original[:, 0].reshape(xx_scaled.shape)
    yy_orig = mesh_original[:, 1].reshape(yy_scaled.shape)

    # 5. Convert axis limits to original space if needed
    scaled_limits = np.array([
        [x_min_scaled, y_min_scaled],
        [x_max_scaled, y_max_scaled]
    ])
    original_limits = scaler.inverse_transform(scaled_limits)

    """
    # 6. Plot the decision boundary in original feature space
    plt.contour(xx_orig, yy_orig, Z, levels=[0], linestyles='-', colors='black', linewidths=1)

    # Plot ground truth CFO and ARP as a star at (ARP, CFO)
    gt_label = f'真实值 (ARP={arp_groundtruth[observed_k]:.2f}, CFO={f_k_set_legit[observed_k]:.4f})'
    plt.scatter(float(arp_groundtruth[observed_k]), float(f_k_set_legit[observed_k]), 
                color='orange', marker='*', s=100, 
                label=gt_label, zorder=5)

    # Create opaque proxy artists for legend so legend markers ignore scatter alpha
    blue_proxy = Line2D([0], [0], marker='o', color='blue', linestyle='', markersize=6,
                        markeredgecolor='none', alpha=1, label='合法用户')
    red_proxy = Line2D([0], [0], marker='o', color='red', linestyle='', markersize=6,
                       markeredgecolor='none', alpha=1, label='合法用户+窃听者')
    decision_proxy = Line2D([0], [0], linestyle='-', color='black', linewidth=2, label='决策边界')
    star_proxy = Line2D([0], [0], marker='*', color='orange', linestyle='', markersize=10, alpha=1, label=gt_label)

    legend_elements = [red_proxy, blue_proxy, decision_proxy, star_proxy]

    plt.xlabel('ARP特征', fontsize=12)
    plt.ylabel('CFO特征', fontsize=12)
    title = (f'窃听下二维特征分布\n'
             f'K={K}, T={window_size}, L={L}, M={M}, p_k={p_k}, p_e={p_e} \n'
             f'SNR={SNR_dB}dB, , CFO_max={f_cfo_max:.6f}\n'
             f'Num Attacker={num_of_sim}, Samples Per Attacker={num_sample_per_sim}, Samples={num_of_sim * num_sample_per_sim}')
    title += f'\nAcc={accuracy:.2%}, MDR={pmd:.2%}, FAR={far:.2%}'
    plt.title(title, fontsize=int(eta*15))
    plt.grid(True, alpha=0.3)
    
    # Use opaque proxy handles for legend (keeps plotted scatter alpha=0.1)
    plt.legend(handles=legend_elements, fontsize=10, prop={'family': 'Noto Sans SC'})
    
    plt.xlim(0, 2500)
    plt.ylim(-0.003, 0.003)
    plt.tight_layout()
    plt.savefig('singleuser_2Dfeature_usblib_decision_ear_both.png', dpi=300)
    """
    with open("./saved_models/ocsvm_temp.pkl", 'wb') as f:
        pickle.dump(ocsvm, f)
    return [features_legit_test, features_att_test, np.ascontiguousarray(xx_orig),
             np.ascontiguousarray(yy_orig), np.ascontiguousarray(Z),
             arp_groundtruth[observed_k], f_k_set_legit[observed_k],
             ocsvm]

if __name__ == '__main__':
    fig, ax = plt.subplots()
    L_val = 150
    num_user = 1
    window_size = 10
    eta = 1200/2000
    main(fig, ax, L_val, num_user, window_size, eta)