#!/usr/bin/env python3
"""
load_referenzstories.py

Usage:
  python load_referenzstories.py \
    --tenant 1 \
    --einordnung ./einordnung_bullets.txt \
    --stories ./referenzstories_tagged.txt
"""

import argparse
import os
import re
import html
import psycopg
import json
import zipfile
from psycopg.rows import dict_row
from docx import Document as DocxDocument
from dotenv import load_dotenv, find_dotenv

# load .env from project root (find_dotenv walks up directories)
load_dotenv(find_dotenv(), override=False)

# Optional sanitizer (recommended): pip install bleach
try:
    import bleach
except ImportError:
    bleach = None

# Optional: OpenAI embeddings
from openai import OpenAI

# ---- OpenAI key resolution (same idea as app.py) ----
FALLBACK_OPENAI_KEY = os.getenv("FALLBACK_OPENAI_KEY", "")

def _resolve_api_key(cli_key: str | None = None) -> str:
    k = (cli_key or os.getenv("OPENAI_API_KEY") or FALLBACK_OPENAI_KEY or "").strip()
    if not k:
        raise RuntimeError("No OpenAI API key provided. "
                           "Set --openai-key or OPENAI_API_KEY or FALLBACK_OPENAI_KEY.")
    return k

_openai = None
def init_openai_client(cli_key: str | None = None):
    global _openai
    if _openai is None:
        from openai import OpenAI
        _openai = OpenAI(api_key=_resolve_api_key(cli_key))


EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small")

# ---------- Tag patterns (canonical) ----------
REF_RE      = re.compile(r"^<<<REF:\s*(\d{1,4})\s*>>>$", re.I)
SECTION_RE  = re.compile(r"^<<<SECTION:\s*([A-Za-zÄÖÜäöüß0-9 /_.-]+)\s*>>>$", re.I)
END_REF_RE  = re.compile(r"^<<<END-REF>>>$", re.I)
END_ALIAS_RE= re.compile(r"^<<<END>>>$", re.I)  # accepted alias (warn)

BULLET_RE   = re.compile(r"^\s*(•|-)\s+(.*)$")

# ---------- Minimal HTML conversion ----------
ALLOWED_TAGS = [
    "p", "br", "ul", "ol", "li", "blockquote", "strong", "b", "em", "i", "u"
]
ALLOWED_ATTRS = {}

def _runs_to_markdown(p):
    """Join runs, preserving **bold**; ignore italics/underline to keep it simple."""
    parts = []
    for r in p.runs:
        t = r.text or ""
        if not t:
            continue
        parts.append(f"**{t}**" if r.bold else t)
    return "".join(parts).strip()

def _para_is_bullet(p):
    """Detect if a DOCX paragraph is a list item."""
    try:
        return p._p.pPr is not None and p._p.pPr.numPr is not None
    except Exception:
        return False

def parse_tagged_stories_docx(path: str):
    """Return [{'story_id': int, 'sections': {name: [lines...]}}] from a tagged DOCX."""
    doc = DocxDocument(path)
    stories = []
    cur_ref = None
    sections = {}
    cur_section = None
    section_lines = []

    def commit_section():
        nonlocal cur_section, section_lines, sections
        if cur_section is not None:
            sections[cur_section] = list(section_lines)
        cur_section, section_lines = None, []

    def commit_story():
        nonlocal cur_ref, sections
        if cur_ref is not None:
            stories.append({"story_id": cur_ref, "sections": sections})
        cur_ref, sections = None, {}

    for p in doc.paragraphs:
        raw = (p.text or "").strip()

        # TAG detection must use raw text (not runs-with-bold)
        if raw.upper().startswith("<<<REF:") and raw.endswith(">>>"):
            # new story starts; close previous
            if cur_section is not None:
                commit_section()
            if cur_ref is not None:
                commit_story()
            # parse id
            try:
                cur_ref = int(raw.split("<<<REF:",1)[1].split(">>>",1)[0].strip())
            except Exception:
                raise ValueError(f"Bad REF tag: {raw}")
            continue

        if raw.upper() == "<<<END-REF>>>" or raw.upper() == "<<<END>>>":
            if cur_section is not None:
                commit_section()
            commit_story()
            continue

        if raw.upper().startswith("<<<SECTION:") and raw.endswith(">>>"):
            if cur_section is not None:
                commit_section()
            name = raw.split("<<<SECTION:",1)[1].split(">>>",1)[0].strip()
            cur_section = name
            continue

        # Content paragraphs
        if cur_section is None:
            # ignore content outside sections
            continue

        # Preserve bold in content using **…**
        line = _runs_to_markdown(p)
        if _para_is_bullet(p) and line:
            line = f"• {line}"
        section_lines.append(line)

    # EOF flush
    if cur_section is not None:
        commit_section()
    if cur_ref is not None:
        commit_story()

    return stories

