import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import joblib

FIG_WIDTH = 10
FIG_HEIGHT = 7

def calculate_hold_difficulty_scores(problems_json):
    """Calculate difficulty score for each hold based on problems that use it"""
    hold_scores = {}
    hold_usage_count = {}
    
    # Grade to numeric mapping (simplified)
    grade_to_numeric = {
        '6A': 1, '6A+': 2, '6B': 3, '6B+': 4, '6C': 5, '6C+': 6,
        '7A': 7, '7A+': 8, '7B': 9, '7B+': 10, '7C': 11, '7C+': 12,
        '8A': 13, '8A+': 14, '8B': 15, '8B+': 16, '8C': 17, '8C+': 18
    }
    
    # Collect all hold usage and grades
    for problem_id, problem in problems_json.items():
        grade = problem.get('Grade', '6A')
        grade_numeric = grade_to_numeric.get(grade, 1)
        
        for move in problem.get('Moves', []):
            hold_pos = move.get('Description', '')
            if hold_pos:
                if hold_pos not in hold_scores:
                    hold_scores[hold_pos] = 0
                    hold_usage_count[hold_pos] = 0
                
                hold_scores[hold_pos] += grade_numeric
                hold_usage_count[hold_pos] += 1
    
    # Calculate average difficulty for each hold
    hold_difficulty = {}
    for hold_pos, total_score in hold_scores.items():
        usage_count = hold_usage_count[hold_pos]
        hold_difficulty[hold_pos] = total_score / usage_count if usage_count > 0 else 1
    
    return hold_difficulty

def extract_features(problem, hold_difficulty_scores=None):
    holds = problem.get('Moves', [])
    
    # Basic spatial features
    features = {
        'num_holds': len(holds),
        'avg_x': np.mean([ord(h['Description'][0]) - ord('A') for h in holds]) if holds else 0,
        'avg_y': np.mean([18 - int(h['Description'][1:]) for h in holds]) if holds else 0,
        'std_x': np.std([ord(h['Description'][0]) - ord('A') for h in holds]) if holds else 0,
        'std_y': np.std([18 - int(h['Description'][1:]) for h in holds]) if holds else 0,
    }
    
    # Add hold difficulty features if available
    if hold_difficulty_scores and holds:
        hold_difficulties = []
        for hold in holds:
            hold_pos = hold.get('Description', '')
            difficulty = hold_difficulty_scores.get(hold_pos, 1)
            hold_difficulties.append(difficulty)
        
        features.update({
            'avg_hold_difficulty': np.mean(hold_difficulties),
            'max_hold_difficulty': np.max(hold_difficulties),
            'min_hold_difficulty': np.min(hold_difficulties),
            'std_hold_difficulty': np.std(hold_difficulties),
            'difficulty_range': np.max(hold_difficulties) - np.min(hold_difficulties)
        })
    
    return features

def load_data(json_path):
    with open(json_path, 'r') as f:
        problems = json.load(f)
    
    # Calculate hold difficulty scores
    print("Calculating hold difficulty scores...")
    hold_difficulty_scores = calculate_hold_difficulty_scores(problems)
    print(f"Calculated difficulty scores for {len(hold_difficulty_scores)} unique holds")
    
    X = []
    y = []
    for problem_id, problem in problems.items():
        X.append(extract_features(problem, hold_difficulty_scores))
        y.append(problem.get('Grade', 'Unknown'))
    return pd.DataFrame(X), y

