#!/usr/bin/env python3
"""
rename-ai-snaps — Scan PNGs for AI prompt metadata and propose descriptive filenames.

Usage:
  ./rename-ai-snaps [path] [--no-interactive]

Scans directory (default ~/Pictures) for ComfyUI PNGs, reads their embedded prompt,
extracts short keyword descriptions, and interactively renames them to:
  schmeeve-AI-{keywords}.png
"""

import json
import os
import re
import sys
import time
from pathlib import Path

try:
    from PIL import Image
except ImportError:
    print("Error: Pillow (PIL) is required. Install with: pip install Pillow")
    sys.exit(1)

# ── stopwords and filter sets ──────────────────────────────────────────────

QUALITY_TAGS = {
    "score_6_up", "score_7_up", "score_8_up", "score_9",
    "score_6", "score_7", "score_8",
    "masterpiece", "best quality", "good quality", "normal quality",
    "high quality", "highly detailed", "very detailed", "extreme detail",
    "very_aesthetic", "absurdres", "8k", "4k",
    "photorealistic", "photograph",
    "depth of field", "solo focus", "cinematic",
    "newest", "amazing", "stunning",
}

QUALITY_WORDS = {
    "best", "good", "high", "top", "ultra", "super",
    "mega", "hyper", "extreme", "extra", "ultimate",
}

TECHNICAL_WORDS = {
    "detailed", "focus", "quality", "aesthetic", "realistic",
    "cinematic", "lighting", "rendering", "shading", "texture",
    "newest", "absurdres",
}

STOP_WORDS = {
    "the", "a", "an", "of", "in", "on", "at", "to", "for", "with",
    "and", "or", "is", "are", "was", "were", "be", "been", "being",
    "have", "has", "had", "do", "does", "did", "will", "would",
    "could", "should", "may", "might", "can", "shall", "this",
    "that", "these", "those", "it", "its", "by", "from", "as",
    "into", "through", "during", "before", "after", "above", "below",
    "between", "out", "off", "over", "under", "again", "further",
    "then", "once", "here", "there", "when", "where", "why", "how",
    "all", "each", "every", "both", "few", "more", "most", "other",
    "some", "such", "no", "nor", "not", "only", "own", "same", "so",
    "than", "too", "very", "just", "about", "up", "down",
    "make", "get", "set", "put", "take", "give", "show", "use",
    "like", "look", "see", "want", "need", "let", "close", "full",
    "add", "new", "one", "two", "five",
    "also", "well", "back", "still", "even", "much",
    "you", "your", "my", "me", "we", "our", "they", "them", "their",
}

GENERIC_WORDS = {
    "man", "men", "guy", "guys", "boy", "boys", "woman", "women",
    "girl", "girls", "people", "person", "human", "figure",
    "photo", "image", "picture", "shot", "view", "pose", "posing",
    "face", "head", "body", "skin", "hair", "eyes", "hand", "hands",
    "dark", "light", "bright", "color", "colour",
}

NEGATIVE_INDICATORS = {
    "deformed", "distorted", "disfigured", "poorly drawn", "bad anatomy",
    "extra digits", "missing digits", "extra limbs", "missing limbs",
    "ugly", "tiling", "low quality", "worst quality", "normal quality",
    "lowres", "monochrome", "grayscale", "text", "watermark",
    "branding", "border", "cropped", "signature", "username",
    "error", "mutation", "mutated", "out of frame", "duplicate", "cloned",
    "body out of frame", "bad hands", "bad face", "blurry",
}



def spinner():
    chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
    i = 0
    while True:
        yield chars[i % len(chars)]
        i += 1


def is_negative_text(text):
    """Heuristic: is this text block a negative prompt?"""
    lower = text.lower()
    score = 0
    for ind in NEGATIVE_INDICATORS:
        if ind in lower:
            score += 1
    return score >= 2


def is_quality_only(text):
    """Heuristic: does this text block contain only quality/technical tags?"""
    lower = text.lower()
    # Split into words, strip weighting syntax
    words = re.findall(r"[a-z_]+", lower)
    if not words:
        return False
    meaningful = sum(1 for w in words if w not in QUALITY_TAGS and len(w) > 2)
    return meaningful == 0


