#!/usr/bin/env python3
# Copyright (c) Noitom Robotics. All rights reserved.
"""
Example: visualize hands_keypoint_3d.json as a single-panel video
showing RGB frames overlaid with the 2D hand skeleton.

This is the skeleton-only variant (no depth / point cloud rendering).

Single file, no pose_align dependency. Requires: numpy, opencv-python.

Expected directory layout for --folder:
  camera_params/head_param.json   - camera intrinsics (RGB + distortion)
  hands_keypoint_3d.json          - 3D keypoints + MANO parameters
  rgb_head.mp4 + rgb_head.csv     - source RGB video and per-frame timestamps
  rgb_undistorted/*.png           - optional; if present, used directly
                                    (otherwise the script streams + undistorts
                                     rgb_head.mp4 in memory, no files written
                                     to the source directory)

Usage:
    python example_kp_vis.py --folder /path/to/video_dir
    python example_kp_vis.py --folder /path/to/video_dir --output_dir /path/to/vis
"""
import os
import json
import csv
import argparse
import subprocess
import shutil
from pathlib import Path
from types import SimpleNamespace
from typing import List, Tuple, Optional, Iterator

import numpy as np
import cv2

# ====== Constants ======
JOINT_NAMES = [
    'wrist', 'thumb_cmc', 'thumb_mcp', 'thumb_ip', 'thumb_tip',
    'index_mcp', 'index_pip', 'index_dip', 'index_tip',
    'middle_mcp', 'middle_pip', 'middle_dip', 'middle_tip',
    'ring_mcp', 'ring_pip', 'ring_dip', 'ring_tip',
    'pinky_mcp', 'pinky_pip', 'pinky_dip', 'pinky_tip',
]
HAND_CONNECTIONS = [
    (0, 1), (1, 2), (2, 3), (3, 4),
    (0, 5), (5, 6), (6, 7), (7, 8),
    (0, 9), (9, 10), (10, 11), (11, 12),
    (0, 13), (13, 14), (14, 15), (15, 16),
    (0, 17), (17, 18), (18, 19), (19, 20),
    (5, 9), (9, 13), (13, 17),
]
FPS = 30
TAIL_EXCLUDE_SEC = 0.0  # Drop the last N seconds (matches `exclude_reason="tail"` in the JSON)

# Note on left-hand handling (see schema doc, Usage Notes #5):
#   The pipeline uses only MANO_RIGHT.pkl. For a left hand, after running the
#   right-hand MANO forward pass, the vertices and joints must be mirrored
#   along the X-axis (verts[:, 0] *= -1, joints[:, 0] *= -1) BEFORE being
#   compared with `keypoints_3d_cam_m`.
#
#   This visualization script does NOT need to do that flip itself, because
#   `keypoints_3d_cam_m` is already the post-flip, camera-frame result. We
#   simply project those 3D points to 2D and draw them. The `is_right` flag
#   is used only for picking colors and the "R"/"L" label.

RIGHT_HAND_COLORS = {'thumb': (0, 200, 255), 'index': (0, 100, 255), 'middle': (0, 50, 220),
                     'ring': (50, 0, 200), 'pinky': (80, 0, 180), 'palm': (0, 140, 255),
                     'wrist': (0, 255, 255), 'label': (0, 140, 255)}
LEFT_HAND_COLORS = {'thumb': (200, 200, 0), 'index': (200, 150, 0), 'middle': (255, 100, 0),
                    'ring': (255, 50, 50), 'pinky': (200, 0, 100), 'palm': (220, 180, 0),
                    'wrist': (255, 255, 0), 'label': (220, 180, 0)}


def get_bone_color(start_idx, end_idx, is_right):
    colors = RIGHT_HAND_COLORS if is_right else LEFT_HAND_COLORS
    if start_idx <= 4 and end_idx <= 4:
        return colors['thumb']
    elif 5 <= end_idx <= 8 or (start_idx == 0 and end_idx == 5):
        return colors['index']
    elif 9 <= end_idx <= 12 or (start_idx == 0 and end_idx == 9):
        return colors['middle']
    elif 13 <= end_idx <= 16 or (start_idx == 0 and end_idx == 13):
        return colors['ring']
    elif 17 <= end_idx <= 20 or (start_idx == 0 and end_idx == 17):
        return colors['pinky']
    return colors['palm']


