from pathlib import Path
import json
import os
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np

# Lazy imports inside functions to avoid environment issues at import time.


# ------------------------------
# Configuration
# ------------------------------

EMBED_MODEL = os.environ.get("OPENAI_EMBED_MODEL", "text-embedding-3-small")
CHUNK_SIZE_TOKENS = int(os.environ.get("RAG_CHUNK_SIZE", "500"))
CHUNK_OVERLAP_TOKENS = int(os.environ.get("RAG_CHUNK_OVERLAP", "80"))
DEFAULT_TOP_K = int(os.environ.get("RAG_TOP_K", "3"))


def _get_tokenizer():
    import tiktoken
    return tiktoken.get_encoding("cl100k_base")


def _split_by_tokens(text: str, max_tokens: int, overlap: int) -> List[str]:
    enc = _get_tokenizer()
    tokens = enc.encode(text)
    chunks = []
    start = 0
    n = len(tokens)
    while start < n:
        end = min(start + max_tokens, n)
        chunk_tokens = tokens[start:end]
        chunks.append(enc.decode(chunk_tokens))
        if end == n:
            break
        start = end - overlap
        if start < 0:
            start = 0
    return chunks


def _cosine_sim_matrix(a, b):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-12)
    b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-12)
    return a_norm @ b_norm.T


def _embed_texts(texts: List[str]) -> List[List[float]]:
    from openai import OpenAI
    client = OpenAI()
    resp = client.embeddings.create(model=EMBED_MODEL, input=texts)
    vectors = [d.embedding for d in resp.data]
    return vectors


@dataclass
class RAGChunk:
    tenant_id: int
    title: str
    content: str
    tags: Optional[str] = None
    embedding: Optional[List[float]] = None


def extract_pdf_text(pdf_path: str) -> List[Tuple[int, str]]:
    import fitz  # PyMuPDF
    pages = []
    with fitz.open(pdf_path) as doc:
        for i, page in enumerate(doc):
            text = page.get_text("text")
            if text:
                pages.append((i + 1, text))
    return pages


def make_chunks_from_pdf(pdf_path: str,
                         tenant_id: int,
                         base_title: str,
                         tags: str = "pdf,rag",
                         chunk_size_tokens: int = CHUNK_SIZE_TOKENS,
                         overlap_tokens: int = CHUNK_OVERLAP_TOKENS) -> List[RAGChunk]:
    chunks: List[RAGChunk] = []
    page_texts = extract_pdf_text(pdf_path)
    for page_no, text in page_texts:
        pieces = _split_by_tokens(text, chunk_size_tokens, overlap_tokens)
        for idx, piece in enumerate(pieces):
            title = f"{base_title} :: p{page_no:02d} :: c{idx+1:02d}"
            chunks.append(RAGChunk(tenant_id=tenant_id, title=title, content=piece, tags=tags))
    return chunks


def _import_models():
    try:
        from app import SessionLocal  # type: ignore
    except Exception:
        SessionLocal = None
    try:
        from models import Document  # type: ignore
    except Exception:
        Document = None
    if SessionLocal is None or Document is None:
        raise ImportError(
            "Could not import SessionLocal or Document. Adjust `_import_models()` to your project layout."
        )
    return SessionLocal, Document


def upsert_chunks_to_db(chunks: List[RAGChunk]) -> int:
    SessionLocal, Document = _import_models()
    session = SessionLocal()
    try:
        BATCH = 128
        rows_written = 0
        for i in range(0, len(chunks), BATCH):
            batch = chunks[i:i+BATCH]
            texts = [c.content for c in batch]
            vectors = _embed_texts(texts)

            for c, vec in zip(batch, vectors):
                emb_json = json.dumps(vec)
                doc = Document(
                    tenant_id=c.tenant_id,
                    case_id=None,
                    title=c.title,
                    content=c.content,
                    embedding=emb_json,
                    tags=c.tags or "pdf,rag"
                )
                session.add(doc)
            session.commit()
            rows_written += len(batch)
        return rows_written
    finally:
        session.close()


def preprocess_pdf_to_db(pdf_path: str, tenant_id: int, base_title: Optional[str] = None, tags: str = "pdf,rag") -> int:
    pdf_path = str(pdf_path)
    if base_title is None:
        base_title = Path(pdf_path).stem
    chunks = make_chunks_from_pdf(pdf_path, tenant_id=tenant_id, base_title=base_title, tags=tags)
    written = upsert_chunks_to_db(chunks)
    return written