def extract_prompts(filepath):
    """Return (positive_prompt, negative_prompt) from a ComfyUI PNG."""
    try:
        img = Image.open(filepath)
    except Exception:
        return None, None

    if "prompt" not in img.info:
        return None, None

    try:
        data = json.loads(img.info["prompt"])
    except (json.JSONDecodeError, TypeError):
        return None, None

    # Gather all text fields from all nodes
    candidates = []
    for node in data.values():
        inputs = node.get("inputs", {})
        # Check all common text-carrying fields
        for field in ("text", "prompt", "positive", "negative", "value", "string"):
            val = inputs.get(field, "")
            if isinstance(val, str) and len(val.strip()) > 3:
                candidates.append(val.strip())

    # Second pass: resolve node references (e.g. ["node_id", 0])
    for node in data.values():
        inputs = node.get("inputs", {})
        for field in ("text", "prompt", "positive", "negative", "value", "string"):
            val = inputs.get(field)
            if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
                ref_node = data.get(val[0], {})
                # Check if the referenced node has a 'value' or 'text' field
                ref_inputs = ref_node.get("inputs", {})
                for rf in ("value", "text", "string"):
                    rv = ref_inputs.get(rf, "")
                    if isinstance(rv, str) and len(rv.strip()) > 3:
                        candidates.append(rv.strip())
                        break

    if not candidates:
        return None, None

    # Split into positive vs negative
    positives = [c for c in candidates if not is_negative_text(c)]
    negatives = [c for c in candidates if is_negative_text(c)]

    # Further filter: quality-only texts are not useful for naming
    positives = [c for c in positives if not is_quality_only(c)]

    pos = max(positives, key=len) if positives else None
    neg = max(negatives, key=len) if negatives else None
    return pos, neg


def extract_keywords(text, max_chars=40):
    """Generate a short dash-separated keyword description from prompt text."""
    if not text:
        return None

    text_lower = text.lower()

    # Strip weighting/parenthetical syntax: (word:1.2), [word], {word}, (word)
    cleaned = re.sub(r"[\[\(\{][^\]\)\}]*[\]\)\}]", "", text_lower)

    # Split on commas, periods, semicolons, exclamation/question marks
    segments = re.split(r"[,.;:!?]+", cleaned)

    # Collect meaningful keywords, preserving order
    seen = set()
    keywords = []
    for seg in segments:
        seg = seg.strip()
        if not seg or len(seg) < 4:
            continue

        # Extract individual words from segment
        words = re.findall(r"[a-zA-Z_]+", seg)
        good = []
        for w in words:
            wl = w.lower().strip("_")
            # Skip short words, stop words, quality tags, negative indicators
            if len(wl) < 3:
                continue
            if wl in STOP_WORDS or wl in GENERIC_WORDS:
                continue
            if wl in QUALITY_TAGS or wl in QUALITY_WORDS:
                continue
            if wl in TECHNICAL_WORDS:
                continue
            if wl in NEGATIVE_INDICATORS:
                continue
            if wl.startswith("score") or wl.startswith("step"):
                continue
            if wl.isdigit():
                continue
            good.append(wl)

        if good:
            # Take up to 2 unseen keywords from this segment
            taken = 0
            for g in good:
                if g not in seen and taken < 2:
                    keywords.append(g)
                    seen.add(g)
                    taken += 1

    if not keywords:
        return None

    # Build description: up to 5 keywords
    desc = "-".join(keywords[:5])

    # Replace any non-alphanumeric (except hyphen) with hyphens
    desc = re.sub(r"[^a-z0-9-]", "-", desc)
    desc = re.sub(r"-+", "-", desc).strip("-")

    # Truncate at max_chars, breaking at a word boundary
    if len(desc) > max_chars:
        desc = desc[:max_chars].rstrip("-")
        if max_chars > 10 and "-" in desc:
            truncated = "-".join(desc.split("-")[:-1])
            if truncated and len(truncated) > 10:
                desc = truncated

    return desc if desc and len(desc) > 3 else None


def propose_name(filepath):
    """Propose 'schmeeve-AI-{keywords}.png' or None."""
    pos, neg = extract_prompts(filepath)
    source = pos or neg
    if not source:
        return None

    desc = extract_keywords(source)
    if not desc:
        return None

    return f"schmeeve-AI-{desc}.png"


# ── main ───────────────────────────────────────────────────────────────────