def plot_feature_importance(X, model):
    """Create a bar chart of feature importance"""
    feature_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=True)
    
    # Set style
    plt.style.use('default')
    plt.rcParams['font.size'] = 9
    plt.rcParams['font.family'] = 'DejaVu Sans'
    
    # Get screen-appropriate size
    fig_width, fig_height = FIG_WIDTH, FIG_HEIGHT
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    # Create horizontal bar chart
    bars = ax.barh(range(len(feature_importance)), feature_importance['importance'],
                   color='steelblue', alpha=0.8, edgecolor='midnightblue', linewidth=0.5)
    
    # Set y-axis labels with better formatting
    ax.set_yticks(range(len(feature_importance)))
    ax.set_yticklabels(feature_importance['feature'], fontsize=9, fontweight='normal')
    
    # Set x-axis
    ax.set_xlabel('Feature Importance', fontsize=10, fontweight='bold')
    ax.set_title('Climbing Problem Difficulty - Feature Importance',
                fontsize=12, fontweight='bold', pad=20)
    
    # Add grid
    ax.grid(axis='x', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Add value labels on bars with better positioning
    for i, bar in enumerate(bars):
        width = bar.get_width()
        ax.text(width + 0.002, bar.get_y() + bar.get_height()/2,
                f'{width:.3f}', ha='left', va='center', fontweight='bold', fontsize=8)
    
    # Adjust layout to prevent text cutoff - increased padding
    plt.tight_layout(pad=2.0)
    
    plt.show()

def plot_confusion_matrix(y_test, y_pred, le):
    """Create a confusion matrix heatmap with improved coloring for better accuracy"""
    cm = confusion_matrix(y_test, y_pred)
    
    # Calculate normalized confusion matrix for better color representation
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # Handle any NaN values
    
    # Set style
    plt.style.use('default')
    plt.rcParams['font.size'] = 7
    plt.rcParams['font.family'] = 'DejaVu Sans'
    
    # Get screen-appropriate size - use full width for two subplots
    fig_width, fig_height = FIG_WIDTH, FIG_HEIGHT
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(fig_width, fig_height * 0.7))
    
    # Create heatmap with raw counts (left subplot)
    sns.heatmap(cm, annot=True, fmt='d', cmap='YlOrRd',
                xticklabels=le.classes_, yticklabels=le.classes_,
                ax=ax1, cbar_kws={'label': 'Number of Predictions'})
    
    ax1.set_title('Confusion Matrix - Raw Counts',
                fontsize=11, fontweight='bold', pad=20)
    ax1.set_xlabel('Predicted Grade', fontsize=9, fontweight='bold')
    ax1.set_ylabel('Actual Grade', fontsize=9, fontweight='bold')
    
    # Rotate labels for better readability
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    ax1.set_yticklabels(ax1.get_yticklabels(), rotation=0)
    
    # Create normalized heatmap (right subplot) for better color accuracy
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='RdYlBu_r',
                xticklabels=le.classes_, yticklabels=le.classes_,
                ax=ax2, cbar_kws={'label': 'Normalized Accuracy'})
    
    ax2.set_title('Confusion Matrix - Normalized',
                fontsize=11, fontweight='bold', pad=20)
    ax2.set_xlabel('Predicted Grade', fontsize=9, fontweight='bold')
    ax2.set_ylabel('Actual Grade', fontsize=9, fontweight='bold')
    
    # Rotate labels for better readability
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0)
    
    # Add overall title
    fig.suptitle('Climbing Problem Difficulty Prediction - Confusion Matrix Analysis',
                fontsize=13, fontweight='bold', y=0.95)
    
    # Adjust layout with more padding
    plt.tight_layout(pad=3.0)
    
    plt.show()

def analyze_feature_categories(X, model):
    """Analyze the relative importance of hold difficulty, spatial, and number of holds features"""
    # Define feature categories
    hold_difficulty_features = ['avg_hold_difficulty', 'max_hold_difficulty', 'min_hold_difficulty',
                               'std_hold_difficulty', 'difficulty_range']
    spatial_features = ['std_x', 'std_y', 'avg_x', 'avg_y']
    num_holds_feature = ['num_holds']
    # Calculate importance for each category
    feature_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': model.feature_importances_
    })
    hold_difficulty_importance = feature_importance[
        feature_importance['feature'].isin(hold_difficulty_features)]['importance'].sum()
    spatial_importance = feature_importance[
        feature_importance['feature'].isin(spatial_features)]['importance'].sum()
    num_holds_importance = feature_importance[
        feature_importance['feature'].isin(num_holds_feature)]['importance'].sum()
    total_importance = feature_importance['importance'].sum()
    
    # Calculate percentages
    hold_difficulty_percent = (hold_difficulty_importance / total_importance) * 100
    spatial_percent = (spatial_importance / total_importance) * 100
    num_holds_percent = (num_holds_importance / total_importance) * 100
    return hold_difficulty_percent, spatial_percent, num_holds_percent, feature_importance

