# DYNBOT APP 
import os
import re, unicodedata
import uuid
import logging
import json
import sys
import random
import pprint

from flask import Flask, request, jsonify, render_template
from sqlalchemy.sql import func

from openai import OpenAI
from rag_service import generate_rag_addendum
from db import SessionLocal, init_db
from utils.prompt_manager import build_messages_payload
from services.referenz import get_referenzstory
from models import (
    Session as ChatSession, User, Case, Message, Document, Topic, Tenant
)

from time import perf_counter
from contextlib import contextmanager

@contextmanager
def timed(label, bag):
    t0 = perf_counter()
    try:
        yield
    finally:
        bag[label] = int((perf_counter() - t0) * 1000)  # ms

print("📍 being here 1", flush=True)

# --- DIRECT KEY USAGE (no .env dependency required) ---
FALLBACK_OPENAI_KEY = "sk-p"  # <-- your temp key

# --- Summary threshold & canonical base ---
SUMMARY_AT = int(os.getenv("SUMMARY_AT", "2"))  # was 3; now default 2
def is_summary_turn(turn: int) -> bool:
    return turn >= SUMMARY_AT  # robust if a turn is retried

CANONICAL_BASE_URL = os.getenv("CANONICAL_BASE_URL", "https://ai-1a.com").rstrip("/")

def _resolve_api_key():
    # If you still export OPENAI_API_KEY in the environment, we'll prefer it.
    # Otherwise we fall back to the hardcoded testing key above.
    k = (os.getenv("OPENAI_API_KEY") or FALLBACK_OPENAI_KEY or "").strip()
    if not k:
        raise RuntimeError("No API key available (OPENAI_API_KEY and FALLBACK_OPENAI_KEY are empty).")
    return k

_client = None
def get_openai_client() -> OpenAI:
    global _client
    if _client is None:
        api_key = _resolve_api_key()
        logging.getLogger().info("🔑 OPENAI_API_KEY ends with: ...%s", api_key[-4:])
        _client = OpenAI(api_key=api_key)
    return _client

print(f"🗄 DB user: {os.getenv('DB_USER')}")
for var in ("HTTP_PROXY","HTTPS_PROXY","ALL_PROXY"):
    if os.getenv(var):
        logging.getLogger().warning("Proxy var %s is set; may cause 401s: %s", var, os.getenv(var))

app = Flask(__name__)
init_db()

# @app.route("/")
# def hello():
#    return "✅ Dynbot backend is running."

def _slug(s: str) -> str:
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    s = re.sub(r"[^a-zA-Z0-9]+", "-", s).strip("-").lower()
    return s or "section"


def render_referenzstory_html(match: dict) -> str:
    """
    match = {
      "story_id": int, "title": "Referenzstory_XX", "score": float,
      "sections": [{"section": "Titel", "content_html": "<p>…</p>"}, ...]
    }
    """
    if not match or not match.get("sections"):
        return ""
    parts = ['<div class="ref-story">']
    # (sections already come in the desired order from retrieval)
    for sec in match["sections"]:
        name = sec.get("section") or ""
        html = sec.get("content_html") or ""
        parts.append(f'<section class="ref-section ref-{_slug(name)}">{html}</section>')
    parts.append("</div>")
    return "".join(parts)

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/init-session", methods=["POST"])
def init_session():
    db = SessionLocal()
    data = request.get_json()

    tenant_id = data.get("tenant_id", 1)
    user_id = data.get("user_id", 1)
    session_token = data.get("session_token", str(uuid.uuid4()))

    # Ensure user exists
    user = db.query(User).filter_by(id=user_id).first()
    if not user:
        user = User(id=user_id, tenant_id=tenant_id, username="Default User", email="default@example.com")
        db.add(user)
        db.commit()

    # Always create a new session
    new_session = ChatSession(
        tenant_id=tenant_id,
        user_id=user_id,
        session_token=session_token,
    )
    db.add(new_session)
    db.commit()
    db.refresh(new_session)

    # Create a case before closing the session
    case = Case(
        tenant_id=tenant_id,
        user_id=user_id,
        session_id=new_session.id
    )
    db.add(case)
    db.commit()
    db.refresh(case)

    # ✅ Access attributes before db.close()
    response = {
        "session_id": new_session.id,
        "session_token": session_token,
        "case_id": case.id
    }

    db.close()
    return jsonify(response)


@app.route("/cases", methods=["POST"])
def create_case():
    db = SessionLocal()
    tenant_id = request.json.get("tenant_id", 1)
    new = Case(tenant_id=tenant_id, title=request.json.get("title",""))
    db.add(new); db.commit(); db.refresh(new)
    db.close()
    return jsonify({"case_id": new.id}), 201

