#!/usr/bin/env python3
"""
Archive Discourse posts and render topics to Markdown from multiple sites.

Uses locally archived JSON posts to render Markdown topics. The API is only used
to check/newly fetch posts for a topic. The API endpoints used are:
  - https://{defaultHost}/t/{topic_id}.json (for topic metadata)
  - https://{defaultHost}/posts/{post_id}.json (for individual posts)
  - https://{defaultHost}/c/{slug}/{id}.json (for listing topics by category)

Usage:
  ./discourse2github.py --urls https://forum.example.org,... --target-dir ./archive
"""

import argparse
import concurrent.futures
import datetime
import functools
import json
import logging
import os
import re
import sys
import time
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
from urllib.parse import urlparse

from concurrent.futures import ThreadPoolExecutor, as_completed

import html2text  # pip install html2text
from bs4 import BeautifulSoup  # pip install beautifulsoup4

# Logging setup: use rich if available.
lvl = 'DEBUG' if os.environ.get('DEBUG') else 'INFO'
try:
    from rich.logging import RichHandler
    logging.basicConfig(level=lvl, datefmt="[%X]", handlers=[RichHandler()])
except ImportError:
    logging.basicConfig(level=lvl)
log = logging.getLogger('archive')

# Config constants
BATCH_SIZE = 100
SLEEP_SEC = 2
MAX_ITER = 1000
RETRY_MAX = 5  # Maximum retries on error

# Argument Parser
parser = argparse.ArgumentParser(description='Archive and render Discourse topics.')
parser.add_argument('--urls', help='Comma-separated Discourse URLs',
                    default=os.environ.get('DISCOURSE_URLS', 'https://forum.hackliberty.org'))
parser.add_argument('--debug', action='store_true', default=os.environ.get('DEBUG', False))
parser.add_argument('-t', '--target-dir', help='Base directory for archives',
                    default=Path(os.environ.get('TARGET_DIR', './archive')))

@functools.cache
def args():
    return parser.parse_args()

def parse_sites(urls: str) -> list:
    return [u.strip().rstrip('/') for u in urls.split(',') if u.strip()]

# API credentials (optional)
API_KEY = os.environ.get("DISCOURSE_API_KEY", "")
API_USER = os.environ.get("DISCOURSE_API_USERNAME", "")

def fetch_url(url: str, timeout=15) -> str:
    """
    Fetch a URL with a retry loop. Logs additional debug info.
    If a 404 error is encountered, immediately return None.
    For other errors, wait and retry until RETRY_MAX is reached.
    """
    backoff = 3
    attempts = 0
    req = urllib.request.Request(url)
    # Add API headers if available.
    if API_KEY and API_USER:
        req.add_header("Api-Key", API_KEY)
        req.add_header("Api-Username", API_USER)
    while attempts < RETRY_MAX:
        try:
            log.debug("Attempt %d: Fetching URL: %s", attempts + 1, url)
            with urllib.request.urlopen(req, timeout=timeout) as resp:
                data = resp.read().decode()
                log.debug(
                    "Successfully fetched URL: %s | HTTP Status: %s | Response length: %d bytes",
                    url, resp.status, len(data)
                )
                return data
        except urllib.error.HTTPError as e:
            if e.code == 404:
                log.warning("Resource not found (404) for %s, skipping further retries", url)
                return None
            attempts += 1
            log.warning("HTTPError fetching %s: %s (attempt %d/%d)", url, e, attempts, RETRY_MAX, exc_info=True)
            time.sleep(backoff)
            backoff *= 2
        except Exception as e:
            attempts += 1
            log.warning("Error fetching %s: %s (attempt %d/%d)", url, e, attempts, RETRY_MAX, exc_info=True)
            time.sleep(backoff)
            backoff *= 2
    log.error("Failed fetching %s after %d attempts.", url, RETRY_MAX)
    return None