def parse_einordnung_docx(path: str):
    """
    DOCX file that contains multiple blocks:
      <<<REF: N>>>
      (bulleted paragraphs)
      <<<END-REF>>> or <<<END>>>
    Returns [{"story_id": N, "bullets": ["...", "..."]}, ...]
    """
    from docx import Document as DocxDocument
    doc = DocxDocument(path)

    out, cur_ref = [], None
    bullets: list[str] = []
    raw_lines_accum: list[str] = []  # fallback if bullets not detected

    # Accept lots of bullet glyphs + tabs/NBSP after them
    NBSP = "\u00A0"
    BULLET_RX = re.compile(r"^\s*([\u2022\u25CF\u25E6\u2023\u2043\u2219\-\–\*•·])[\s\u00A0\t]+")

    def commit():
        nonlocal cur_ref, bullets, raw_lines_accum
        if cur_ref is not None:
            final_bullets = [b for b in bullets if b.strip()]
            if not final_bullets:
                # Fallback: treat each non-empty paragraph as a bullet
                final_bullets = [ln for ln in raw_lines_accum if ln.strip()]
            out.append({"story_id": cur_ref, "bullets": final_bullets})
        cur_ref, bullets, raw_lines_accum = None, [], []

    for p in doc.paragraphs:
        raw = (p.text or "").strip()

        # Tag lines use the raw text (not runs)
        if raw.upper().startswith("<<<REF:") and raw.endswith(">>>"):
            commit()
            try:
                cur_ref = int(raw.split("<<<REF:", 1)[1].split(">>>", 1)[0].strip())
            except Exception:
                raise ValueError(f"Bad REF tag: {raw}")
            continue

        if raw.upper() in ("<<<END-REF>>>", "<<<END>>>"):
            commit()
            continue

        if cur_ref is None:
            continue  # ignore stuff outside blocks

        # Build a line preserving **bold** from runs
        txt = _runs_to_markdown(p).strip()
        if not txt and raw:
            txt = raw  # just in case

        # Remember all lines as a fallback
        if txt:
            raw_lines_accum.append(txt)

        # bullet if: actual list paragraph OR manual bullet glyph at start
        raw_no_nbsp = raw.lstrip(NBSP)
        is_bullet = _para_is_bullet(p) or BULLET_RX.match(raw_no_nbsp) or BULLET_RX.match(txt)
        if is_bullet:
            cleaned = BULLET_RX.sub("", txt).strip()
            if cleaned:
                bullets.append(cleaned)

    commit()
    return out



