# app.py — DynBot (tenant-scoped)

import os
import re
import uuid
import logging
import json
import random

from time import perf_counter
from contextlib import contextmanager

from flask import Flask, g, request, jsonify, render_template, abort
from sqlalchemy.sql import func, text
from openai import OpenAI

from db import SessionLocal, init_db
from rag_service import generate_rag_addendum
from utils.prompt_manager import build_messages_payload
from concurrent.futures import ThreadPoolExecutor
from models import (
    Tenant,
    Session as ChatSession,
    User,
    Case,
    Message,
    Document,
    Topic,
)

# ------------------------------------------------------------------------------
# App + DB init
# ------------------------------------------------------------------------------
app = Flask(__name__)
init_db()

# ------------------------------------------------------------------------------
# Helpers for redundancy cleanup
# ------------------------------------------------------------------------------
def _is_summary_prompt(user_text: str) -> bool:
    t = (user_text or "").strip().lower()
    return t.startswith("jetzt fassen wir die konversation") or "zusammenfassung" in t

def _is_summary_turn(db, case_id: int, user_text: str) -> bool:
    """
    Keep your UI-driven summary detection, plus an optional turn-based heuristic.
    Prefer prompt heuristic to avoid DB count fragility.
    """
    if _is_summary_prompt(user_text):
        return True
    # If you also want count-based as a backup (3rd user turn):
    try:
        turns = (
            db.query(Message)
              .join(Case, Message.case_id == Case.id)
              .filter(Case.id == case_id, Case.tenant_id == g.tenant_id, Message.role == "user")
              .count()
        )
        return turns >= 3
    except Exception:
        return False

def _build_full_cycle_summary(db, case_id: int) -> str:
    # your existing “full history → OpenAI summary” logic (kept inline in route below)
    ...

def _build_rag_payload(rag_seed: str, tenant_id: int):
    """
    Return the SAME shape you currently return from /mode=rag so the frontend can
    drop it straight into renderRagInnerHTML(...).
    Example keys: { reply, referenzstory_html, referenzstory_sections, ... }
    """
    return generate_rag_addendum(rag_seed, tenant_id=tenant_id, top_k=3)

# ------------------------------------------------------------------------------
# Other Helpers
# ------------------------------------------------------------------------------
def current_tenant_id() -> int:
    """Single source of truth: bind from systemd Environment=TENANT_ID=4"""
    return int(os.getenv("TENANT_ID", "4"))

@app.before_request
def bind_tenant_and_session():
    g.tenant_id = current_tenant_id()
    g.db = SessionLocal()
    app.logger.info(f"[TENANT] {g.tenant_id} {request.method} {request.path}")

@app.teardown_request
def teardown_session(exc):
    db = getattr(g, "db", None)
    if db is not None:
        try:
            if exc:
                db.rollback()
        finally:
            db.close()

def ensure_case(case_id: int) -> Case:
    """Ensure the case belongs to the current tenant, otherwise 404."""
    case = g.db.query(Case).filter_by(id=case_id, tenant_id=g.tenant_id).first()
    if not case:
        abort(404)
    return case

def get_openai_client() -> OpenAI:
    """Use env OPENAI_API_KEY only. Never hardcode keys."""
    api_key = (os.getenv("OPENAI_API_KEY") or "").strip()
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY not set")
    logging.getLogger().info("🔑 OPENAI_API_KEY ends with …%s", api_key[-4:])
    return OpenAI(api_key=api_key)

@contextmanager
def timed(label, bag):
    t0 = perf_counter()
    try:
        yield
    finally:
        bag[label] = int((perf_counter() - t0) * 1000)  # ms