def fetch_json(url: str, timeout=15) -> dict:
    """
    Fetch JSON data from a URL.
    Logs the received raw data size and the parsed JSON keys where applicable.
    Returns None if the fetch failed or returned 404.
    """
    data = fetch_url(url, timeout)
    if data is None:
        log.debug("No data returned for URL: %s", url)
        return None
    log.debug("Fetched raw data from %s (length: %d bytes)", url, len(data))
    try:
        js = json.loads(data)
        if isinstance(js, dict):
            log.debug("JSON parsed from %s, keys: %s", url, list(js.keys()))
        else:
            log.debug("JSON parsed from %s is not a dict (type: %s)", url, type(js).__name__)
        return js
    except json.JSONDecodeError as e:
        log.error("JSON decode error for %s: %s", url, e, exc_info=True)
        return None


def truncate_fn(name: str, max_len=255) -> str:
    if len(name) <= max_len:
        return name
    p = Path(name)
    stem, suffix = p.stem, "".join(p.suffixes)
    allowed = max_len - len(suffix)
    return (stem[:allowed] if allowed > 0 else name[:max_len]) + suffix

# --- Helpers for images & HTML content ---
def fix_url(url: str) -> str:
    return "https:" + url if url.startswith("//") else url

def download_img(url: str, dest: Path, tid: int = None, timeout=15):
    if dest.exists():
        log.debug("Img exists for topic %s: %s", tid, dest)
        return
    attempts = 0
    backoff = 2
    while attempts < RETRY_MAX:
        try:
            log.info("Downloading img for topic %s: %s", tid, url)
            with urllib.request.urlopen(fix_url(url), timeout=timeout) as r:
                data = r.read()
            dest.parent.mkdir(parents=True, exist_ok=True)
            dest.write_bytes(data)
            log.info("Saved img for topic %s to %s", tid, dest)
            return
        except Exception as e:
            attempts += 1
            log.warning("Failed downloading img for topic %s from %s: %s (attempt %d/%d)", tid, url, e, attempts, RETRY_MAX)
            time.sleep(backoff)
            backoff *= 2
    log.error("Exceeded maximum retries downloading image %s for topic %s", url, tid)

def proc_srcset(srcset: str, tdir: Path, rel: str, tid: int) -> str:
    parts = [e.strip() for e in srcset.split(",")]
    out = []
    for e in parts:
        seg = e.split()
        if not seg:
            continue
        orig = seg[0]
        fixed = fix_url(orig)
        fname = os.path.basename(urlparse(fixed).path)
        if not fname:
            log.warning("Empty filename in srcset for topic %s: %s", tid, fixed)
            continue
        dest = tdir / fname
        download_img(fixed, dest, tid)
        full = os.path.join(rel, fname).replace(os.sep, '/')
        out.append(f"{full} {seg[1]}" if len(seg) > 1 else full)
    return ", ".join(out)

def is_img_link(url: str) -> bool:
    return os.path.basename(urlparse(url).path).lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".webp"))

def remove_img_anchor(soup):
    # Remove anchors that wrap images.
    for a in soup.find_all("a"):
        if a.find("img"):
            a.replace_with(*a.contents)
    return soup

def proc_html(html, tdir: Path, rel: str, tid: int) -> str:
    soup = BeautifulSoup(html, "html.parser")
    cnt = 0
    for img in soup.find_all("img"):
        src = img.get("src")
        if src:
            src = fix_url(src)
            fname = os.path.basename(urlparse(src).path)
            if fname:
                dest = tdir / fname
                download_img(src, dest, tid)
                cnt += 1
                img["src"] = os.path.join(rel, fname).replace(os.sep, '/')
            else:
                log.warning("Empty filename in src for topic %s: %s", tid, src)
        if s := img.get("srcset"):
            img["srcset"] = proc_srcset(s, tdir, rel, tid)
    for a in soup.find_all("a"):
        href = a.get("href")
        if href:
            fixed = fix_url(href)
            if is_img_link(fixed):
                fname = os.path.basename(urlparse(fixed).path)
                if fname:
                    dest = tdir / fname
                    download_img(fixed, dest, tid)
                    cnt += 1
                    a["href"] = os.path.join(rel, fname).replace(os.sep, '/')
                    if a.string:
                        a.string.replace_with("")
                else:
                    log.warning("Empty filename in href for topic %s: %s", tid, fixed)
    remove_img_anchor(soup)
    log.debug("Processed %d images for topic %s", cnt, tid)
    return str(soup)