# ====== Camera ======
def load_camera(folder: str) -> SimpleNamespace:
    """Load RGB camera intrinsics from camera_params/head_param.json.

    Returns a SimpleNamespace with:
      fx, fy, cx, cy   - raw intrinsics from JSON
      dist             - (5,) distortion coefficients (k1, k2, p1, p2, k3)
      orig_K           - (3, 3) raw intrinsic matrix
      undistorted_K    - (3, 3) intrinsic matrix matching the undistorted RGB
      need_undistort   - True if the camera has non-trivial distortion
      resolution       - (W, H) frame resolution
    """
    hp = Path(folder) / "camera_params" / "head_param.json"
    if not hp.exists():
        raise FileNotFoundError(f"Missing {hp}")
    with open(hp) as f:
        data = json.load(f)
    rgb = data["rgb_camera"]["intrinsic"]
    fx, fy = float(rgb["fx"]), float(rgb["fy"])
    cx, cy = float(rgb["ppx"]), float(rgb["ppy"])
    dist = np.array([float(rgb.get(k, 0)) for k in ("k1", "k2", "p1", "p2", "k3")], dtype=np.float64)
    orig_K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    resolution = (int(rgb.get("width", 1280)), int(rgb.get("height", 960)))
    # Use the same threshold as the pipeline (extract_mesh_keypoints, depth_fusion, etc.)
    need_undistort = bool(np.any(np.abs(dist) > 0.1))
    if need_undistort:
        new_K, _ = cv2.getOptimalNewCameraMatrix(orig_K, dist, resolution, 0, resolution)
    else:
        new_K = orig_K.copy()
    return SimpleNamespace(
        fx=fx, fy=fy, cx=cx, cy=cy,
        dist=dist, orig_K=orig_K, undistorted_K=new_K,
        need_undistort=need_undistort, resolution=resolution,
    )


def project_to_2d(pts_3d: np.ndarray, camera: SimpleNamespace) -> np.ndarray:
    """Project 3D camera-frame points (N, 3) onto the undistorted image (N, 2)."""
    pts = np.asarray(pts_3d, dtype=np.float64)
    if pts.ndim == 1:
        pts = pts.reshape(1, -1)
    h = (camera.undistorted_K @ pts.T).T
    return h[:, :2] / np.maximum(h[:, 2:3], 1e-8)


# ====== RGB stream (in-memory, no file writes) ======
def _read_csv_timestamps(csv_path: Path) -> dict:
    """Map frame_index -> timestamp_str (matches the naming used in rgb_undistorted/).

    Tolerates CSV files whose header contains stray whitespace (e.g. " timestamp_s").
    """
    out = {}
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        for raw in reader:
            row = {k.strip(): (v.strip() if isinstance(v, str) else v) for k, v in raw.items() if k is not None}
            idx = int(row["frame_index"])
            fn = row.get("filename", "")
            if fn:
                ts_str = fn.replace("\\", "/").split("/")[-1].replace(".png", "")
                try:
                    float(ts_str)  # validate
                    out[idx] = ts_str
                    continue
                except ValueError:
                    pass
            out[idx] = f"{float(row['timestamp_s']):.6f}"
    return out