def to_html_block(section_name: str, raw_lines: list[str]) -> str:
    """
    Convert a section's raw text lines into minimal, clean HTML.
    - Blank line => new <p>
    - Bullets (• or -) => grouped <ul><li>…</li></ul>
    - 'Zitat' section => wrapped in <blockquote>
    - Pass through **bold** -> <strong> if present, keep safe
    """
    # Normalize bullet glyphs & strip trailing CR/LFs
    lines = [ln.rstrip("\r\n") for ln in raw_lines]

    blocks = []
    ul_items = []

    def flush_ul():
        nonlocal ul_items, blocks
        if ul_items:
            lis = "".join(f"<li>{_escape_inline(li)}</li>" for li in ul_items)
            blocks.append(f"<ul>{lis}</ul>")
            ul_items = []

    def flush_p(buf):
        if buf:
            text = " ".join(buf).strip()
            if text:
                blocks.append(f"<p>{_escape_inline(text)}</p>")

    p_buf = []
    for ln in lines:
        if ln.strip() == "":
            flush_ul()
            flush_p(p_buf)
            p_buf = []
            continue

        m = BULLET_RE.match(ln)
        if m:
            flush_p(p_buf)
            p_buf = []
            ul_items.append(m.group(2).strip())
        else:
            # regular text line
            p_buf.append(ln)

    # tail flush
    flush_ul()
    flush_p(p_buf)

    html_block = "".join(blocks)

    # Wrap Zitat as blockquote
    if section_name.strip().lower() == "zitat":
        html_block = f"<blockquote>{html_block}</blockquote>"

    # Sanitize (if bleach available)
    if bleach:
        html_block = bleach.clean(
            html_block,
            tags=ALLOWED_TAGS,
            attributes=ALLOWED_ATTRS,
            strip=True,
        )
    return html_block

def _escape_inline(text: str) -> str:
    """
    Very light inline handling:
    - escape HTML
    - naive **bold** -> <strong> … </strong>
    """
    text = html.escape(text)
    # restore minimal markdown **bold** after escaping
    text = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", text)
    return text

def strip_to_plain_text(raw_lines: list[str]) -> str:
    """Plain text for embeddings (keeps line breaks)."""
    return "\n".join(ln.rstrip("\r\n") for ln in raw_lines).strip()

# ---------- Parsing ----------
def parse_tagged_stories(path: str) -> list[dict]:
    """
    Parse a file with multiple stories tagged by <<<REF: N>>> and <<<SECTION: NAME>>>.
    Returns list of dicts:
      {
        "story_id": int,
        "sections": { "Einordnung": [lines...], "Beitrag": [lines...], ... }
      }
    """
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    stories = []
    cur_ref = None
    cur_section = None
    section_lines: list[str] = []
    sections: dict[str, list[str]] = {}

    def commit_section():
        nonlocal cur_section, section_lines, sections
        if cur_section is not None:
            # store a COPY of lines to avoid mutations
            sections[cur_section] = list(section_lines)
        cur_section = None
        section_lines = []

    def commit_story():
        nonlocal cur_ref, sections
        if cur_ref is not None:
            stories.append({"story_id": cur_ref, "sections": sections})
        cur_ref = None

    for raw in lines:
        line = raw.rstrip("\n")

        if END_ALIAS_RE.match(line):
            # accept alias, but no-op here; story ends only when END-REF or new REF
            continue

        if m := REF_RE.match(line):
            # starting a new story closes previous story
            if cur_section is not None:
                commit_section()
            if cur_ref is not None:
                commit_story()
            cur_ref = int(m.group(1))
            sections = {}
            continue

        if END_REF_RE.match(line):
            if cur_section is not None:
                commit_section()
            commit_story()
            continue

        if m := SECTION_RE.match(line):
            # starting a new section closes previous section
            if cur_section is not None:
                commit_section()
            cur_section = m.group(1).strip()
            section_lines = []
            continue

        # content
        if cur_section is not None:
            section_lines.append(line)
        else:
            # ignore content outside sections
            pass

    # EOF flush
    if cur_section is not None:
        commit_section()
    if cur_ref is not None:
        commit_story()

    return stories