def slugify(s: str) -> str:
    s = re.sub(r'[^a-z0-9\s-]', '', s.strip().lower())
    return re.sub(r'[\s-]+', '-', s) or "untitled"

# --- Data models ---
@dataclass(frozen=True)
class PostTopic:
    id: int
    slug: str
    title: str
    category_id: int

@dataclass
class Post:
    id: int
    slug: str
    raw: dict

    def created_at(self) -> datetime.datetime:
        return datetime.datetime.fromisoformat(self.raw['created_at'].replace("Z", "+00:00"))

    def updated_at(self) -> datetime.datetime:
        return datetime.datetime.fromisoformat(self.raw['updated_at'].replace("Z", "+00:00"))

    def save(self, d: Path) -> None:
        """Save the post JSON to disk (archive)."""
        idstr = str(self.id).zfill(10)
        fn = f"{idstr}-{self.raw.get('username', 'anonymous')}-{self.raw.get('topic_slug', 'unknown')}.json"
        fn = truncate_fn(fn)
        folder = self.created_at().strftime('%Y-%m-%B')
        path = d / folder / fn
        # Only write if changed.
        if path.exists():
            try:
                ex = json.loads(path.read_text(encoding='utf-8'))
                if ex.get("updated_at") == self.raw.get("updated_at"):
                    log.debug("Post %s unchanged; skip saving.", self.id)
                    return
            except Exception as e:
                log.debug("Error reading %s: %s", path, e)
        path.parent.mkdir(parents=True, exist_ok=True)
        log.info("Saving post %s to %s", self.id, path)
        path.write_text(json.dumps(self.raw, indent=2), encoding='utf-8')

    @classmethod
    def from_json(cls, j: dict) -> 'Post':
        return cls(id=j['id'], slug=j.get('topic_slug', 'unknown'), raw=j)

@dataclass
class Topic:
    id: int
    slug: str
    title: str
    category_id: int
    created_at_str: str
    markdown: str = field(default="")  # initial markdown content

    def created_at(self) -> datetime.datetime:
        return datetime.datetime.fromisoformat(self.created_at_str.replace("Z", "+00:00"))

    def save_rendered(self, d: Path) -> Path:
        date_s = str(self.created_at().date())
        fn = f"{date_s}-{self.slug}-id{self.id}.md"
        fn = truncate_fn(fn)
        folder = self.created_at().strftime('%Y-%m-%B')
        path = d / folder / fn
        path.parent.mkdir(parents=True, exist_ok=True)
        log.info("Saving rendered topic %s to %s", self.id, path)
        path.write_text(self.markdown, encoding='utf-8')
        return path.relative_to(d.parent)

# --- API fetching for topics and posts ---
def fetch_topic_meta(site: str, topic_id: int) -> dict:
    url = f"{site}/t/{topic_id}.json"
    result = fetch_json(url)
    if result is None:
        log.warning("Topic metadata not found for topic %s", topic_id)
    return result

def fetch_single_post(site: str, post_id: int) -> dict:
    """
    Fetch a single post by post_id from the site.
    Logs detailed info upon a successful fetch.
    """
    url = f"{site}/posts/{post_id}.json"
    result = fetch_json(url)
    if result is None:
        log.warning("Post %s not found on site %s", post_id, site)
    else:
        # Log detailed post information if available
        username = result.get("username", "unknown")
        topic_slug = result.get("topic_slug", "unknown")
        created_at = result.get("created_at", "unknown time")
        log.debug("Fetched post %s: topic_slug='%s', username='%s', created_at='%s'", 
                  post_id, topic_slug, username, created_at)
        # Optionally, you can also log the whole JSON response or its size:
        log.debug("Post %s JSON size: %d bytes", post_id, len(json.dumps(result)))
    return result