@app.route("/cases/<int:case_id>/history", methods=["GET"])
def get_history(case_id):
    db = SessionLocal()
    msgs = db.query(Message).filter_by(case_id=case_id).order_by(Message.timestamp).all()
    # If you later add FollowupQuestion to models, import it and use it here.
    db.close()
    return jsonify({
        "messages": [{"role":m.role,"content":m.content,"timestamp":m.timestamp.isoformat()} for m in msgs]
    })

@app.route("/cases/<int:case_id>/messages", methods=["POST"])
def post_message(case_id):
    mode = request.args.get("mode", "primary")
    db = SessionLocal()
    timings = {}
    t_all = perf_counter()
    try:
        data = request.get_json() or {}
        user_text = data.get("content")
        button_choice = data.get("button_choice", "")
        tenant_id = data.get("tenant_id", 1)

        if not user_text or not tenant_id:
            return jsonify({"error": "bad_request", "detail": "Missing 'content' or 'tenant_id'"}), 400

        if mode == "rag":
            # 1) your existing RAG addendum
            with timed("rag_addendum", timings):
                try:
                    rag_res = generate_rag_addendum(user_text, tenant_id=tenant_id, top_k=3)
                    rag_addendum = (rag_res.get("addendum") or "").strip()
                except Exception:
                    rag_addendum = ""
                    timings["rag_error"] = 1

            # 2) NEW: best matching Referenzstory
            referenzstory_html = ""
            referenzstory_meta = {}
            with timed("referenzstory", timings):
                try:
                    matches = get_referenzstory(user_text, tenant_id=tenant_id, top_k=1)
                    if matches:
                        top = matches[0]
                        referenzstory_html = render_referenzstory_html(top)
                        # optional metadata if you ever want to show it
                        referenzstory_meta = {
                            "story_id": top.get("story_id"),
                            "title": top.get("title"),
                            "score": top.get("score"),
                        }
                except Exception:
                    referenzstory_html = ""
                    timings["referenz_error"] = 1

            # 3) respond with both
            payload = {
                "reply": rag_addendum,                    # keep your RAG text
                "referenzstory_html": referenzstory_html, # NEW: drop-in HTML for the story
                "referenzstory": referenzstory_meta,      # optional metadata
                "timings": timings
            }
            resp = jsonify(payload)
            total = int((perf_counter() - t_all) * 1000)
            resp.headers["Server-Timing"] = (
                f"rag_addendum;dur={timings.get('rag_addendum',0)}, "
                f"referenzstory;dur={timings.get('referenzstory',0)}, "
                f"total;dur={total}"
            )
            return resp, 200

        # PRIMARY
        with timed("db_store_user", timings):
            user_msg = Message(case_id=case_id, role="user", content=user_text, button_choice=button_choice)
            db.add(user_msg); db.commit(); db.refresh(user_msg)

        with timed("build_payload", timings):
            messages_payload = build_messages_payload(db, case_id, tenant_id)

        openai_client = get_openai_client()
        with timed("openai", timings):
            resp = openai_client.chat.completions.create(
                model="gpt-4o-mini",
                temperature=1.2,
                messages=messages_payload,
                tools=[{
                    "type": "function",
                    "function": {
                        "name": "extract_followups",
                        "description": "Suggest 6–8 follow-up questions based on the assistant's reply.",
                        "parameters": {"type": "object", "properties": {"questions": {"type":"array","items":{"type":"string"}}}, "required": ["questions"]}
                    }
                }],
                tool_choice="auto"
            )

        choice = resp.choices[0]
        msg = getattr(choice, "message", None)
        finish = getattr(choice, "finish_reason", None)

        if finish == "tool_calls" and msg and getattr(msg, "tool_calls", None):
            tool_call = msg.tool_calls[0]
            args = json.loads(getattr(tool_call.function, "arguments", "{}") or "{}")
            followups = args.get("questions", [])
            ai_text = msg.content or "Vielen Dank für Ihre Auswahl. Hier sind mögliche nächste Schritte oder Zusammenfassungen."
        else:
            followups = []
            ai_text = (msg.content if msg else "") or "(no text reply provided)"

        # ---------- Build assistant text & choices ----------
        ai_text = truncate_numbered_blocks(ai_text, max_items=3)
        choices = re.findall(r"\*\*(.*?)\*\*", ai_text)
        if len(choices) > 3:
            choices = random.sample(choices, 3)

        # ---------- Turn counting BEFORE deciding summary/off-topic ----------
        past_messages = db.query(Message).filter_by(case_id=case_id).order_by(Message.timestamp).all()
        turn = len([m for m in past_messages if m.role == "user"])
        is_summary = is_summary_turn(turn)

        # ---------- Quick-win OFF-TOPIC gate ----------
        # Heuristic: no follow-up **choices** AND not summary ⇒ off-topic (e.g., "Warum ist der Himmel blau?")
        off_topic = (not is_summary) and (len(choices) == 0)

        if off_topic:
            # Remove the just-saved user turn so it doesn't advance iteration/summary logic
            try:
                with timed("db_delete_offtopic_user", timings):
                    db.delete(user_msg)
                    db.commit()
            except Exception:
                db.rollback()

            payload = {
                "reply": ai_text,      # the generic model reply you already generated
                "followups": [],       # none in off-topic
                "choices": [],         # none -> frontend shows start topics again
                "is_summary": False,
                "timings": timings
            }
            resp = jsonify(payload)
            total = int((perf_counter() - t_all) * 1000)
            server_timing = ", ".join([f"{k};dur={v}" for k, v in timings.items()])
            resp.headers["Server-Timing"] = f'{server_timing}, total;dur={total}' if server_timing else f'total;dur={total}'
            return resp, 200

        # ---------- Normal path (PM-relevant or still useful) ----------

        quote_text = ""

        if is_summary:
            with timed("summary_prep", timings):
                intro_doc = db.query(Document).filter_by(tenant_id=tenant_id, title="summary_intro_1").first()
                quotes = db.query(Document).filter_by(tenant_id=tenant_id, title="quote").all()
                intro_text = intro_doc.content if intro_doc else ""
                quote_text = random.choice(quotes).content if quotes else ""

                button_choices = [m.button_choice for m in past_messages if m.role == "user"]
                intro_filled = (intro_text or "").format(
                    iteration1=(button_choices[0] if len(button_choices) > 0 else ""),
                    iteration2=(button_choices[1] if len(button_choices) > 1 else ""),
                    iteration3=(button_choices[2] if len(button_choices) > 2 else "")
                )

                # ⬇️ Keep the quote OUT of ai_text. We'll render it after RAG on the client.
                ai_text = f"{intro_filled}\n\n{ai_text}"

            choices = []  # no inline choices in summary
        else:
            if len(choices) > 3:
                choices = random.sample(choices, 3)

        # Store assistant *after* enrichment so DB matches UI
        with timed("db_store_ai", timings):
            ai_msg = Message(case_id=case_id, role="assistant", content=ai_text)
            db.add(ai_msg); db.commit(); db.refresh(ai_msg)

        payload = {
            "reply": ai_text,
            "followups": followups,
            "choices": ([] if is_summary else choices),
            "is_summary": is_summary,
            "footer_quote": (quote_text if is_summary else ""),  # ⬅️ NEW
            "timings": timings
        }
        resp = jsonify(payload)
        total = int((perf_counter() - t_all) * 1000)
        server_timing = ", ".join([f"{k};dur={v}" for k,v in timings.items()])
        resp.headers["Server-Timing"] = f'{server_timing}, total;dur={total}' if server_timing else f'total;dur={total}'
        return resp

    except Exception as e:
        return jsonify({"error": "server_error", "detail": str(e)}), 500
    finally:
        db.close()

