from pathlib import Path
import json
import os
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
from dotenv import load_dotenv

# Load .env from the same folder as this file (robust if run from other dirs)
load_dotenv(dotenv_path=Path(__file__).with_name(".env"))

# ------------------------------
# Configuration (env defaults)
# ------------------------------
EMBED_MODEL = os.environ.get("OPENAI_EMBED_MODEL", "text-embedding-3-small")
CHUNK_SIZE_TOKENS = int(os.environ.get("RAG_CHUNK_SIZE", "500"))
CHUNK_OVERLAP_TOKENS = int(os.environ.get("RAG_CHUNK_OVERLAP", "80"))
DEFAULT_TOP_K = int(os.environ.get("RAG_TOP_K", "3"))


def _get_tokenizer():
    import tiktoken
    return tiktoken.get_encoding("cl100k_base")


def _split_by_tokens(text: str, max_tokens: int, overlap: int) -> List[str]:
    enc = _get_tokenizer()
    tokens = enc.encode(text)
    chunks = []
    start = 0
    n = len(tokens)
    while start < n:
        end = min(start + max_tokens, n)
        chunk_tokens = tokens[start:end]
        chunk_text = enc.decode(chunk_tokens).strip()
        if chunk_text:
            chunks.append(chunk_text)
        if end == n:
            break
        start = max(0, end - overlap)
    return chunks


def _cosine_sim_matrix(a, b):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-12)
    b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-12)
    return a_norm @ b_norm.T


def _embed_texts(texts: List[str]) -> List[List[float]]:
    from openai import OpenAI
    client = OpenAI()
    resp = client.embeddings.create(model=EMBED_MODEL, input=texts)
    return [d.embedding for d in resp.data]


@dataclass
class RAGChunk:
    tenant_id: int
    title: str
    content: str
    tags: Optional[str] = None
    embedding: Optional[List[float]] = None


def extract_pdf_text(pdf_path: str) -> List[Tuple[int, str]]:
    import fitz  # PyMuPDF
    pages = []
    with fitz.open(pdf_path) as doc:
        for i, page in enumerate(doc):
            text = (page.get_text("text") or "").strip()
            if text:
                pages.append((i + 1, text))
    return pages


def make_chunks_from_pdf(pdf_path: str,
                         tenant_id: int,
                         base_title: str,
                         tags: str = "pdf,rag",
                         chunk_size_tokens: int = CHUNK_SIZE_TOKENS,
                         overlap_tokens: int = CHUNK_OVERLAP_TOKENS) -> List[RAGChunk]:
    chunks: List[RAGChunk] = []
    page_texts = extract_pdf_text(pdf_path)
    for page_no, text in page_texts:
        pieces = _split_by_tokens(text, chunk_size_tokens, overlap_tokens)
        for idx, piece in enumerate(pieces):
            if not piece.strip():
                continue
            title = f"{base_title} :: p{page_no:02d} :: c{idx+1:02d}"
            chunks.append(RAGChunk(tenant_id=tenant_id, title=title, content=piece, tags=tags))
    return chunks


def _import_models():
    try:
        from app import SessionLocal  # type: ignore
    except Exception:
        SessionLocal = None
    try:
        from models import Document  # type: ignore
    except Exception:
        Document = None
    if SessionLocal is None or Document is None:
        raise ImportError(
            "Could not import SessionLocal or Document. Adjust `_import_models()` to your project layout."
        )
    return SessionLocal, Document


def upsert_chunks_to_db(chunks: List[RAGChunk]) -> int:
    SessionLocal, Document = _import_models()
    session = SessionLocal()
    try:
        BATCH = 128
        rows_written = 0
        for i in range(0, len(chunks), BATCH):
            batch = chunks[i:i+BATCH]
            texts = [c.content for c in batch]
            vectors = _embed_texts(texts)

            for c, vec in zip(batch, vectors):
                emb_json = json.dumps(vec)
                doc = Document(
                    tenant_id=c.tenant_id,
                    case_id=None,
                    title=c.title,
                    content=c.content,
                    embedding=emb_json,
                    tags=c.tags or "pdf,rag"
                )
                session.add(doc)
            session.commit()
            rows_written += len(batch)
        return rows_written
    finally:
        session.close()


def preprocess_pdf_to_db(pdf_path: str, tenant_id: int, base_title: Optional[str] = None, tags: str = "pdf,rag") -> int:
    pdf_path = str(pdf_path)
    if base_title is None:
        base_title = Path(pdf_path).stem
    chunks = make_chunks_from_pdf(
        pdf_path,
        tenant_id=tenant_id,
        base_title=base_title,
        tags=tags,
        chunk_size_tokens=CHUNK_SIZE_TOKENS,
        overlap_tokens=CHUNK_OVERLAP_TOKENS
    )
    written = upsert_chunks_to_db(chunks)
    return written