def stream_undistorted_rgb(folder: Path, camera: SimpleNamespace) -> Iterator[Tuple[str, np.ndarray]]:
    """Yield (timestamp_str, undistorted_bgr) tuples.

    Priority:
      1. If `rgb_undistorted/*.png` exists, read those files directly.
      2. Otherwise stream from `rgb_head.mp4 + rgb_head.csv` and undistort in memory.
         No files are written under `folder` (the source directory is treated as read-only).
    """
    ud_dir = folder / "rgb_undistorted"
    if ud_dir.is_dir():
        pngs = sorted(ud_dir.glob("*.png"), key=lambda p: float(p.stem))
        if pngs:
            for p in pngs:
                img = cv2.imread(str(p))
                if img is not None:
                    yield p.stem, img
            return

    video_path = folder / "rgb_head.mp4"
    csv_path = folder / "rgb_head.csv"
    if not video_path.is_file() or not csv_path.is_file():
        raise FileNotFoundError(
            f"Neither rgb_undistorted/ nor (rgb_head.mp4 + rgb_head.csv) found in {folder}"
        )

    ts_map = _read_csv_timestamps(csv_path)

    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    undistort_maps = None
    if camera.need_undistort and (w, h) == camera.resolution:
        undistort_maps = cv2.initUndistortRectifyMap(
            camera.orig_K, camera.dist, None, camera.undistorted_K, (w, h), cv2.CV_32FC1,
        )
    elif camera.need_undistort:
        # Resolution mismatch — rebuild the undistorted_K for the actual video size.
        new_K, _ = cv2.getOptimalNewCameraMatrix(camera.orig_K, camera.dist, (w, h), 0, (w, h))
        camera.undistorted_K = new_K
        camera.resolution = (w, h)
        undistort_maps = cv2.initUndistortRectifyMap(
            camera.orig_K, camera.dist, None, new_K, (w, h), cv2.CV_32FC1,
        )

    try:
        idx = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            # JSON frame keys and csv timestamp_s share the same :.6f format.
            ts_str = ts_map.get(idx, f"{idx / float(FPS):.6f}")
            if undistort_maps is not None:
                frame = cv2.remap(frame, undistort_maps[0], undistort_maps[1], cv2.INTER_LINEAR)
            yield ts_str, frame
            idx += 1
    finally:
        cap.release()


# ====== 2D skeleton rendering ======
def draw_skeleton_2d(img, kp_2d, is_right, is_low_conf=False):
    H, W = img.shape[:2]
    colors = RIGHT_HAND_COLORS if is_right else LEFT_HAND_COLORS
    lt, pr = (1, 3) if is_low_conf else (2, 4)
    kp_2d = np.asarray(kp_2d)
    for start, end in HAND_CONNECTIONS:
        pt1 = kp_2d[start].astype(int)
        pt2 = kp_2d[end].astype(int)
        if 0 <= pt1[0] < W and 0 <= pt1[1] < H and 0 <= pt2[0] < W and 0 <= pt2[1] < H:
            c = get_bone_color(start, end, is_right)
            if is_low_conf:
                dist = np.linalg.norm(pt2.astype(float) - pt1.astype(float))
                if dist > 1:
                    n_dashes = max(int(dist / 6), 1)
                    for d in range(0, n_dashes, 2):
                        t1, t2 = d / n_dashes, min((d + 1) / n_dashes, 1.0)
                        p1 = (pt1 + t1 * (pt2 - pt1)).astype(int)
                        p2 = (pt1 + t2 * (pt2 - pt1)).astype(int)
                        cv2.line(img, tuple(p1), tuple(p2), c, 1, cv2.LINE_AA)
            else:
                cv2.line(img, tuple(pt1), tuple(pt2), c, lt, cv2.LINE_AA)
    for i, (u, v) in enumerate(kp_2d):
        ui, vi = int(u), int(v)
        if 0 <= ui < W and 0 <= vi < H:
            c = get_bone_color(0, i, is_right) if i > 0 else colors["wrist"]
            cv2.circle(img, (ui, vi), pr, c, -1 if not is_low_conf else 1)
    wx, wy = int(kp_2d[0, 0]), int(kp_2d[0, 1])
    if 0 <= wx < W and 0 <= wy < H:
        label = ("R?" if is_low_conf else "R") if is_right else ("L?" if is_low_conf else "L")
        ty, tx = max(20, wy - 20), wx + 10
        cv2.putText(img, label, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 3, cv2.LINE_AA)
        cv2.putText(img, label, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.8, colors["label"], 2, cv2.LINE_AA)
    return img