def parse_einordnung_file(path: str) -> list[dict]:
    """
    Parse Einordnung-only file with blocks:
      <<<REF: N>>>
      • bullet
      • bullet
      <<<END>>>
    Returns list of {"story_id": int, "bullets": [str, ...]}
    """
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    out = []
    cur_ref = None
    bullets = []

    def commit():
        nonlocal cur_ref, bullets
        if cur_ref is not None:
            out.append({"story_id": cur_ref, "bullets": [b for b in bullets if b.strip()]})
        cur_ref = None
        bullets = []

    for raw in lines:
        line = raw.rstrip("\n")
        if m := REF_RE.match(line):
            commit()
            cur_ref = int(m.group(1))
            bullets = []
            continue
        if END_REF_RE.match(line) or END_ALIAS_RE.match(line):
            commit()
            continue
        if m := BULLET_RE.match(line):
            bullets.append(m.group(2).strip())
        else:
            # Non-bullet lines are ignored in Einordnung-only file
            pass

    commit()
    return out

def _is_docx(path: str) -> bool:
    # also detects .DOCX and cases where the extension is missing but file is a ZIP (docx)
    try:
        return path.lower().endswith(".docx") or (os.path.isfile(path) and zipfile.is_zipfile(path))
    except Exception:
        return path.lower().endswith(".docx")

def load_einordnung(path: str):
    if _is_docx(path):
        print(f"[loader] Einordnung: using DOCX parser → {path}")
        return parse_einordnung_docx(path)
    print(f"[loader] Einordnung: using TXT parser → {path}")
    return parse_einordnung_file(path)

def load_stories(path: str):
    if _is_docx(path):
        print(f"[loader] Stories: using DOCX parser → {path}")
        return parse_tagged_stories_docx(path)
    print(f"[loader] Stories: using TXT parser → {path}")
    return parse_tagged_stories(path)

# ---------- DB + embeddings ----------
def get_db_conn():
    return psycopg.connect(
        host=os.getenv("DB_HOST", "localhost"),
        port=os.getenv("DB_PORT", "5432"),
        dbname=os.getenv("DB_NAME", "dynbot_db"),
        user=os.getenv("DB_USER", "chatbot_user"),
        password=os.getenv("DB_PASSWORD", "tLgPX5dQJ^LgXttV3Q3PfJ7V3YJ*X9Q7JL"),
        row_factory=dict_row
    )

_openai = None
def embed(text: str) -> list[float]:
    global _openai
    if _openai is None:
        init_openai_client()  # fallback if someone calls before main()
    t = (text or "").strip()
    if not t:
        return [0.0] * 1536
    resp = _openai.embeddings.create(model=EMBED_MODEL, input=t)
    return resp.data[0].embedding    

UPSERT_SQL = """
INSERT INTO documents
  (tenant_id, title, content, section, story_id, bullet_index, content_text, content_html, embedding_json)
VALUES
  (%(tenant_id)s, %(title)s, %(content)s, %(section)s, %(story_id)s, %(bullet_index)s, %(content_text)s, %(content_html)s, %(embedding_json)s)
ON CONFLICT (tenant_id, title, section, story_id, bullet_index)
DO UPDATE SET
  content        = EXCLUDED.content,
  content_text   = EXCLUDED.content_text,
  content_html   = EXCLUDED.content_html,
  embedding_json = EXCLUDED.embedding_json;
"""

def upsert_row(cur, row):
    cur.execute(UPSERT_SQL, row)

