#!/usr/bin/env python3
"""
Archive Discourse posts and render topics to Markdown from multiple sites.

This script downloads posts from one or more Discourse servers via their APIs,
archives new posts as JSON files (skipping those already saved or archived),
renders topics to Markdown files for each batch of posts concurrently, and updates
a metadata file after each post is indexed.

Usage:
  ./archive.py --urls https://forum.hackliberty.org,https://forum.qubes-os.org --target-dir ./archive
"""

import argparse
import concurrent.futures
import functools
import json
import logging
import os
import sys
import time
import urllib.request
import datetime
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlparse

# Set up logging. If the 'rich' module is available, it will be used.
loglevel = 'DEBUG' if os.environ.get('DEBUG') else 'INFO'
try:
    from rich.logging import RichHandler
    logging.basicConfig(level=loglevel, datefmt="[%X]", handlers=[RichHandler()])
except ImportError:
    logging.basicConfig(level=loglevel)
log = logging.getLogger('archive')

# Argument parser (cached for re-use)
parser = argparse.ArgumentParser(
    description='Archive topics from one or more Discourse installations and render to markdown')
parser.add_argument(
    '--urls',
    help='Comma-separated URLs of Discourse servers (for example: "https://forum.hackliberty.org,https://forum.qubes-os.org")',
    default=os.environ.get('DISCOURSE_URLS', 'https://forum.hackliberty.org'))
parser.add_argument(
    '--debug', action='store_true', default=os.environ.get('DEBUG', False))
parser.add_argument(
    '-t', '--target-dir', help='Target base directory for the archives',
    default=Path(os.environ.get('TARGET_DIR', './archive')))

@functools.cache
def args():
    return parser.parse_args()

def parse_sites(urls_string: str) -> list:
    """Return a list of cleaned-up site URLs."""
    return [url.strip().rstrip('/') for url in urls_string.split(',') if url.strip()]

def http_get(site_url: str, path: str) -> str:
    """Simple HTTP GET with exponential backoff."""
    full_url = f"{site_url}{path}"
    log.debug("HTTP GET %s", full_url)
    backoff = 3
    while True:
        try:
            with urllib.request.urlopen(full_url) as response:
                return response.read().decode()
        except Exception as e:
            log.debug("Error fetching %s: %s -- Retrying in %d seconds", full_url, e, backoff)
            time.sleep(backoff)
            backoff *= 2
            if backoff >= 256:
                log.exception("Rate limit or unrecoverable error for %s", full_url)
                sys.exit(1)

def http_get_json(site_url: str, path: str) -> dict:
    """Fetch URL contents from a specific site and decode JSON."""
    try:
        return json.loads(http_get(site_url, path))
    except json.JSONDecodeError:
        log.warning("Unable to decode JSON response from %r", path)
        raise

# ----- Helper: Truncate Filename -----
def truncate_filename(filename: str, max_length: int = 255) -> str:
    """
    Truncates the file name to a maximum length (default 255 characters).
    It preserves the file extension.
    """
    if len(filename) <= max_length:
        return filename

    # Split into stem and suffix(s)
    p = Path(filename)
    stem = p.stem
    suffix = "".join(p.suffixes)
    # Calculate the maximum allowed length for the stem after accounting for the suffix
    max_stem_length = max_length - len(suffix)
    if max_stem_length <= 0:
        # In the unlikely event that the suffix itself is longer than max_length,
        # simply return a truncated version of the entire filename.
        return filename[:max_length]
    truncated_stem = stem[:max_stem_length]
    return truncated_stem + suffix

# ----- Data Models -----

@dataclass(frozen=True)
class PostTopic:
    id: int
    slug: str
    title: str