def main():
    import argparse

    parser = argparse.ArgumentParser(
        description="Rename AI-generated PNGs based on embedded prompt metadata.",
    )
    parser.add_argument(
        "path", nargs="?", default=os.path.expanduser("~/Pictures"),
        help="Directory to scan for PNG files (default: ~/Pictures)",
    )
    parser.add_argument(
        "-n", "--no-interactive", action="store_true",
        help="Auto-rename without prompting",
    )
    args = parser.parse_args()

    scan_dir = Path(args.path).expanduser().resolve()
    if not scan_dir.is_dir():
        print(f"Error: {scan_dir} is not a directory")
        sys.exit(1)

    pngs = sorted(scan_dir.glob("*.png"))

    # ── Phase 1: Analyze with spinner ──
    sys.stdout.write("  Analyzing PNGs")
    sys.stdout.flush()
    spin = spinner()
    raw_proposals = {}
    for p in pngs:
        sys.stdout.write(f"\r  {next(spin)} Analyzing PNGs")
        sys.stdout.flush()
        # Skip already-renamed files
        if p.name.startswith("schmeeve-AI-"):
            continue
        name = propose_name(str(p))
        if name:
            raw_proposals[p] = name
        time.sleep(0.02)

    # Clear the spinner line
    sys.stdout.write("\r" + " " * 60 + "\r")
    sys.stdout.flush()

    # Deduplicate proposed names
    proposals = {}
    name_counts = {}
    for p, name in raw_proposals.items():
        base = name
        if base in name_counts:
            name_counts[base] += 1
            stem = base.rsplit(".", 1)[0]
            ext = ".png"
            name = f"{stem}_{name_counts[base]}{ext}"
        else:
            name_counts[base] = 0
        proposals[p] = name

    # ── Phase 2: Display proposed names ──
    if not proposals:
        print("  No renamable PNGs found.")
        return

    for i, (old_path, new_name) in enumerate(proposals.items(), 1):
        old = old_path.name
        stem = old_path.stem
        ext = old_path.suffix
        # Truncate old name for display
        old_display = old if len(old) < 50 else old[:22] + "…" + old[-25:]
        print(f"  {i:>3}. {old_display}")
        print(f"       → {new_name}")

    print()
    print(f"  {len(proposals)} file(s) to rename.\n")

    # ── Phase 3: Rename (interactive or auto) ──
    if args.no_interactive:
        renamed = 0
        for old_path, new_name in proposals.items():
            new_path = old_path.with_name(new_name)
            if new_path.exists():
                stem = new_path.stem
                counter = 1
                while new_path.exists():
                    new_path = old_path.with_name(f"{stem}_{counter}{old_path.suffix}")
                    counter += 1
            old_path.rename(new_path)
            renamed += 1
        print(f"  Renamed {renamed} file(s).")
    else:
        renamed = 0
        skipped = 0
        items = list(proposals.items())
        i = 0
        while i < len(items):
            old_path, new_name = items[i]
            old = old_path.name
            new = new_name

            print(f"\n  [{i+1}/{len(items)}]")
            print(f"  Current: {old}")
            print(f"  New:     {new}")
            remaining = len(items) - i - 1
            rlabel = f"rename all {remaining}" if remaining else ""
            sys.stdout.write(f"  [Enter]=rename  [e]=edit  [s]=skip{'  [a]=' + rlabel if rlabel else ''}  [q]=quit: ")
            sys.stdout.flush()
            choice = sys.stdin.readline().strip().lower()

            if choice == "q":
                remaining = len(items) - i - 1
                if remaining:
                    print(f"  Skipping remaining {remaining} file(s).")
                break
            elif choice == "s":
                skipped += 1
                i += 1
                continue
            elif choice == "e":
                sys.stdout.write(f"  Edit name (will be prefixed 'schmeeve-AI-'): ")
                sys.stdout.flush()
                custom = sys.stdin.readline().strip()
                if custom:
                    # Sanitize
                    custom_desc = re.sub(r"[^a-z0-9-]", "-", custom.lower())
                    custom_desc = re.sub(r"-+", "-", custom_desc).strip("-")
                    if custom_desc:
                        new = f"schmeeve-AI-{custom_desc}.png"
                    else:
                        print("  Invalid name, skipping.")
                        skipped += 1
                        i += 1
                        continue
                else:
                    # empty = skip
                    skipped += 1
                    i += 1
                    continue
                # Fall through to rename
            elif choice == "":
                pass  # rename with proposed name
            elif choice == "a":
                # Rename current and all remaining without further prompts
                for j in range(i, len(items)):
                    p, n = items[j]
                    np = p.with_name(n)
                    if np.exists():
                        stem = np.stem
                        counter = 1
                        while np.exists():
                            np = p.with_name(f"{stem}_{counter}{p.suffix}")
                            counter += 1
                    p.rename(np)
                renamed += len(items) - i
                break
            else:
                print(f"  Unknown option '{choice}', skipping.")
                skipped += 1
                i += 1
                continue

            # Perform rename
            new_path = old_path.with_name(new)
            if new_path.exists():
                stem = new_path.stem
                counter = 1
                while new_path.exists():
                    new_path = old_path.with_name(f"{stem}_{counter}{old_path.suffix}")
                    counter += 1
                print(f"  (file existed, saved as {new_path.name})")
            old_path.rename(new_path)
            renamed += 1
            i += 1

        print(f"\n  Renamed: {renamed}  Skipped: {skipped}")

    # One more pass: also look at .jpg? (maybe later)
    remaining_ai = 0
    for f in scan_dir.glob("*.jpg"):
        try:
            img = Image.open(f)
            if "prompt" in img.info:
                remaining_ai += 1
        except Exception:
            pass
    if remaining_ai:
        print(f"  Note: {remaining_ai} JPEG(s) with AI metadata found (not yet supported).")


if __name__ == "__main__":
    main()