# ---------- Main ingest ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--tenant", type=int, required=True)
    ap.add_argument("--einordnung", type=str, required=True,
                    help="einordnung_bullets.txt or .docx")
    ap.add_argument("--stories", type=str, required=True,
                    help="referenzstories_tagged.txt or .docx")
    # ap.add_argument("--openai-key", type=str, required=False, default=None,
    #                help="Override OpenAI API key (else OPENAI_API_KEY or FALLBACK_OPENAI_KEY)")

    args = ap.parse_args()

    # init OpenAI client once (uses env vars)
    init_openai_client()

    ein_blocks = load_einordnung(args.einordnung)
    stories    = load_stories(args.stories)
    print(f"[loader] Parsed Einordnung blocks: {len(ein_blocks)}")
    print(f"[loader] Parsed stories: {len(stories)}")
    print(f"[loader] Parsed Einordnung blocks: {len(ein_blocks)}")
    for b in ein_blocks:
        print(f"[loader] story {b['story_id']:02d}: {len(b['bullets'])} Einordnung bullets")

    # Map for quick lookups
    story_map = {s['story_id']: s['sections'] for s in stories}

    with get_db_conn() as conn, conn.cursor() as cur:
        cur.execute("select current_database(), current_user, inet_server_addr(), inet_server_port()")
        row = cur.fetchone()  # dict
        print(f"[loader] Connected to DB={row['current_database']} as {row['current_user']} on {row['inet_server_addr']}:{row['inet_server_port']}")

        cur.execute("""
        select count(*) as cnt
        from information_schema.columns
        where table_name='documents'
            and column_name in ('section','story_id','bullet_index','content_text','content_html','embedding_json')
        """)
        colrow = cur.fetchone()  # dict
        print(f"[loader] documents has new cols (count of matches): {colrow['cnt']}")

        for block in ein_blocks:
            sid = block["story_id"]
            title = f"Referenzstory_{sid:02d}"

            # --- Full Einordnung ---
            ein_lines = (story_map.get(sid, {}).get("Einordnung")
                         or [f"• {b}" for b in block["bullets"]])

            ein_plain = strip_to_plain_text(ein_lines)
            ein_html  = to_html_block("Einordnung", ein_lines)
            ein_embed = embed(ein_plain)
            upsert_row(cur, {
                "tenant_id": args.tenant,
                "title": title,
                "content": ein_plain,     
                "section": "Einordnung",
                "story_id": sid,
                "bullet_index": -1,
                "content_text": ein_plain,
                "content_html": ein_html,
                "embedding_json": json.dumps(ein_embed),
            })

            # --- Einordnung bullets ---
            for idx, bullet in enumerate(block["bullets"]):
                bullet_plain = bullet.strip()
                if not bullet_plain:
                    continue
                bullet_html  = to_html_block("Einordnung", [f"• {bullet_plain}"])
                bullet_embed = embed(bullet_plain)

                upsert_row(cur, {
                    "tenant_id": args.tenant,
                    "title": title,
                    "content": bullet_plain,            # <-- bullet text
                    "section": "Einordnung",
                    "story_id": sid,
                    "bullet_index": idx,                # <-- critical
                    "content_text": bullet_plain,       # <-- bullet text
                    "content_html": bullet_html,        # <-- bullet HTML
                    "embedding_json": json.dumps(bullet_embed),
                })

            # --- Other sections ---
            if sid in story_map:
                for sec_name, lines in story_map[sid].items():
                    if sec_name == "Einordnung":
                        continue
                    sec_plain = strip_to_plain_text(lines)
                    sec_html  = to_html_block(sec_name, lines)
                    sec_embed = embed(sec_plain)
                    upsert_row(cur, {
                        "tenant_id": args.tenant,
                        "title": title,
                        "content": sec_plain,          # <— add this
                        "section": sec_name,
                        "story_id": sid,
                        "bullet_index": -1,
                        "content_text": sec_plain,
                        "content_html": sec_html,
                        "embedding_json": json.dumps(sec_embed),
                    })
        
        cur.execute("SELECT COUNT(*) AS cnt FROM documents WHERE tenant_id=%s AND story_id IS NOT NULL", (args.tenant,))
        print("[loader] rows for tenant", args.tenant, ":", cur.fetchone()['cnt'])

        cur.execute("SELECT COUNT(*) AS cnt FROM documents WHERE tenant_id=%s AND section ILIKE 'Einordnung' AND bullet_index >= 0", (args.tenant,))
        print("[loader] Einordnung bullet rows:", cur.fetchone()['cnt'])

        conn.commit()
    print("✅ Ingest complete.")


if __name__ == "__main__":
    main()