# --- Rendering functions using fresh API post data ---
def render_topic(site: str, topic_id: int, tops_dir: Path, cats: dict) -> dict:
    """
    Render each post individually and append it immediately to the topic markdown file.
    This version fetches EVERY post in the topic (using additional API calls if needed),
    not just the first 20.
    """
    topic_meta = fetch_topic_meta(site, topic_id)
    if not topic_meta:
        log.warning("No metadata found for topic %s; skipping render.", topic_id)
        return None

    # Use the topic meta from /t/{topic_id}.json
    slug = topic_meta.get("slug", "unknown")
    title = topic_meta.get("title", "No Title")
    category_id = int(topic_meta.get("category_id", 0))
    created_at_str = topic_meta.get("created_at", datetime.datetime.now().isoformat())

    # Create assets dir for images.
    assets = tops_dir.parent / "assets" / "images" / f"{topic_id}"
    assets.mkdir(parents=True, exist_ok=True)
    folder = datetime.datetime.fromisoformat(created_at_str.replace("Z", "+00:00")).strftime('%Y-%m-%B')
    md_dir = tops_dir / folder
    rel_path = os.path.relpath(assets, md_dir)

    # Create or truncate the markdown topic file
    date_s = str(datetime.datetime.fromisoformat(created_at_str.replace("Z", "+00:00")).date())
    fn = f"{date_s}-{slug}-id{topic_id}.md"
    fn = truncate_fn(fn)
    topic_md_path = md_dir / fn
    topic_md_path.parent.mkdir(parents=True, exist_ok=True)
    log.info("Creating markdown file for topic %s at %s", topic_id, topic_md_path)
    # Write the topic title as header
    with topic_md_path.open(mode="w", encoding="utf8") as f:
        f.write(f"# {title}\n\n")

    conv = html2text.HTML2Text()
    conv.body_width = 0

    # ---- Modified section: Fetch ALL posts for the topic ----
    # Get posts from topic_meta (first 20 posts)
    posts_meta = topic_meta.get("post_stream", {}).get("posts", [])
    # Also get the full post stream (IDs) which might include extra post IDs
    full_stream = topic_meta.get("post_stream", {}).get("stream", [])
    # Identify extra post IDs that might not be in posts_meta
    # (Since posts_meta are typically the first 20 posts.)
    extra_ids = [pid for pid in full_stream if pid not in [p.get("id") for p in posts_meta]]
    log.debug("Topic %s: %d posts in initial load, %d extra IDs detected.", topic_id, len(posts_meta), len(extra_ids))

    # Fetch extras in chunks (say, 20 per request)
    n = 20
    if extra_ids:
        chunks = [extra_ids[i:i+n] for i in range(0, len(extra_ids), n)]
        for chunk in chunks:
            # Build query string with multiple post_ids[] parameters
            qs = "&".join([f"post_ids[]={pid}" for pid in chunk])
            posts_extra_url = f"{site}/t/{topic_id}/posts.json?{qs}"
            extra_response = fetch_json(posts_extra_url)
            if extra_response and "post_stream" in extra_response and "posts" in extra_response["post_stream"]:
                extra_posts = extra_response["post_stream"]["posts"]
                posts_meta.extend(extra_posts)
            else:
                log.warning("Failed fetching extra posts for topic %s with URL: %s", topic_id, posts_extra_url)

    # Sort posts by (for example) their post_number if available (to preserve original order)
    posts_meta.sort(key=lambda p: p.get("post_number", 0))
    # ---- End fetch-all posts section ----

    # Extract post IDs from the combined posts_meta
    post_ids = [post["id"] for post in posts_meta]
    log.debug("Processing a total of %d posts for topic %s", len(post_ids), topic_id)

    # Now process each post (as before)
    for post in posts_meta:
        try:
            post_id = post.get("id")
            log.debug("Processing post ID %s for topic %s", post_id, topic_id)
            # Create header for the post and fetch necessary dates
            cdt = datetime.datetime.fromisoformat(post.get("created_at").replace("Z", "+00:00"))
            udt = datetime.datetime.fromisoformat(post.get("updated_at", "").replace("Z", "+00:00")) if post.get("updated_at") else cdt
            hdr = (f"> **Post #{post.get('post_number', 0)} • {post.get('username', 'unknown')}**\n"
                   f"> Created: {cdt.strftime('%Y-%m-%d %H:%M')}\n"
                   f"> Updated: {udt.strftime('%Y-%m-%d %H:%M')}")
            cooked = post.get("cooked", "")
            proc = proc_html(cooked, assets, rel_path, topic_id)
            md_post = conv.handle(proc)

            # Clean up the markdown post
            clean_lines = []
            for l in md_post.splitlines():
                if re.search(r'\S+\s*\d+\s*[×x]\s*\d+\s+\d+(\.\d+)?\s*(KB|MB)$', l, flags=re.IGNORECASE):
                    continue
                clean_lines.append(l)
            md_post = "\n".join(clean_lines)
            md_post = re.sub(r'(\S+)\s*\d+\s*[×x]\s*\d+\s+\d+(\.\d+)?\s*(KB|MB)', r'\1', md_post, flags=re.IGNORECASE)

            section = f"<!-- ✦✦✦ POST START ✦✦✦ -->\n\n{hdr}\n\n{md_post}\n\n<!-- ✦✦✦ POST END ✦✦✦ -->\n\n"
            with topic_md_path.open(mode="a", encoding="utf8") as f:
                f.write(section)
            log.debug("Appended post #%s (ID %s) to topic markdown file", post.get("post_number", "?"), post_id)
            time.sleep(0.2)  # to ensure sequential API calls (if needed)
        except Exception as e:
            log.error("Error processing post %s: %s", post.get("id"), e)

    # After processing, read the file content and return the topic info.
    full_md = topic_md_path.read_text(encoding='utf8')
    topic_obj = Topic(
        id=topic_id,
        slug=slug,
        title=title,
        category_id=category_id,
        created_at_str=created_at_str,
        markdown=full_md,
    )
    rel_saved = topic_obj.save_rendered(tops_dir)  # This rewrites the file; that's acceptable.
    log.info("Rendered topic %s (%s) with %d posts", topic_obj.id, topic_obj.slug, len(post_ids))
    return {"id": topic_id, "title": title, "relative_path": str(rel_saved), "category": cats.get(category_id, "Uncategorized")}