@dataclass(frozen=True)
class Post:
    id: int
    slug: str
    raw: dict

    def get_created_at(self) -> datetime.datetime:
        return datetime.datetime.fromisoformat(self.raw['created_at'].replace("Z", "+00:00"))

    def save(self, dir: Path):
        """Save the raw JSON post to disk if not already archived."""
        idstr = str(self.id).zfill(10)
        filename = f"{idstr}-{self.raw.get('username', 'anonymous')}-{self.raw.get('topic_slug', 'unknown')}.json"
        # Truncate file name if necessary.
        filename = truncate_filename(filename)
        folder_name = self.get_created_at().strftime('%Y-%m-%B')
        full_path = dir / folder_name / filename

        if full_path.exists():
            log.debug("Post %s already saved, skipping", self.id)
            return

        full_path.parent.mkdir(parents=True, exist_ok=True)
        log.info("Saving post %s to %s", self.id, full_path)
        full_path.write_text(json.dumps(self.raw, indent=2), encoding='utf-8')

    def get_topic(self) -> PostTopic:
        return PostTopic(
            id=self.raw.get('topic_id', self.id),
            slug=self.raw.get('topic_slug', self.slug),
            title=self.raw.get('topic_title', self.raw.get('title', 'No Title')),
        )

    @classmethod
    def from_json(cls, j: dict) -> 'Post':
        return cls(
            id=j['id'],
            slug=j.get('topic_slug', 'unknown'),
            raw=j,
        )

@dataclass(frozen=True)
class Topic:
    id: int
    slug: str
    raw: dict
    markdown: str

    def get_created_at(self) -> datetime.datetime:
        return datetime.datetime.fromisoformat(self.raw['created_at'].replace("Z", "+00:00"))

    def save_rendered(self, dir: Path):
        """
        Save the rendered Markdown topic to disk.
        Filename built from creation date, slug, and id.
        Truncate the filename if it is too long for the operating system.
        """
        date_str = str(self.get_created_at().date())
        filename = f"{date_str}-{self.slug}-id{self.id}.md"
        # Truncate if necessary
        filename = truncate_filename(filename)
        folder_name = self.get_created_at().strftime('%Y-%m-%B')
        full_path = dir / folder_name / filename
        full_path.parent.mkdir(parents=True, exist_ok=True)
        log.info("Saving rendered topic %s to %s", self.id, full_path)
        rendered_markdown = f"# {self.raw.get('title', 'No Title')}\n\n{self.markdown}"
        full_path.write_text(rendered_markdown, encoding='utf-8')

    @classmethod
    def from_json(cls, t: dict, markdown: str) -> 'Topic':
        slug = t.get('slug') or t.get('topic_slug') or "unknown"
        return cls(
            id=t.get('id', 0),
            slug=slug,
            raw=t,
            markdown=markdown,
        )

# ----- Helper Functions -----

def update_metadata(metadata_file: Path, metadata: dict):
    """Writes the metadata as a JSON file to disk."""
    log.debug("Updating metadata: %s", metadata)
    metadata_file.write_text(json.dumps(metadata, indent=2), encoding='utf-8')

def render_topic(site_url: str, topic: PostTopic, topics_dir: Path):
    """
    Render a single topic to Markdown.
    Fetches the topic JSON and its raw Markdown (including additional pages if available).
    """
    try:
        log.info("Fetching topic %s JSON from %s", topic.id, site_url)
        topic_data = http_get_json(site_url, f"/t/{topic.id}.json")
    except Exception as e:
        log.warning("Failed to fetch topic JSON for topic %s: %s", topic.id, e)
        return

    log.info("Fetching raw markdown for topic %s from %s", topic.id, site_url)
    body = http_get(site_url, f"/raw/{topic.id}")
    if not body:
        log.warning("Could not retrieve markdown body for topic %s", topic.id)
        return

    # Assemble additional pages if available.
    page_num = 2
    while True:
        more_body = http_get(site_url, f"/raw/{topic.id}?page={page_num}")
        if not more_body:
            break
        body += f"\n{more_body}"
        page_num += 1

    try:
        topic_obj = Topic.from_json(topic_data, body)
    except Exception as e:
        log.error("Failed to create Topic object for topic %s: %s", topic.id, e)
        return

    topic_obj.save_rendered(topics_dir)
    log.info("Saved rendered topic %s (%s)", topic_obj.id, topic_obj.slug)

def render_topics_concurrently(site_url: str, topics: dict, topics_dir: Path, max_workers: int = 8):
    """
    Render multiple topics concurrently.
    topics: a dictionary of topic_id -> PostTopic.
    """
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(render_topic, site_url, topic, topics_dir) for topic in topics.values()]
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()
            except Exception as exc:
                log.error("A topic generated an exception: %s", exc)