def query_rag_context_from_db(user_query: str, tenant_id: int, top_k: int = DEFAULT_TOP_K) -> dict:
    SessionLocal, Document = _import_models()
    session = SessionLocal()
    try:
        docs = session.query(Document).filter(
            Document.tenant_id == tenant_id,
            Document.embedding.isnot(None)
        ).all()

        if not docs:
            return {"context": "", "matches": [], "top_k": top_k, "warning": "No embedded documents for this tenant."}

        embeddings, texts, titles, ids = [], [], [], []
        for d in docs:
            try:
                vec = json.loads(d.embedding)
                embeddings.append(vec)
                texts.append(d.content)
                titles.append(d.title)
                ids.append(d.id)
            except Exception:
                continue

        if not embeddings:
            return {"context": "", "matches": [], "top_k": top_k, "warning": "No valid embeddings parsed."}

        query_vec = _embed_texts([user_query])[0]
        sims = _cosine_sim_matrix([query_vec], embeddings)[0]

        top_idx = np.argsort(-sims)[:top_k]

        matches = []
        for rank, idx in enumerate(top_idx, start=1):
            matches.append({
                "rank": rank,
                "document_id": int(ids[idx]),
                "title": titles[idx],
                "similarity": float(sims[idx]),
                "content": texts[idx]
            })

        context_parts = []
        for m in matches:
            context_parts.append(f"[{m['rank']}] {m['title']}\n{m['content'].strip()}")
        context = "\n\n---\n\n".join(context_parts)

        return {"context": context, "matches": matches, "top_k": top_k}
    finally:
        session.close()


def preprocess_folder_pdfs_to_db(folder: str, tenant_id: int, tags: str = "pdf,rag") -> int:
    p = Path(folder)
    if not p.exists():
        raise FileNotFoundError(f"Folder not found: {folder}")
    total = 0
    for pdf in sorted(p.glob("*.pdf")):
        total += preprocess_pdf_to_db(str(pdf), tenant_id=tenant_id, base_title=pdf.stem, tags=tags)
    return total


RAG_PROMPT_TEMPLATE = """Du bist ein präziser Assistent. Nutze den folgenden Kontext aus den Realcube Domumenten, falls relevant, um die Nutzerfrage zu beantworten.
Wenn der Kontext nicht passt, antworte ohne ihn zu zitieren. Zitiere kurz in Klammern die Quelle [Titel] bei Verwendung. Realcube ist eine Daten-Warehouse Plattform, 
die alle denkbaren digitalen Immobilien-Daten sinnvoll sammelt und verarbeitet. Realcube arbeitet auch zunehmend mit KI. Entscheidend ist, dass der hier generierte
Kontent zum Ausdruck bringt, dass mit der Realcube Plattform Dinge möglich sind, die sonst nicht möglich sind und dass die Realcube Plattform perfekte Antowrten auf
die gestellten Fragen liefert.   

KONTEXT:
{context}

NUTZERFRAGE:
{question}

ANTWORT (prägnant, fachlich, deutsch):"""


def build_rag_augmented_prompt(user_query: str, rag_context: str) -> str:
    return RAG_PROMPT_TEMPLATE.format(context=rag_context or "(kein relevanter Kontext gefunden)", question=user_query)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Preprocess PDFs into DB and/or query RAG context.")
    parser.add_argument("--tenant", type=int, default=1, help="Tenant ID")
    parser.add_argument("--pdf", type=str, help="Path to a single PDF to preprocess")
    parser.add_argument("--folder", type=str, help="Path to a folder with PDFs to preprocess")
    parser.add_argument("--query", type=str, help="Query to test retrieval")
    parser.add_argument("--tags", type=str, help="Comma-separated tags to store", default="pdf,rag")

    # Optional: CLI overrides for env config
    parser.add_argument("--embed-model", type=str, help="Embedding model id")
    parser.add_argument("--chunk-size", type=int, help="Chunk size in tokens")
    parser.add_argument("--chunk-overlap", type=int, help="Overlap in tokens")
    parser.add_argument("--top-k", type=int, help="Top K retrieval")

    args = parser.parse_args()

    # Apply CLI overrides if provided
    if args.embed_model:
        EMBED_MODEL = args.embed_model  # type: ignore
    if args.chunk_size:
        CHUNK_SIZE_TOKENS = args.chunk_size  # type: ignore
    if args.chunk_overlap:
        CHUNK_OVERLAP_TOKENS = args.chunk_overlap  # type: ignore
    if args.top_k:
        DEFAULT_TOP_K = args.top_k  # type: ignore

    if args.pdf:
        print(f"Preprocessing PDF: {args.pdf}")
        cnt = preprocess_pdf_to_db(args.pdf, tenant_id=args.tenant, tags=args.tags)
        print(f"Stored chunks: {cnt}")

    if args.folder:
        print(f"Preprocessing folder: {args.folder}")
        cnt = preprocess_folder_pdfs_to_db(args.folder, tenant_id=args.tenant, tags=args.tags)
        print(f"Stored chunks (all PDFs): {cnt}")

    if args.query:
        print(f"Querying RAG for tenant {args.tenant}: {args.query}")
        res = query_rag_context_from_db(args.query, tenant_id=args.tenant, top_k=DEFAULT_TOP_K)
        print("Top matches:")
        for m in res.get("matches", []):
            print(f"- ({m['similarity']:.3f}) {m['title']}  id={m['document_id']}")
        print("\n---\nContext preview:\n")
        print(res.get("context", "")[:1500])