def truncate_numbered_blocks(text: str, max_items: int = 3) -> str:
    """Keep only the first N numbered list items (1., 2., …) to avoid overlong replies."""
    m = re.search(r'(?m)^\s*\d+\.\s+', text)
    if not m:
        return text
    head = text[:m.start()].rstrip("\n")
    body = text[m.start():]
    blocks = re.findall(r'(?ms)^\s*(\d+)\.\s+(.*?)(?=^\s*\d+\.|\Z)', body)
    if not blocks:
        return text
    out = []
    for idx, (_num, content) in enumerate(blocks[:max_items], start=1):
        out.append(f"{idx}. {content.strip()}")
    return (head + "\n\n" if head else "") + "\n\n".join(out)

def _server_timing_header(timings: dict, total_ms: int) -> str:
    parts = [f"{k};dur={v}" for k, v in timings.items()]
    parts.append(f"total;dur={total_ms}")
    return ", ".join(parts)

# ------------------------------------------------------------------------------
# Routes
# ------------------------------------------------------------------------------
@app.route("/")
def index():
    return render_template("index.html")

@app.get("/__whoami")
def whoami():
    return {"tenant_id": g.tenant_id}

@app.route("/tenants", methods=["POST"])
def create_tenant():
    db = g.db
    t = Tenant(name=request.json["name"])
    db.add(t); db.commit(); db.refresh(t)
    return jsonify({"tenant_id": t.id}), 201

@app.route("/topics/random", methods=["GET"])
def get_random_topics():
    db = g.db
    tenant_id = g.tenant_id
    try:
        # --- Primary (ORM) path ---
        raw_topics = (
            db.query(Topic)
              .filter(Topic.tenant_id == tenant_id)    # explicit, not filter_by
              .order_by(func.random())
              .limit(3)
              .all()
        )
    except Exception:
        # Log the real reason
        app.logger.exception("[topics/random] ORM query failed; falling back to raw SQL")

        # --- Fallback (raw SQL) path ---
        rows = db.execute(
            text("""
                SELECT topic, question
                FROM topics
                WHERE tenant_id = :tid
                ORDER BY RANDOM()
                LIMIT 3
            """),
            {"tid": tenant_id}
        ).fetchall()

        # Normalize result
        results = [{"topic": r[0], "question": (r[1] or "").replace("{topic}", r[0] or "")} for r in rows]
        return jsonify(results)

    # Normal ORM result
    results = []
    for t in raw_topics:
        results.append({
            "topic": t.topic,
            "question": (t.question or "").replace("{topic}", t.topic or "")
        })
    return jsonify(results)


@app.route("/init-session", methods=["POST"])
def init_session():
    db = g.db
    data = request.get_json(silent=True) or {}
    tenant_id = g.tenant_id

    # Use provided token if present; otherwise create one
    session_token = (data.get("session_token") or str(uuid.uuid4())).strip()

    # Avoid cross-tenant clashes: identify user by (tenant_id, username)
    username = (data.get("username") or "Default User").strip()
    email = (data.get("email") or "default@example.com").strip()

    # --- get or create user ---
    user = (
        db.query(User)
          .filter_by(tenant_id=tenant_id, username=username)
          .first()
    )
    if not user:
        user = User(tenant_id=tenant_id, username=username, email=email)
        db.add(user)
        db.commit()
        db.refresh(user)

    # --- get or create session (idempotent by token+tenant) ---
    chat_sess = (
        db.query(ChatSession)
          .filter_by(tenant_id=tenant_id, session_token=session_token)
          .first()
    )
    if not chat_sess:
        chat_sess = ChatSession(
            tenant_id=tenant_id,
            user_id=user.id,
            session_token=session_token,
        )
        db.add(chat_sess)
        db.commit()
        db.refresh(chat_sess)

    # --- get or create case (one case per session) ---
    case = (
        db.query(Case)
          .filter_by(tenant_id=tenant_id, session_id=chat_sess.id)
          .order_by(Case.id.asc())
          .first()
    )
    if not case:
        case = Case(
            tenant_id=tenant_id,
            user_id=user.id,
            session_id=chat_sess.id,
            title=(data.get("case_title") or None)
        )
        db.add(case)
        db.commit()
        db.refresh(case)

    return jsonify({
        "session_id": chat_sess.id,
        "session_token": session_token,
        "case_id": case.id
    })