def plot_feature_comparison(hold_difficulty_percent, spatial_percent, num_holds_percent):
    """Create a pie and bar chart comparing hold difficulty, spatial, and number of holds features"""
    # Set style
    plt.style.use('default')
    plt.rcParams['font.size'] = 9
    plt.rcParams['font.family'] = 'DejaVu Sans'
    # Get screen-appropriate size
    fig_width, fig_height = FIG_WIDTH, FIG_HEIGHT
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(fig_width, fig_height * 0.5))
    # Create the data
    categories = ['Hold Difficulty\nFeatures', 'Spatial Distance\nFeatures', 'Number of Holds']
    percentages = [hold_difficulty_percent, spatial_percent, num_holds_percent]
    colors = ['salmon', 'turquoise', 'gold']
    # Create pie chart
    wedges, texts, autotexts = ax1.pie(percentages, labels=categories, colors=colors,
                                       autopct='%1.1f%%', startangle=90, explode=(0.05, 0.05, 0.05),
                                       textprops={'fontsize': 9, 'fontweight': 'bold'})
    ax1.set_title('Feature Importance: Hold Difficulty, Spatial, Number of Holds',
                 fontsize=11, fontweight='bold', pad=20)
    # Add percentage labels with better formatting
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')
        autotext.set_fontsize(10)
    # Create bar chart for comparison
    bars = ax2.bar(categories, percentages, color=colors, alpha=0.8,
                   edgecolor='black', linewidth=1)
    ax2.set_ylabel('Importance (%)', fontsize=10, fontweight='bold')
    ax2.set_title('Feature Category Comparison', fontsize=11, fontweight='bold', pad=20)
    ax2.set_ylim(0, max(percentages) * 1.2)
    # Add grid
    ax2.grid(axis='y', alpha=0.3, linestyle='--')
    ax2.set_axisbelow(True)
    # Add value labels on bars with better positioning
    for bar, percentage in zip(bars, percentages):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + max(percentages) * 0.02,
                f'{percentage:.1f}%', ha='center', va='bottom', 
                fontweight='bold', fontsize=10)
    # Adjust layout with more padding
    plt.tight_layout(pad=3.0)
    plt.show()

def print_detailed_results(X, y, model, le, y_test, y_pred, X_train, y_train, X_test):
    """Print detailed results in a nice format"""
    print("=" * 60)
    print(" CLIMBING PROBLEM DIFFICULTY PREDICTION RESULTS ")
    print("=" * 60)
    
    # Model Performance
    print("\n MODEL PERFORMANCE:")
    print(f"   Training Accuracy: {model.score(X_train, y_train):.1%}")
    print(f"   Testing Accuracy:  {model.score(X_test, y_pred):.1%}")
    
    # Dataset Info
    print(f"\n DATASET INFORMATION:")
    print(f"   Total Problems:    {len(X)}")
    print(f"   Training Samples:  {len(X_train)}")
    print(f"   Testing Samples:   {len(y_test)}")
    print(f"   Unique Grades:     {len(le.classes_)}")
    
    # Grade Distribution
    grade_counts = pd.Series(y).value_counts().sort_index()
    print(f"\n GRADE DISTRIBUTION:")
    for grade, count in grade_counts.head(10).items():
        percentage = (count / len(y)) * 100
        print(f"   {grade}: {count:4d} problems ({percentage:5.1f}%)")
    
    # Feature Importance
    print(f"\n TOP FEATURES FOR PREDICTING DIFFICULTY:")
    feature_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)
    
    for i, (_, row) in enumerate(feature_importance.head(8).iterrows(), 1):
        print(f"   {i:2d}. {row['feature']:<20} {row['importance']:.3f}")
    
    # Feature Category Analysis
    # hold_difficulty_percent, spatial_percent, num_holds_percent, _ = analyze_feature_categories(X, model)
    # plot_feature_comparison(hold_difficulty_percent, spatial_percent, num_holds_percent)
    
    # Classification Report
    print(f"\n DETAILED CLASSIFICATION REPORT:")
    print(classification_report(y_test, y_pred, target_names=le.classes_))

def main():
    json_path = 'moonboard_problems_setup_2016.json'
    
    print(" Loading climbing problems data...")
    X, y = load_data(json_path)
    
    print(" Training machine learning model...")
    le = LabelEncoder()
    y_encoded = le.fit_transform(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
    
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    
    # Make predictions
    y_pred = model.predict(X_test)
    
    # Print detailed results
    print_detailed_results(X, y, model, le, y_test, y_pred, X_train, y_train, X_test)
    
    # Show feature comparison (pie/bar chart) as the first popup
    hold_difficulty_percent, spatial_percent, num_holds_percent, _ = analyze_feature_categories(X, model)
    plot_feature_comparison(hold_difficulty_percent, spatial_percent, num_holds_percent)
    
    # Show feature importance plot
    plot_feature_importance(X, model)
    
    # Show confusion matrix plot
    plot_confusion_matrix(y_test, y_pred, le)
    
    # Save results for later viewing
    joblib.dump({
        'X': X,
        'y': y,
        'model': model,
        'le': le,
        'X_train': X_train,
        'X_test': X_test,
        'y_train': y_train,
        'y_test': y_test,
        'y_pred': y_pred
    }, 'ml_results.joblib')
    print("\n Analysis complete! Results saved to ml_results.joblib. Check the plots for visual insights.")

if __name__ == "__main__":
    main()