# referenz_retrieval.py (fallback: no pgvector)
import os, json, math
import psycopg
from psycopg.rows import dict_row
from openai import OpenAI

EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small")

def get_db_conn():
    return psycopg.connect(
        host=os.getenv("DB_HOST", "localhost"),
        port=os.getenv("DB_PORT", "5432"),
        dbname=os.getenv("DB_NAME", "dynbot_db"),
        user=os.getenv("DB_USER", "chatbot_user"),
        password=os.getenv("DB_PASSWORD", "tLgPX5dQJ^LgXttV3Q3PfJ7V3YJ*X9Q7JL"),
        row_factory=dict_row
    )

_openai = None
def embed_query(text: str) -> list[float]:
    global _openai
    if _openai is None:
        _openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    resp = _openai.embeddings.create(model=EMBED_MODEL, input=(text or "").strip())
    return resp.data[0].embedding

def cosine_sim(a, b):
    # a, b are lists of floats (same length)
    dot = 0.0; na = 0.0; nb = 0.0
    for x, y in zip(a, b):
        dot += x*y; na += x*x; nb += y*y
    if na == 0 or nb == 0:
        return 0.0
    return dot / (math.sqrt(na) * math.sqrt(nb))

def find_matching_referenzstories(challenge_text: str, tenant_id: int, top_k: int = 3):
    """
    Python-side similarity:
    1) Fetch all Einordnung bullets for tenant
    2) Score with cosine_sim
    3) Roll-up by story_id (take max), bonus for multiple close hits
    4) Re-rank by full Einordnung similarity as tie-break
    5) Return top_k with all sections (HTML)
    """
    qvec = embed_query(challenge_text)

    FETCH_BULLETS_SQL = """
    SELECT id, story_id, title, content_text, embedding_json
    FROM documents
    WHERE tenant_id = %s
      AND section ILIKE 'Einordnung'
      AND bullet_index >= 0
    """
    FETCH_FULL_SQL = """
    SELECT embedding_json
    FROM documents
    WHERE tenant_id = %s
      AND section ILIKE 'Einordnung'
      AND bullet_index = -1
      AND story_id = %s
    LIMIT 1
    """

    FETCH_SECTIONS_SQL = """
    SELECT section, bullet_index, content_html
    FROM documents
    WHERE tenant_id = %(tenant_id)s
    AND title     = %(title)s
    AND story_id  = %(story_id)s
    AND NOT (section ILIKE 'Einordnung' AND bullet_index >= 0)
    ORDER BY
    CASE
        WHEN section ILIKE 'Titel' THEN 0
        WHEN section ILIKE 'Subtitel' THEN 1
        WHEN section ILIKE 'Branche' THEN 2
        WHEN section ILIKE 'Beschreibung-Titel' THEN 3
        WHEN section ILIKE 'Beschreibung' THEN 4
        WHEN section ILIKE 'Beschreibung-Highlight' THEN 5
        WHEN section ILIKE 'Beschreibung-Finish' THEN 6
        WHEN section ILIKE 'Einordnung' THEN 7
        WHEN section ILIKE 'Beitrag' THEN 8
        WHEN section ILIKE 'Zitat' THEN 9
        ELSE 99
    END, section;
    """

    with get_db_conn() as conn, conn.cursor() as cur:
        cur.execute(FETCH_BULLETS_SQL, (tenant_id,))
        bullets = cur.fetchall()

        # 1) Score bullets
        per_story = {}  # story_id -> {"title":..., "best":float, "close_hits":int}
        for row in bullets:
            emb = []
            if row["embedding_json"]:
                try:
                    emb = json.loads(row["embedding_json"])
                except Exception:
                    emb = []
            sim = cosine_sim(qvec, emb) if emb else 0.0
            sid = row["story_id"]
            rec = per_story.get(sid)
            if not rec:
                rec = {"title": row["title"], "best": sim, "close_hits": 1 if sim >= 0.60 else 0}
            else:
                rec["best"] = max(rec["best"], sim)
                if sim >= 0.60:
                    rec["close_hits"] += 1
            per_story[sid] = rec

        # Early exit if nothing
        if not per_story:
            return []

        # 2) Tie-break with full Einordnung similarity
        ranked = []
        for sid, rec in per_story.items():
            cur.execute(FETCH_FULL_SQL, (tenant_id, sid))
            full = cur.fetchone()
            full_sim = 0.0
            if full and full["embedding_json"]:
                try:
                    fvec = json.loads(full["embedding_json"])
                    full_sim = cosine_sim(qvec, fvec)
                except Exception:
                    pass
            ranked.append({
                "story_id": sid,
                "title": rec["title"],
                "bullet_sim": rec["best"],
                "full_sim": full_sim,
                "close_hits": rec["close_hits"],
                "final": (0.7 * max(rec["best"], full_sim) + 0.3 * full_sim)
            })

        ranked.sort(key=lambda r: (r["final"], r["close_hits"]), reverse=True)
        top = ranked[:max(1, top_k)]

        # 3) Fetch sections for top stories
        results = []
        for r in top:
            cur.execute(
                    FETCH_SECTIONS_SQL,
                    {"tenant_id": tenant_id, "title": r["title"], "story_id": r["story_id"]},
                )
            secs = cur.fetchall()
            clean_secs = [
                {"section": s["section"], "content_html": s["content_html"]}
                for s in secs
                if s["section"].lower() != "einordnung" or s["bullet_index"] == -1
            ]
            results.append({
                "story_id": r["story_id"],
                "title": r["title"],
                "score": float(r["final"]),
                "sections": clean_secs,
            })
            
        return results