@app.route("/cases", methods=["POST"])
def create_case():
    db = g.db
    tenant_id = g.tenant_id
    new_case = Case(tenant_id=tenant_id, title=(request.json.get("title") or "")[:255])
    db.add(new_case); db.commit(); db.refresh(new_case)
    return jsonify({"case_id": new_case.id}), 201


@app.route("/cases/<int:case_id>/history", methods=["GET"])
def get_history(case_id):
    db = g.db
    ensure_case(case_id)
    # Join via Case to ensure tenant scoping for Message (which has no tenant_id)
    msgs = (
        db.query(Message)
          .join(Case, Message.case_id == Case.id)
          .filter(Case.id == case_id, Case.tenant_id == g.tenant_id)
          .order_by(Message.timestamp)
          .all()
    )
    return jsonify({
        "messages": [
            {"role": m.role, "content": m.content, "timestamp": m.timestamp.isoformat()}
            for m in msgs
        ]
    })


@app.route("/cases/<int:case_id>/messages", methods=["POST"])
def post_message(case_id):
    mode = request.args.get("mode", "primary")  # "primary" | "rag"
    db = g.db
    ensure_case(case_id)

    timings = {}
    t_all = perf_counter()
    quote_text = ""

    try:
        data = request.get_json(silent=True) or {}
        user_text = (data.get("content") or "").strip()
        button_choice = (data.get("button_choice") or "").strip()
        tenant_id = g.tenant_id

        if not user_text:
            return jsonify({"error": "empty"}), 400

        # ------------------- RAG branch -------------------
        if mode == "rag":
            # 🚀 Fast-skip if it's a summary prompt/turn (frontend *shouldn't* call, but double-guard)
            if _is_summary_turn(db, case_id, user_text):
                total = int((perf_counter() - t_all) * 1000)
                resp = jsonify({"reply": "", "timings": {"rag_skip": total}})
                resp.headers["Server-Timing"] = _server_timing_header({"rag_skip": total}, total)
                return resp, 200

            with timed("rag_addendum", timings):
                try:
                    rag_res = generate_rag_addendum(user_text, tenant_id=tenant_id, top_k=3)
                    rag_addendum = (rag_res.get("addendum") or rag_res.get("reply") or "").strip()
                except Exception:
                    rag_addendum = ""
                    timings["rag_error"] = 1

            resp = jsonify({"reply": rag_addendum, "timings": timings})
            total = int((perf_counter() - t_all) * 1000)
            resp.headers["Server-Timing"] = _server_timing_header(timings, total)
            return resp, 200

        # ------------------- PRIMARY branch -------------------

        # 1) Persist the user turn (as before)
        with timed("db_store_user", timings):
            user_msg = Message(case_id=case_id, role="user",
                               content=user_text, button_choice=button_choice)
            db.add(user_msg); db.commit(); db.refresh(user_msg)

        # 2) Determine summary phase EARLY (before any OpenAI call)
        with timed("history_fetch", timings):
            # Use prompt-based detection with turn-count as backup
            is_summary = _is_summary_turn(db, case_id, user_text)

            # Also collect past messages for summary build (last few only)
            past_messages = (
                db.query(Message)
                  .join(Case, Message.case_id == Case.id)
                  .filter(Case.id == case_id, Case.tenant_id == tenant_id)
                  .order_by(Message.timestamp)
                  .all()
            )

        # 3) SUMMARY fast-path: summary + RAG in parallel, and return BOTH in primary
        if is_summary:
            with timed("summary_prep", timings):
                # Path reconstruction from persisted turns
                user_turns = [m for m in past_messages if m.role == "user"]
                original_question = user_turns[0].content if user_turns else user_text
                button_choices = [m.button_choice for m in user_turns if m.button_choice]

                free_texts = []
                for m in user_turns[1:]:
                    txt = (m.content or "").strip()
                    if txt and txt != m.button_choice:
                        free_texts.append(txt)

                # Tenant intro + quote
                try:
                    with timed("tenant_intro", timings):
                        quotes = (
                            db.query(Document)
                              .filter(Document.tenant_id == tenant_id, Document.kind == "quote")
                              .order_by(func.random())
                              .limit(5)
                              .all()
                        )
                        intro_doc = (
                            db.query(Document)
                              .filter(Document.tenant_id == tenant_id, Document.kind == "summary_intro")
                              .order_by(func.random())
                              .limit(1)
                              .first()
                        )
                        intro_text = intro_doc.content if intro_doc else ""
                        quote_text = random.choice(quotes).content if quotes else ""
                        intro_filled = (intro_text or "").format(
                            iteration1=(button_choices[0] if len(button_choices) > 0 else ""),
                            iteration2=(button_choices[1] if len(button_choices) > 1 else ""),
                            iteration3=(button_choices[2] if len(button_choices) > 2 else "")
                        )
                except Exception:
                    intro_filled = ""
                    quote_text = ""

                # Prepare immutable inputs for both threads
                summary_system = (
                    "You are writing a concise, session-level executive summary. "
                    "Touch each of the selected topics in order, connect them logically, "
                    "and end with 3–5 crisp recommendations."
                )
                parts = [
                    f"Originale Frage: {original_question}",
                    f"Auswahl/Themenpfad: {', '.join(button_choices) or '(keine)'}",
                ]
                if free_texts:
                    parts.append("Zusätzliche Hinweise/Präzisierungen:\n- " + "\n- ".join(free_texts))
                summary_user = "\n\n".join(parts)
                rag_seed = " ".join([original_question] + button_choices + free_texts)

                model_for_summary = os.getenv("OPENAI_MODEL_SUMMARY", "gpt-4o-mini")
                tenant_for_rag = tenant_id

            def _run_summary():
                t0 = perf_counter()
                client = get_openai_client()
                resp = client.chat.completions.create(
                    model=model_for_summary,
                    temperature=0.5,
                    messages=[
                        {"role": "system", "content": summary_system},
                        {"role": "user",   "content": summary_user},
                    ],
                )
                txt = (resp.choices[0].message.content or "").strip()
                timings["openai_summary"] = int((perf_counter() - t0) * 1000)
                return txt or "(leere Zusammenfassung)"

            def _run_rag():
                t0 = perf_counter()
                try:
                    res = generate_rag_addendum(rag_seed, tenant_id=tenant_for_rag, top_k=3)
                except Exception:
                    res = {"reply": "", "error": "rag_failed"}
                timings["rag_addendum"] = int((perf_counter() - t0) * 1000)
                return res

            with ThreadPoolExecutor(max_workers=2) as ex:
                fut_summary = ex.submit(_run_summary)
                fut_rag = ex.submit(_run_rag)
                ai_text = fut_summary.result()
                rag_payload = fut_rag.result()

            ai_text = f"{intro_filled}\n\n{ai_text}"

            with timed("db_store_ai", timings):
                ai_msg = Message(case_id=case_id, role="assistant", content=ai_text)
                db.add(ai_msg); db.commit(); db.refresh(ai_msg)

            payload = {
                "reply": ai_text,
                "followups": [],
                "choices": [],
                "is_summary": True,
                "footer_quote": (quote_text or ""),
                # 👇 IMPORTANT: Frontend will see this and NOT call /mode=rag
                "rag_included": True,
                "rag": rag_payload,
                "timings": timings
            }
            resp = jsonify(payload)
            total = int((perf_counter() - t_all) * 1000)
            resp.headers["Server-Timing"] = _server_timing_header(timings, total)
            return resp, 200

        # 4) Normal (non-summary) turn: build payload → OpenAI → parse → store
        MATH_RULES = (
            "Format all mathematical expressions as LaTeX. "
            "Use $$...$$ for display equations and \\(...\\) for inline. "
            "Use \\times (not ×), \\cdot, \\frac{a}{b}, exponents as ^{…}, subscripts as _{…}. "
            "Do not output plain-text formulas like 'Pn=Pn-1×(1+r)-M'."
        )
        with timed("build_payload", timings):
            messages_payload = build_messages_payload(db, case_id, tenant_id)
        messages_payload.insert(0, {"role": "system", "content": MATH_RULES})

        openai_client = get_openai_client()
        with timed("openai", timings):
            resp = openai_client.chat.completions.create(
                model=os.getenv("OPENAI_MODEL_PRIMARY", "gpt-4o-mini"),
                temperature=1.2,
                messages=messages_payload,
                tools=[{
                    "type": "function",
                    "function": {
                        "name": "extract_followups",
                        "description": "Suggest 6–8 follow-up questions based on the assistant's reply.",
                        "parameters": {
                            "type": "object",
                            "properties": {"questions": {"type": "array", "items": {"type": "string"}}},
                            "required": ["questions"]
                        }
                    }
                }],
                tool_choice="auto"
            )

        choice = resp.choices[0]
        msg_obj = getattr(choice, "message", None)
        finish = getattr(choice, "finish_reason", None)

        if finish == "tool_calls" and msg_obj and getattr(msg_obj, "tool_calls", None):
            tool = msg_obj.tool_calls[0]
            if tool.function.name == "extract_followups":
                try:
                    tool_args = json.loads(tool.function.arguments or "{}")
                except Exception:
                    tool_args = {"questions": []}
                followups = tool_args.get("questions", [])[:8]
                ai_text = (msg_obj.content or "").strip()
            else:
                followups = []
                ai_text = (msg_obj.content or "").strip()
        else:
            followups = []
            ai_text = (msg_obj.content if msg_obj else "") or "(no text reply provided)"

        # Trim numbered lists to 3 items and extract **choices**
        ai_text = truncate_numbered_blocks(ai_text, max_items=3)
        choices = re.findall(r"\*\*(.*?)\*\*", ai_text)
        if len(choices) > 3:
            choices = random.sample(choices, 3)

        # Off-topic gate (retain your existing behavior)
        off_topic = (len(choices) == 0)
        if off_topic:
            try:
                with timed("db_delete_offtopic_user", timings):
                    db.delete(user_msg); db.commit()
            except Exception:
                db.rollback()
            payload = {
                "reply": ai_text,
                "followups": [],
                "choices": [],
                "is_summary": False,
                "timings": timings
            }
            resp = jsonify(payload)
            total = int((perf_counter() - t_all) * 1000)
            resp.headers["Server-Timing"] = _server_timing_header(timings, total)
            return resp, 200

        # Persist assistant turn
        with timed("db_store_ai", timings):
            ai_msg = Message(case_id=case_id, role="assistant", content=ai_text)
            db.add(ai_msg); db.commit(); db.refresh(ai_msg)

        payload = {
            "reply": ai_text,
            "followups": followups,
            "choices": choices,
            "is_summary": False,
            "timings": timings
        }
        resp = jsonify(payload)
        total = int((perf_counter() - t_all) * 1000)
        resp.headers["Server-Timing"] = _server_timing_header(timings, total)
        return resp, 200

    except Exception as e:
        app.logger.exception("Error in /messages")
        return jsonify({"error": "server_error", "detail": str(e)}), 500


# ------------------------------------------------------------------------------
# Boot
# ------------------------------------------------------------------------------
if __name__ == '__main__':
    app.run(
        host="0.0.0.0",
        port=5004,
        ssl_context=(
            "/var/www/html/decompression/certs/fullchain.pem",
            "/var/www/html/decompression/certs/privkey.pem",
        ),
        debug=True
    )