# ====== Main pipeline ======
def process_video(folder: str, output_dir: str) -> None:
    folder_p = Path(folder)
    name = folder_p.name
    json_path = folder_p / "hands_keypoint_3d.json"
    if not json_path.is_file():
        print(f"  Missing {json_path}")
        return
    with open(json_path) as f:
        data = json.load(f)

    try:
        camera = load_camera(folder)
    except FileNotFoundError as e:
        print(f"  {e}")
        return

    frames_data = data.get("frames", {})
    sorted_ts_all = sorted(frames_data.keys(), key=float)
    # Drop the last TAIL_EXCLUDE_SEC seconds (same effect as filtering by `excluded == true`).
    if sorted_ts_all and TAIL_EXCLUDE_SEC > 0:
        max_ts = float(sorted_ts_all[-1])
        kept_ts = {t for t in sorted_ts_all if float(t) <= max_ts - TAIL_EXCLUDE_SEC}
    else:
        kept_ts = set(sorted_ts_all)

    frames_dir = Path(output_dir) / f"_frames_{name}"
    frames_dir.mkdir(parents=True, exist_ok=True)

    written: List[Path] = []
    n_ok = 0

    try:
        rgb_stream = stream_undistorted_rgb(folder_p, camera)
    except (FileNotFoundError, RuntimeError) as e:
        print(f"  {e}")
        shutil.rmtree(frames_dir, ignore_errors=True)
        return

    for ts_str, img in rgb_stream:
        if ts_str not in kept_ts:
            continue
        H, W = img.shape[:2]
        frame_info = frames_data.get(ts_str, {})
        hands = frame_info.get("hands", [])

        for hand_info in hands:
            is_right = hand_info.get("is_right", False)
            conf = hand_info.get("confidence")  # may be "high", "low", or None
            kp_dict = hand_info.get("keypoints_3d_cam_m", {})
            if not kp_dict:
                continue
            # `keypoints_3d_cam_m` is already in camera frame and already accounts
            # for the right-model-mirrored-for-left-hand convention. No additional
            # X-axis flip is needed for visualization.
            kp_3d = np.array([kp_dict.get(n, [0, 0, 0]) for n in JOINT_NAMES], dtype=np.float64)
            # Treat both "low" and None as low-confidence (drawn dashed).
            is_low = (conf != "high")
            kp_2d = project_to_2d(kp_3d, camera)
            draw_skeleton_2d(img, kp_2d, is_right, is_low_conf=is_low)

        cv2.putText(img, "RGB + 2D Skeleton", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (50, 255, 50), 2)
        n_low = sum(1 for h in hands if h.get("confidence") != "high")
        if n_low > 0:
            cv2.putText(img, f"Non-high conf: {n_low} hand(s)", (10, 44),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.45, (100, 100, 255), 1, cv2.LINE_AA)
        cv2.putText(img, f"ts: {float(ts_str):.6f}", (10, H - 15),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)

        if img.shape[1] % 2 == 1:
            img = img[:, :-1]
        if img.shape[0] % 2 == 1:
            img = img[:-1, :]
        out_path = frames_dir / f"{ts_str}.png"
        cv2.imwrite(str(out_path), img)
        written.append(out_path)
        n_ok += 1

    if n_ok == 0:
        print(f"  No frames rendered for {name}")
        shutil.rmtree(frames_dir, ignore_errors=True)
        return

    out_video = str(Path(output_dir) / f"{name}_skeleton.mp4")
    list_file = frames_dir / "_list.txt"
    written.sort(key=lambda p: float(p.stem))
    with open(list_file, "w") as f:
        for p in written:
            f.write(f"file '{p.name}'\n")
            f.write(f"duration {1.0 / FPS}\n")
        f.write(f"file '{written[-1].name}'\n")

    try:
        subprocess.run([
            "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file),
            "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2", "-c:v", "libx264", "-crf", "18",
            "-pix_fmt", "yuv420p", "-r", str(FPS), out_video
        ], capture_output=True, timeout=600, cwd=str(frames_dir), check=True)
        print(f"  Video: {out_video} ({n_ok} frames)")
    except FileNotFoundError:
        print(f"  WARNING: ffmpeg not found, frames saved in {frames_dir}")
        return
    except subprocess.CalledProcessError as e:
        print(f"  ffmpeg failed: {(e.stderr or b'').decode()[:300]}")
        return
    shutil.rmtree(frames_dir, ignore_errors=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize hands_keypoint_3d.json")
    parser.add_argument("--folder", required=True,
                        help="Path to a single video directory containing hands_keypoint_3d.json")
    parser.add_argument("--output_dir", default=None,
                        help="Output directory (default: <folder>/../kp_vis_output)")
    args = parser.parse_args()
    if args.output_dir is None:
        args.output_dir = str(Path(args.folder).parent / "kp_vis_output")
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Processing {args.folder}...")
    process_video(args.folder, args.output_dir)
    print("Done!")
