import os
from dotenv import load_dotenv
from sqlalchemy import create_engine, event
from sqlalchemy.engine import URL
from sqlalchemy.orm import sessionmaker, Session as SASession, with_loader_criteria

# NEW: to read tenant during web requests
from flask import has_request_context, g

# ← Bring in Base *after* models have been defined
from models import Base, Case, Document, Prompt, Topic, User, Session as ChatSession

load_dotenv()  # loads DB_USER, DB_PASSWORD, etc.

url = URL.create(
    drivername="postgresql+psycopg2",
    username=os.getenv("DB_USER"),
    password=os.getenv("DB_PASSWORD"),
    host=os.getenv("DB_HOST"),
    port=int(os.getenv("DB_PORT", 5432)),
    database=os.getenv("DB_NAME")
)

engine = create_engine(url, pool_pre_ping=True)  # pre_ping helps after long idle

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def init_db():
    print("📦 Initializing database...")
    Base.metadata.create_all(bind=engine)


# ---------- MULTI-TENANT GUARDS (SELECT + INSERT) ----------

@event.listens_for(SASession, "do_orm_execute")
def _enforce_tenant_on_select(execute_state):
    """Auto-apply WHERE tenant_id = g.tenant_id to SELECTs on tenant-scoped models."""
    if not execute_state.is_select:
        return
    if not (has_request_context() and hasattr(g, "tenant_id")):
        return

    tid = int(getattr(g, "tenant_id", 0) or 0)
    if not tid:
        return

    # Apply to each **mapped** entity that has a tenant_id column
    execute_state.statement = execute_state.statement.options(
        with_loader_criteria(Case,       lambda cls: cls.tenant_id == tid, include_aliases=True),
        with_loader_criteria(Document,   lambda cls: cls.tenant_id == tid, include_aliases=True),
        with_loader_criteria(Prompt,     lambda cls: cls.tenant_id == tid, include_aliases=True),
        with_loader_criteria(Topic,      lambda cls: cls.tenant_id == tid, include_aliases=True),
        with_loader_criteria(User,       lambda cls: cls.tenant_id == tid, include_aliases=True),
        with_loader_criteria(ChatSession,lambda cls: cls.tenant_id == tid, include_aliases=True),
    )


@event.listens_for(SASession, "before_flush")
def _autofill_tenant_on_insert(session, flush_context, instances):
    """Fill tenant_id on new objects when missing."""
    if not (has_request_context() and hasattr(g, "tenant_id")):
        return
    tid = int(getattr(g, "tenant_id", 0) or 0)
    if not tid:
        return
    for obj in session.new:
        if hasattr(obj, "tenant_id") and getattr(obj, "tenant_id", None) in (None, 0):
            setattr(obj, "tenant_id", tid)