def truncate_numbered_blocks(text: str, max_items: int = 3) -> str:
    import re
    # Suche Start der nummerierten Liste
    m = re.search(r'(?m)^\s*\d+\.\s+', text)
    if not m:
        return text

    head = text[:m.start()].rstrip("\n")
    body = text[m.start():]

    # Blöcke: "N. ..." bis vor die nächste Nummer oder Textende
    blocks = re.findall(r'(?ms)^\s*(\d+)\.\s+(.*?)(?=^\s*\d+\.|\Z)', body)
    if not blocks:
        return text

    # Nimm die ersten max_items und baue sie wieder zusammen
    out = []
    for idx, (_num, content) in enumerate(blocks[:max_items], start=1):
        out.append(f"{idx}. {content.strip()}")   # optional: Nummern neu 1..N

    return (head + "\n\n" if head else "") + "\n\n".join(out)

    


@app.route("/tenants", methods=["POST"])
def create_tenant():
    db = SessionLocal()
    name = request.json["name"]
    t = models.Tenant(name=name)
    db.add(t); db.commit(); db.refresh(t)
    db.close()
    return jsonify({"tenant_id": t.id}), 201


@app.route("/topics/random", methods=["GET"])
def get_random_topics():
    db = SessionLocal()
    tenant_id = 1  # or dynamic if implemented
    raw_topics = db.query(Topic)\
        .filter_by(tenant_id=tenant_id)\
        .order_by(func.random())\
        .limit(3)\
        .all()

    # Replace {topic} in question field
    results = []
    for t in raw_topics:
        filled_question = t.question.replace("{topic}", t.topic)
        results.append({
            "topic": t.topic,
            "question": filled_question
        })

    db.close()
    return jsonify(results)    


if __name__ == '__main__':
    app.run(
        host='0.0.0.0',
        port=5002,
        ssl_context=(
            '/var/www/html/decompression/certs/fullchain.pem',
            '/var/www/html/decompression/certs/privkey.pem'
        ),
        debug=True
    )