def process_site(site_url: str, base_target_dir: Path):
    """
    Archive posts and render topics for a single site.
    Each site gets its own subdirectory (named for its hostname) inside the base target directory,
    and its own metadata file.
    """
    parsed = urlparse(site_url)
    site_name = parsed.hostname or site_url.replace("https://", "").replace("http://", "").split('/')[0]
    log.info("Processing site: %s", site_url)
    site_target_dir = base_target_dir / site_name
    posts_dir = site_target_dir / 'posts'
    topics_dir = site_target_dir / 'rendered-topics'
    posts_dir.mkdir(parents=True, exist_ok=True)
    topics_dir.mkdir(parents=True, exist_ok=True)
    metadata_file = site_target_dir / '.metadata.json'

    # Load stored metadata if it exists.
    metadata = {}
    last_sync_date = None
    archived_post_ids = set()
    if metadata_file.exists():
        try:
            metadata = json.loads(metadata_file.read_text())
            if "last_sync_date" in metadata:
                last_sync_date = datetime.datetime.fromisoformat(metadata.get('last_sync_date'))
            if "archived_post_ids" in metadata:
                archived_post_ids = set(int(x) for x in metadata.get('archived_post_ids', []))
        except Exception as e:
            log.error("Failed to read/parse metadata file for %s: %s", site_url, e)

    if last_sync_date:
        # Step back one day to catch updates.
        last_sync_date -= datetime.timedelta(days=1)
        log.info("Resyncing posts from %s for %s", last_sync_date.isoformat(), site_url)

    posts_json = http_get_json(site_url, '/posts.json')
    posts = posts_json.get('latest_posts', [])
    last_id = None
    max_created_at = last_sync_date
    should_stop = False

    while posts:
        log.info("Processing %d posts for %s", len(posts), site_url)
        topics_to_render = {}  # unique topics in this batch
        for json_post in posts:
            try:
                post = Post.from_json(json_post)
            except Exception as e:
                log.warning("Failed to deserialize post %s: %s", json_post, e)
                continue

            if post.id in archived_post_ids:
                log.debug("Post %s already archived, skipping", post.id)
                continue

            post_created = post.get_created_at()
            if last_sync_date is not None and post_created < last_sync_date:
                log.info("Post %s is older than last_sync_date; stopping batch for %s.", post.id, site_url)
                should_stop = True
                break

            post.save(posts_dir)
            archived_post_ids.add(post.id)
            last_id = post.id

            topic = post.get_topic()
            topics_to_render[topic.id] = topic

            if max_created_at is None or post_created > max_created_at:
                max_created_at = post_created

            metadata['last_sync_date'] = max_created_at.isoformat() if max_created_at else None
            metadata['archived_post_ids'] = sorted(archived_post_ids)
            update_metadata(metadata_file, metadata)

        # Render topics concurrently for the current batch.
        if topics_to_render:
            log.info("Rendering %d topics concurrently for %s.", len(topics_to_render), site_url)
            render_topics_concurrently(site_url, topics_to_render, topics_dir, max_workers=8)

        if should_stop:
            log.info("Stopping pagination loop based on sync date for %s.", site_url)
            break

        if last_id is None or last_id <= 1:
            log.info("No valid last_id found for %s. Ending pagination loop.", site_url)
            break

        time.sleep(5)
        posts = http_get_json(site_url, f'/posts.json?before={last_id - 1}').get('latest_posts', [])
        # Fallback if posts come empty (step back gradually)
        while not posts and last_id and last_id >= 0:
            last_id -= 49
            posts = http_get_json(site_url, f'/posts.json?before={last_id}').get('latest_posts', [])
            time.sleep(1)

def main() -> None:
    # Parse command-line parameters.
    parameters = args()
    base_target_dir = parameters.target_dir
    if not isinstance(base_target_dir, Path):
        base_target_dir = Path(base_target_dir)
    base_target_dir.mkdir(parents=True, exist_ok=True)

    sites = parse_sites(parameters.urls)
    if not sites:
        log.error("No valid sites provided. Exiting.")
        sys.exit(1)

    # Process each site.
    for site_url in sites:
        process_site(site_url, base_target_dir)

if __name__ == "__main__":
    main()