import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import os
import importlib
import pickle
from lib import generate_features
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from matplotlib.lines import Line2D
from sklearn.model_selection import ParameterGrid  # Added for grid search
importlib.reload(generate_features)
generate_feature_sets = generate_features.generate_feature_sets
# Register the Noto Sans SC font
font_path = os.path.join(os.path.dirname(__file__), 'lib', 'Noto_Sans_SC', 'static', 'NotoSansSC-Regular.ttf')
fm.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'Noto Sans SC'

def main(L_val, num_user, window_size, max_cfo, snr, num_of_train_samples, num_of_test_samples):
    K = num_user
    T = 50
    L = L_val
    M = 9
    p_k = 5.0
    p_e = 5.0
    SNR_dB = snr
    num_sample_per_sim = 10  # Number of samples to extract per simulation
    num_of_sim = int( (num_of_train_samples + num_of_test_samples)
                     /num_sample_per_sim)
    test_ratio = num_of_test_samples/(num_of_train_samples + num_of_test_samples)

    window_size = window_size  # Sliding window size for feature calculation
    T_total = num_sample_per_sim - 1 + window_size

    
    f_cfo_max = max_cfo

    # Original nu and gamma (for fallback)
    nu = 0.05
    gamma = 5.0

    observed_k = 0 # Eaves-dropper always use K = 0 as its pilot sequence

    np.random.seed(42)

    # Generate constant CFOs for all samples
    f_k_set_legit = np.random.rand(K) * f_cfo_max * 2 - f_cfo_max
    # f_k_set_legit = f_k_set_legit*0 + 0.05
    """"""

    # Generate features for legit and auth attackers
    features_legit, features_legit_att_auth, arp_groundtruth = generate_feature_sets(
        K=K,
        T=T,
        L=L,
        M=M,
        SNR_dB=SNR_dB,
        p_k=p_k,
        p_e=p_e,
        f_k_set_legit=f_k_set_legit,
        f_cfo_max=f_cfo_max,
        num_of_leg_sample=num_of_sim,
        num_of_att_sample=num_of_sim,
        state='auth',
        mode='both', 
        window_size=window_size,
        num_sample_per_sim=num_sample_per_sim
    )

    # Generate features for ear attackers (reuse features_legit, but regenerate for consistency in randomness)
    _, features_legit_att_ear, _ = generate_feature_sets(
        K=K,
        T=T,
        L=L,
        M=M,
        SNR_dB=SNR_dB,
        p_k=p_k,
        p_e=p_e,
        f_k_set_legit=f_k_set_legit,
        f_cfo_max=f_cfo_max,
        num_of_leg_sample=num_of_sim,
        num_of_att_sample=num_of_sim,
        state='ear',
        mode='both', 
        window_size=window_size,
        num_sample_per_sim=num_sample_per_sim
    )

    legit_features = features_legit[observed_k]
    features_att_full_auth = features_legit_att_auth[observed_k]
    features_att_full_ear = features_legit_att_ear[observed_k]

    features_legit_train, features_legit_test = train_test_split(
        legit_features, test_size=test_ratio, shuffle=True, random_state=42
    )
    _, features_att_test_auth = train_test_split(
        features_att_full_auth, test_size=test_ratio, shuffle=True, random_state=42
    )
    _, features_att_test_ear = train_test_split(
        features_att_full_ear, test_size=test_ratio, shuffle=True, random_state=42
    )

    # Mix att_test with 50:0 auth:ear ratio (all auth)
    total_test_att = len(features_att_test_auth)
    ratio_auth = 1 - K / int(L/2) # int(L/2) is number of possible indices.
                                # Only when attacker chooses Vk of a legitimate user,
                                # an attack occurs. This probability is  1 / int(L/2)
    """"""

    n_auth = int(total_test_att * ratio_auth)
    n_ear = total_test_att - n_auth

    indices_auth = np.random.choice(len(features_att_test_auth), n_auth, replace=False)
    indices_ear = np.random.choice(len(features_att_test_ear), n_ear, replace=False)

    mixed_att_test = np.vstack((
        features_att_test_auth[indices_auth],
        features_att_test_ear[indices_ear]
    ))

    # For visualization: treat auth as legit (blue), ear as red
    auth_test_selected = features_att_test_auth[indices_auth]
    ear_test_selected = features_att_test_ear[indices_ear]
    legit_plus_auth_test = np.vstack((features_legit_test, auth_test_selected))

    scaler = MinMaxScaler()
    features_legit_train_scaled = scaler.fit_transform(features_legit_train)
    features_legit_test_scaled = scaler.transform(features_legit_test)
    mixed_att_test_scaled = scaler.transform(mixed_att_test)

    # --- Hyperparameter tuning with grid search ---
    param_grid = {
        'nu': [0.01, 0.05, 0.1, 0.2],  # Common range for nu (outlier fraction)
        'gamma': ['scale', 0.1, 1.0, 5.0, 10.0]  # RBF kernel width; 'scale' is 1/(n_features * X.var())
    }
    
    grid = ParameterGrid(param_grid)
    best_score = float('inf')
    best_params = {'nu': nu, 'gamma': gamma}  # Fallback to original
    mdrs = []
    fars = []
    scores = []
    param_combos = []
    
    # Weights for MDR and FAR (more weight to MDR)
    mdr_weight = 1.0
    far_weight = 1.2
    
    n_legit_test = len(features_legit_test)
    n_legit_plus_auth = n_legit_test + n_auth
    
    for params in grid:
        ocsvm_temp = OneClassSVM(kernel='rbf', gamma=params['gamma'], nu=params['nu'])
        ocsvm_temp.fit(features_legit_train_scaled)
        
        test_features_scaled = np.vstack((features_legit_test_scaled, mixed_att_test_scaled))
        test_labels = np.concatenate((
            np.ones(n_legit_test),
            np.ones(n_auth),
            np.full(n_ear, -1)
        ))
        predictions = ocsvm_temp.predict(test_features_scaled)
        
        mdr = np.mean(predictions[n_legit_plus_auth:] == 1) if n_ear > 0 else 0.0
        far = np.mean(predictions[:n_legit_plus_auth] == -1) if n_legit_plus_auth > 0 else 0.0
        score = mdr_weight * mdr + far_weight * far
        
        mdrs.append(mdr)
        fars.append(far)
        scores.append(score)
        param_combos.append(params)
        
        if score < best_score:
            best_score = score
            best_params = params
    
    # Use best parameters
    nu = best_params['nu']
    gamma = best_params['gamma']
    
    print(f"Best parameters: nu={nu}, gamma={gamma}")
    print(f"Best score ({mdr_weight}*MDR + {far_weight}*FAR): {best_score:.4f}")
    print("All tested:")
    for i, (mdr_val, far_val, score_val, params) in enumerate(zip(mdrs, fars, scores, param_combos)):
        print(f"  nu={params['nu']}, gamma={params['gamma']}: MDR={mdr_val:.4f}, FAR={far_val:.4f}, score={score_val:.4f}")

    # Train final model with best params
    ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu)
    ocsvm.fit(features_legit_train_scaled)

    test_features_scaled = np.vstack((features_legit_test_scaled, mixed_att_test_scaled))
    n_legit_test = len(features_legit_test)
    test_labels = np.concatenate((
        np.ones(n_legit_test),
        np.ones(n_auth),
        np.full(n_ear, -1)
    ))
    predictions = ocsvm.predict(test_features_scaled)

    accuracy = np.mean(predictions == test_labels) if len(test_labels) > 0 else 0.0

    # Separate metrics
    n_legit_plus_auth = n_legit_test + n_auth
    far = np.mean(predictions[:n_legit_plus_auth] == -1) if n_legit_plus_auth > 0 else 0.0
    mdr = np.mean(predictions[n_legit_plus_auth:] == 1) if n_ear > 0 else 0.0
    """
    # Plot in original feature space
    plt.scatter(legit_plus_auth_test[:, 0], legit_plus_auth_test[:, 1],
                color='blue', alpha=0.1, s=int(eta*8), label='合法用户')
    plt.scatter(ear_test_selected[:, 0], ear_test_selected[:, 1],
                color='red', alpha=0.1, s=int(eta*8), label='合法用户+攻击者')
    """
    # --- Plot the SVM decision boundary with new method ---
    # 1. Define the scaled region for visualization
    x_min_scaled, x_max_scaled = -0.1, 1.1
    y_min_scaled, y_max_scaled = -0.1, 1.1

    # 2. Create meshgrid in SCALED space
    h = 0.02
    xx_scaled, yy_scaled = np.meshgrid(
        np.arange(x_min_scaled, x_max_scaled, h),
        np.arange(y_min_scaled, y_max_scaled, h)
    )

    # 3. Compute decision function on SCALED meshgrid
    mesh_scaled = np.c_[xx_scaled.ravel(), yy_scaled.ravel()]
    Z = ocsvm.decision_function(mesh_scaled)
    Z = Z.reshape(xx_scaled.shape)

    # 4. Convert meshgrid to ORIGINAL space
    mesh_original = scaler.inverse_transform(mesh_scaled)
    xx_orig = mesh_original[:, 0].reshape(xx_scaled.shape)
    yy_orig = mesh_original[:, 1].reshape(xx_scaled.shape)
    """
    # 6. Plot the decision boundary in original feature space
    plt.contour(xx_orig, yy_orig, Z, levels=[0], linestyles='-', colors='black', linewidths=1)

    # Plot ground truth CFO and ARP as a star at (ARP, CFO) and create opaque legend proxies
    gt_label = f'真实值 (ARP={arp_groundtruth[observed_k]:.2f}, CFO={f_k_set_legit[observed_k]:.4f})'
    plt.scatter(float(arp_groundtruth[observed_k]), float(f_k_set_legit[observed_k]),
                color='orange', marker='*', s=100,
                label=gt_label, zorder=5)

    # Create opaque proxy artists for legend (scatter points keep alpha=0.1 on the plot)
    blue_proxy = Line2D([0], [0], marker='o', color='blue', linestyle='', markersize=6,
                        markeredgecolor='none', alpha=1, label='合法用户')
    red_proxy = Line2D([0], [0], marker='o', color='red', linestyle='', markersize=6,
                       markeredgecolor='none', alpha=1, label='合法用户+攻击者')
    decision_proxy = Line2D([0], [0], linestyle='-', color='black', linewidth=2, label='决策边界')
    star_proxy = Line2D([0], [0], marker='*', color='orange', linestyle='', markersize=10, alpha=1, label=gt_label)

    legend_elements = [red_proxy, blue_proxy, decision_proxy, star_proxy]

    plt.xlabel('ARP特征', fontsize=12)
    plt.ylabel('CFO特征', fontsize=12)
    title = (f'认证下二维特征分布\n'
             f'K={K}, T={window_size}, L={L}, M={M}, p_k={p_k}, p_e={p_e} \n'
             f'SNR={SNR_dB}dB, , CFO_max={f_cfo_max:.6f}\n'
             f'Num Attacker={num_of_sim}, Samples Per Attacker={num_sample_per_sim}, Samples={num_of_sim * num_sample_per_sim}')
    title += f'\nAcc={accuracy:.2%}, MDR={mdr:.2%}, FAR={far:.2%}'
    # title += f'\nBest nu={nu}, gamma={gamma}'
    plt.title(title, fontsize=int(eta*15))
    plt.grid(True, alpha=0.3)
    
    # Use opaque proxy handles for legend so markers appear opaque there while plotted points remain transparent
    plt.legend(handles=legend_elements, fontsize=10, prop={'family': 'Noto Sans SC'})
     
    plt.xlim(0, 1500)
    plt.ylim(-0.003, 0.003)
    
    plt.tight_layout()
    plt.savefig('singleuser_2Dfeature_usblib_decision_auth_both.png', dpi=300)
    """

    with open("./saved_models/ocsvm_temp.pkl", 'wb') as f:
        pickle.dump(ocsvm, f)
    return [legit_plus_auth_test, ear_test_selected, np.ascontiguousarray(xx_orig),
             np.ascontiguousarray(yy_orig), np.ascontiguousarray(Z),
             arp_groundtruth[observed_k], f_k_set_legit[observed_k]]
    

if __name__ == '__main__':
    L_val = 150
    num_user = 1
    window_size = 10
    num_train_samples = 1000
    num_test_samples = 600
    main(L_val, num_user, window_size, num_train_samples, num_test_samples)