"""
This script computes various mean metrics (IoU, Precision, Recall, F1 Score, Chamfer Distance)
between feature points extracted from .npy files and ground truth points from annotated
point clouds in .pcd files. It also provides visualization of the point clouds and errors.
"""

import numpy as np
import open3d as o3d
import json
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree

def load_feature_points(npy_file):
    return np.load(npy_file)  

def load_ground_truth_pcd(point_cloud, json_file):
    point_cloud = o3d.io.read_point_cloud(point_cloud)
    with open(json_file, 'r') as f:
        data = json.load(f)
    binary_mask = np.array(data.get("attributes", {}).get("point_annotations", []))
    selected_indices = np.where(binary_mask == 1)[0]

    ground_truth_points = np.asarray(point_cloud.points)[selected_indices]
    return ground_truth_points

def compute_iou(feature_points, ground_truth_points):

    feature_set = set(map(tuple, feature_points))
    gt_set = set(map(tuple, ground_truth_points))

    intersection = feature_set & gt_set  
    union = feature_set | gt_set         

    iou = len(intersection) / len(union) if len(union) > 0 else 0.0
    return iou

def compute_precision(feature_points, ground_truth_points):
    feature_set = set(map(tuple, feature_points))
    gt_set = set(map(tuple, ground_truth_points))
    
    true_positives = len(feature_set & gt_set)
    false_positives = len(feature_set - gt_set)
    
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    return precision

def compute_recall(feature_points, ground_truth_points):
    feature_set = set(map(tuple, feature_points))
    gt_set = set(map(tuple, ground_truth_points))
    
    true_positives = len(feature_set & gt_set)
    false_negatives = len(gt_set - feature_set)
    
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
    return recall

def compute_f1_score(feature_points, ground_truth_points):
    precision = compute_precision(feature_points, ground_truth_points)
    recall = compute_recall(feature_points, ground_truth_points)
    
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    return f1_score

def chamfer_distance(points_a, points_b):
    tree_a = cKDTree(points_a)
    tree_b = cKDTree(points_b)
    
    dist_a, _ = tree_b.query(points_a) 
    dist_b, _ = tree_a.query(points_b)  
    
    chamfer_dist = np.mean(dist_a**2) + np.mean(dist_b**2)
    return chamfer_dist

def plot_point_clouds(feature_points, ground_truth_points):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(feature_points[:, 0], feature_points[:, 1], feature_points[:, 2], 
            c='r', marker='o', s=5, label='Feature Points')

    ax.scatter(ground_truth_points[:, 0], ground_truth_points[:, 1], ground_truth_points[:, 2], 
            c='b', marker='o', s=5, label='Ground Truth Points')

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Point Cloud Visualization')
    ax.set_zlim(-10, 10)
    ax.legend()
    
    plt.show()
    plt.close()

def plot_point_clouds_with_errors(feature_points, ground_truth_points):
    feature_set = set(map(tuple, feature_points))
    gt_set = set(map(tuple, ground_truth_points))
    
    false_positives = np.array([point for point in feature_points if tuple(point) not in gt_set])
    false_negatives = np.array([point for point in ground_truth_points if tuple(point) not in feature_set])
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    axes[0, 0].scatter(ground_truth_points[:, 0], ground_truth_points[:, 1], c='b', s=2, label='Ground Truth')
    axes[0, 0].set_title('Ground Truth')
    axes[0, 0].set_xlabel('X')
    axes[0, 0].set_ylabel('Y')
    
    axes[0, 1].scatter(feature_points[:, 0], feature_points[:, 1], c='r', s=2, label='Feature Points')
    axes[0, 1].set_title('Feature Points')
    axes[0, 1].set_xlabel('X')
    axes[0, 1].set_ylabel('Y')

    axes[1, 0].scatter(false_positives[:, 0], false_positives[:, 1], c='g', s=2, label='False Positives')
    axes[1, 0].set_title('False Positives')
    axes[1, 0].set_xlabel('X')
    axes[1, 0].set_ylabel('Y')
    
    axes[1, 1].scatter(false_negatives[:, 0], false_negatives[:, 1], c='orange', s=2, label='False Negatives')
    axes[1, 1].set_title('False Negatives')
    axes[1, 1].set_xlabel('X')
    axes[1, 1].set_ylabel('Y')

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    ious = []
    precisions = []
    recalls = []
    f1s = []
    chamfers = []

    for index in range(15):
        npy_file = f"feature_points/feature_points_{(index + 1):02d}.npy"  
        json_file = f"annotated_snapshots/labels{index}.json"  
        point_cloud = f"snapshots/snapshot_{index:02d}.pcd"

        feature_points = load_feature_points(npy_file)
        ground_truth_points = load_ground_truth_pcd(point_cloud, json_file)

        # filter for things outside the costmap
        feature_points = feature_points[np.linalg.norm(feature_points, axis=1) <= 10*(2**0.5)]
        ground_truth_points = ground_truth_points[np.linalg.norm(ground_truth_points, axis=1) <= 10*(2**0.5)]

        # filter for things in the forst
        feature_points = feature_points[feature_points[:, 1] >= -5.0]
        ground_truth_points = ground_truth_points[ground_truth_points[:, 1] >= -5.0]

        # filter for points too close to the Jackal
        feature_points = feature_points[np.linalg.norm(feature_points, axis=1) > 2]
        ground_truth_points = ground_truth_points[np.linalg.norm(ground_truth_points, axis=1) > 2]

        # self intersecting points with Jackal
        feature_mask = ~((feature_points[:, 0] > -5) & (feature_points[:, 0] < 0) & 
                        (feature_points[:, 1] > -1.5) & (feature_points[:, 1] < 1.5))
        ground_mask = ~((ground_truth_points[:, 0] > -5) & (ground_truth_points[:, 0] < 0) & 
                        (ground_truth_points[:, 1] > -1.5) & (ground_truth_points[:, 1] < 1.5))
        feature_points = feature_points[feature_mask]
        ground_truth_points = ground_truth_points[ground_mask]

        # plot_point_clouds(feature_points, ground_truth_points)
        # plot_point_clouds_with_errors(feature_points, ground_truth_points)

        iou_score = compute_iou(feature_points, ground_truth_points)
        precision_score = compute_precision(feature_points, ground_truth_points)
        recall_score = compute_recall(feature_points, ground_truth_points)
        f1_score = compute_f1_score(feature_points, ground_truth_points)
        chamfer_score = chamfer_distance(feature_points, ground_truth_points)

        print(f"IoU Score: {iou_score}")
        print(f"Precision: {precision_score}")
        print(f"Recall: {recall_score}")
        print(f"F1 Score: {f1_score}")
        print(f"Chamfer Score: {chamfer_score}")
        print("========================")

        ious.append(iou_score)
        precisions.append(precision_score)
        recalls.append(recall_score)
        f1s.append(f1_score)
        chamfers.append(chamfer_score)

    print("Means:")
    print(f"IoU Mean: {np.mean(ious)}")
    print(f"Precision Mean: {np.mean(precisions)}")
    print(f"Recall Mean: {np.mean(recalls)}")
    print(f"F1s Mean: {np.mean(f1s)}")
    print(f"Chamfer Mean: {np.mean(chamfers)}")