def query_rag_context_from_db(user_query: str, tenant_id: int, top_k: int = DEFAULT_TOP_K) -> dict:
    SessionLocal, Document = _import_models()
    session = SessionLocal()
    try:
        docs = session.query(Document).filter(
            Document.tenant_id == tenant_id,
            Document.embedding.isnot(None)
        ).all()

        if not docs:
            return {"context": "", "matches": [], "top_k": top_k, "warning": "No embedded documents for this tenant."}

        embeddings, texts, titles, ids = [], [], [], []
        for d in docs:
            try:
                vec = json.loads(d.embedding)
                embeddings.append(vec)
                texts.append(d.content)
                titles.append(d.title)
                ids.append(d.id)
            except Exception:
                continue

        if not embeddings:
            return {"context": "", "matches": [], "top_k": top_k, "warning": "No valid embeddings parsed."}

        query_vec = _embed_texts([user_query])[0]
        sims = _cosine_sim_matrix([query_vec], embeddings)[0]

        top_idx = np.argsort(-sims)[:top_k]

        matches = []
        for rank, idx in enumerate(top_idx, start=1):
            matches.append({
                "rank": rank,
                "document_id": int(ids[idx]),
                "title": titles[idx],
                "similarity": float(sims[idx]),
                "content": texts[idx]
            })

        context_parts = []
        for m in matches:
            context_parts.append(f"[{m['rank']}] {m['title']}\n{m['content'].strip()}")
        context = "\n\n---\n\n".join(context_parts)

        return {"context": context, "matches": matches, "top_k": top_k}
    finally:
        session.close()


def preprocess_folder_pdfs_to_db(folder: str, tenant_id: int, tags: str = "pdf,rag") -> int:
    p = Path(folder)
    if not p.exists():
        raise FileNotFoundError(f"Folder not found: {folder}")
    total = 0
    for pdf in sorted(p.glob("*.pdf")):
        total += preprocess_pdf_to_db(str(pdf), tenant_id=tenant_id, base_title=pdf.stem, tags=tags)
    return total


RAG_PROMPT_TEMPLATE = """Du bist ein präziser Assistent. Nutze den folgenden Kontext aus den House of PM Best-Practices, falls relevant, um die Nutzerfrage zu beantworten.
Wenn der Kontext nicht passt, antworte ohne ihn zu zitieren. Zitiere kurz in Klammern die Quelle [Titel] bei Verwendung.

KONTEXT:
{context}

NUTZERFRAGE:
{question}

ANTWORT (prägnant, fachlich, deutsch):"""


def build_rag_augmented_prompt(user_query: str, rag_context: str) -> str:
    return RAG_PROMPT_TEMPLATE.format(context=rag_context or "(kein relevanter Kontext gefunden)", question=user_query)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Preprocess PDFs into DB and/or query RAG context.")
    parser.add_argument("--tenant", type=int, default=1, help="Tenant ID")
    parser.add_argument("--pdf", type=str, help="Path to a single PDF to preprocess")
    parser.add_argument("--folder", type=str, help="Path to a folder with PDFs to preprocess")
    parser.add_argument("--query", type=str, help="Query to test retrieval")
    args = parser.parse_args()

    if args.pdf:
        print(f"Preprocessing PDF: {args.pdf}")
        cnt = preprocess_pdf_to_db(args.pdf, tenant_id=args.tenant)
        print(f"Stored chunks: {cnt}")

    if args.folder:
        print(f"Preprocessing folder: {args.folder}")
        cnt = preprocess_folder_pdfs_to_db(args.folder, tenant_id=args.tenant)
        print(f"Stored chunks (all PDFs): {cnt}")

    if args.query:
        print(f"Querying RAG for tenant {args.tenant}: {args.query}")
        res = query_rag_context_from_db(args.query, tenant_id=args.tenant, top_k=DEFAULT_TOP_K)
        print("Top matches:")
        for m in res.get("matches", []):
            print(f"- ({m['similarity']:.3f}) {m['title']}  id={m['document_id']}")
        print("\n---\nContext preview:\n")
        print(res.get("context", "")[:1500])
