"""
This script subscribes to a ROS2 PointCloud2 topic and saves incoming point clouds as PCD 
files along with corresponding feature points from a Marker topic.
"""

import rclpy
from rclpy.node import Node
from sensor_msgs.msg import PointCloud2
from visualization_msgs.msg import Marker
import numpy as np
import open3d as o3d
import sensor_msgs_py.point_cloud2 as pc2

class PointCloudSaver(Node):
    def __init__(self):
        super().__init__('pointcloud_saver')
        self.subscription = self.create_subscription(
            PointCloud2, '/ouster/points', self.callback, 10)
        self.subscription = self.create_subscription(
            Marker, '/feature_points', self.callback1, 10)
        self.counter = 0
        self.max_snapshots = 100 

    def callback1(self, msg):
        """Extract and save feature points from the Marker message."""
        feature_points = np.array([[p.x, p.y, p.z] for p in msg.points])

        if feature_points.size == 0:
            self.get_logger().warn("Received empty feature points!")
            return

        # Save feature points as a text file
        np.savetxt(f"feature_points_{self.counter:02d}.txt", feature_points, fmt="%.6f", header="x y z")

        # Optionally, save as a binary .npy file
        np.save(f"feature_points_{self.counter:02d}.npy", feature_points)

        self.get_logger().info(f"Saved {len(feature_points)} feature points.")


    def callback(self, msg):
        if self.counter >= self.max_snapshots:
            self.get_logger().info("Captured all snapshots. Exiting.")
            rclpy.shutdown()

        # Convert ROS2 PointCloud2 message to numpy array
        points = list(pc2.read_points(msg, field_names=("x", "y", "z"), skip_nans=True))
        points = np.array([[p['x'], p['y'], p['z']] for p in pc2.read_points(msg, field_names=("x", "y", "z"), skip_nans=True)])

        if points.size == 0:
            self.get_logger().warn("Empty point cloud received!")
            return

        # Save as PCD file
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        filename = f"snapshot_{self.counter:02d}.pcd"
        o3d.io.write_point_cloud(filename, pcd)
        self.get_logger().info(f"Saved: {filename}")

        self.counter += 1

def main():
    rclpy.init()
    node = PointCloudSaver()
    rclpy.spin(node)

if __name__ == '__main__':
    main()