# --- README update functions ---
TOC_PAT = re.compile(r"- $$(?P<title>.+?)$$$(?P<rel>.+?)$ <!-- id: (?P<id>\d+) -->")
def read_readme(root: Path):
    rp = root / "README.md"
    topics = {}
    if rp.exists():
        try:
            for l in rp.read_text(encoding="utf-8").splitlines():
                m = TOC_PAT.match(l.strip())
                if m:
                    tid = int(m.group("id"))
                    topics[tid] = {"id": tid, "title": m.group("title"), "relative_path": m.group("rel")}
        except Exception as e:
            log.error("Failed parsing README.md: %s", e)
    return topics

def append_readme(root: Path, ntop: dict):
    rp = root / "README.md"
    header = ["# Archived Discourse Topics", "", "## Table of Contents", ""]
    line = f"- [{ntop['title']}]({ntop['relative_path']}) <!-- id: {ntop['id']} -->"
    if rp.exists():
        try:
            lines = rp.read_text(encoding="utf-8").splitlines()
            try:
                idx = lines.index("## Table of Contents") + 1
                while idx < len(lines) and TOC_PAT.match(lines[idx].strip()):
                    idx += 1
                lines.insert(idx, line)
                newc = "\n".join(lines)
            except ValueError:
                newc = "\n".join(header + [line] + [""] + lines)
        except Exception as e:
            log.error("Error reading README.md: %s", e)
            newc = "\n".join(header + [line])
    else:
        newc = "\n".join(header + [line])
    try:
        rp.write_text(newc, encoding="utf-8")
        log.info("Updated README.md at %s", rp)
    except Exception as e:
        log.error("Failed writing README.md: %s", e)

def write_readme(site_dir: Path, tops: dict):
    rp = site_dir / "README.md"
    lines = ["# Archived Discourse Topics", "", "## Table of Contents", ""]
    group = {}
    for t in tops.values():
        group.setdefault(t.get("category", "Uncategorized"), []).append(t)
    for cat in sorted(group.keys()):
        lines.append(f"### {cat}")
        for t in sorted(group[cat], key=lambda x: x["id"]):
            lines.append(f"- [{t['title']}]({t['relative_path']}) <!-- id: {t['id']} -->")
        lines.append("")
    try:
        rp.write_text("\n".join(lines), encoding='utf-8')
        log.info("Finalized README.md at %s", rp)
    except Exception as e:
        log.error("Failed writing final README.md: %s", e)

def update_meta(meta_file: Path, meta: dict):
    log.debug("Updating meta: %s", meta)
    meta_file.write_text(json.dumps(meta, indent=2), encoding='utf-8')

# --- New function to fetch topic IDs using list topics endpoint ---
def fetch_topic_ids(site: str) -> list:
    """
    Fetch topic IDs from each category using /c/{slug}/{id}.json endpoint.
    Returns a list of topic IDs.
    """
    topic_ids = set()
    # Get categories data
    cats_js = fetch_json(f"{site}/categories.json")
    if not cats_js:
        log.error("Failed to fetch categories from %s", site)
        return list(topic_ids)
    cats = cats_js.get("category_list", {}).get("categories", [])
    for cat in cats:
        cat_id = cat.get("id")
        cat_slug = cat.get("slug")
        if not cat_id or not cat_slug:
            continue
        url = f"{site}/c/{cat_slug}/{cat_id}.json"
        js = fetch_json(url)
        if not js:
            log.warning("Failed to fetch topics for category %s using %s", cat_id, url)
            continue
        topics = js.get("topic_list", {}).get("topics", [])
        for t in topics:
            tid = t.get("id")
            if tid:
                topic_ids.add(tid)
    log.info("Fetched %d topic IDs from %s", len(topic_ids), site)
    return list(topic_ids)

# --- Main processing of a site ---
def process_site(site: str, base: Path):
    parsed = urlparse(site)
    sname = parsed.hostname or site.replace("https://", "").replace("http://", "").split('/')[0]
    log.info("Processing site: %s", site)
    sdir = base / sname
    posts_d = sdir / 'posts'
    tops_d = sdir / 'rendered-topics'
    posts_d.mkdir(parents=True, exist_ok=True)
    tops_d.mkdir(parents=True, exist_ok=True)
    meta_file = sdir / '.metadata.json'
    meta = {"archived_topic_ids": {}, "topics": {}}
    
    if meta_file.exists():
        try:
            meta = json.loads(meta_file.read_text())
        except Exception as e:
            log.error("Failed reading meta for %s: %s", site, e)

    rendered_topics = meta.get("topics", {})
    topic_ids_to_process = fetch_topic_ids(site)
    log.debug("Topic IDs to process: %s", topic_ids_to_process)

    rend_all = {}
    
    with ThreadPoolExecutor(max_workers=10) as executor:
        # fetch_cats is needed to provide the category mapping
        future_to_tid = {executor.submit(render_topic, site, tid, tops_d, fetch_cats(site)): tid for tid in topic_ids_to_process}
        
        for future in as_completed(future_to_tid):
            tid = future_to_tid[future]
            try:
                rendered = future.result()
                if rendered:
                    rend_all[rendered["id"]] = rendered
                    meta.setdefault("topics", {})[str(rendered["id"])] = rendered
                    meta.setdefault("archived_topic_ids", {})[str(rendered["id"])] = {
                        "rendered_at": datetime.datetime.now().isoformat()
                    }
                    update_meta(meta_file, meta)
                    append_readme(sdir, rendered)
            except Exception as e:
                log.error("Error rendering topic %s: %s", tid, e)

    if rend_all:
        write_readme(sdir, rend_all)
    else:
        log.info("Site %s: No topics rendered; skipping final README.", site)
    update_meta(meta_file, meta)

def fetch_cats(site: str) -> dict:
    """Fetch topic categories using the /categories.json endpoint for now."""
    try:
        js = fetch_json(site + "/categories.json")
        cats = js.get("category_list", {}).get("categories", [])
        mapping = {int(c["id"]): c["name"] for c in cats}
        log.info("Fetched %d categories from %s", len(mapping), site)
        return mapping
    except Exception as e:
        log.error("Failed fetch categories from %s: %s", site, e)
        return {}

def main() -> None:
    params = args()
    base = params.target_dir if isinstance(params.target_dir, Path) else Path(params.target_dir)
    base.mkdir(parents=True, exist_ok=True)
    sites = parse_sites(params.urls)
    if not sites:
        log.error("No valid sites provided. Exiting.")
        sys.exit(1)
    for s in sites:
        process_site(s, base)

if __name__ == "__main__":
    main()