import os
import openai
import logging
import sys
import uuid
import traceback

# Force reconfig logging before other modules use it
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    stream=sys.stdout,  # Ensure it's sent to systemd journal
    force=True          # ✅ Overwrite any earlier configs
)

# Force reconfiguration of logging before any other module uses it
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from flask import Flask, request, jsonify, session, send_from_directory, render_template, url_for
from flask_cors import CORS
from openai import OpenAI
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
from prompt_manager import call_openai_prompt
from app_config import get_max_phase
from flask import request
from qa_system import (
    get_next_part,
    get_recommendation,
    get_answer,
    detect_app_id_from_port,
    get_prompt_messages_by_template_name  # ✅ Add this line
)

import MySQLdb
import MySQLdb.cursors
import re

# === OpenAI client setup ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

ALLOW_FORCE_PHASE = os.getenv("ALLOW_FORCE_PHASE", "1") in ("1", "true", "True")

# === App setup ===
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
app.secret_key = 'super_secret_key'

# === Application-level constants ===
APP_ID = 3 # 5006  # could later be dynamic from port
#MAX_PHASE = get_max_phase(APP_ID)
#logging.info(f"MAX PHASE setting from DB: {MAX_PHASE}")

try:
    MAX_PHASE = get_max_phase(APP_ID)
    print(">>>> Reached top-level MAX_PHASE fetch")  # <-- Debug test line
    logging.info(f"MAX PHASE setting from DB: {MAX_PHASE}")
except Exception as e:
    logging.error(f"Initialization error for MAX_PHASE: {e}")

def extract_sound_files(sound_tags):
    """
    Extracts filenames from <sound:filename> tags.
    Example: "<sound:calm_1.mp4><sound:pause.mp4>" → ['calm_1.mp4', 'pause.mp4']
    """
    if not sound_tags:
        return []
    return re.findall(r"<sound:([^>]+)>", sound_tags)

#Get the port from the URL
@app.before_request
def set_app_id_from_port():
    try:
        app_id = detect_app_id_from_port(request.host)
        session['app_id'] = app_id
        logging.info(f"Detected APP_ID from host: {app_id}")
    except Exception as e:
        logging.error(f"Could not detect app_id from host: {e}")        

# === Utility ===
def parse_instructions(phase_text):
    segments = re.split(r'(<sound:[\w-]+>)', phase_text)
    sanitized_segments = []
    sound_files = set()

    for segment in segments:
        if "<sound:" in segment:
            sound_marker = re.search(r'<sound:([\w-]+)>', segment)
            if sound_marker:
                filename = sound_marker.group(1)
                sanitized_segments.append({"type": "sound", "value": filename})
                sound_files.add(filename)
        else:
            text_content = segment.strip()
            if text_content:
                sanitized_segments.append({"type": "text", "value": text_content})

    return sanitized_segments, list(sound_files)

# === Routes ===

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/static/<path:path>')
def static_files(path):
    return send_from_directory('static', path)


@app.route('/favicon.ico')
def favicon():
    return send_from_directory(
        os.path.join(app.root_path, 'static'),
        'favicon.ico',
        mimetype='image/vnd.microsoft.icon'
    )

@app.route('/session', methods=['GET'])
def get_or_create_session():
    sid = session.get('session_id')
    if not sid:
        sid = str(uuid.uuid4())
        session['session_id'] = sid
    return jsonify({'session_id': sid})

@app.route("/api/prompt/<int:prompt_id>", methods=["POST"])
def run_prompt(prompt_id):
    data = request.get_json()
    user_input = data.get("user_input", "")
    session_id = data.get("session_id", "")


    logging.info(f"[DEBUG] run_prompt(): prompt_id = {prompt_id}, session_id = {session_id}")


    db = get_db_connection()
    #cur = db.cursor()
    cur = db.cursor(MySQLdb.cursors.DictCursor)

    cur.execute("""
        SELECT sequence, role, content
          FROM prompt_messages
         WHERE prompt_id = %s
         ORDER BY sequence ASC
    """, (prompt_id,))
    messages = cur.fetchall()
    db.close()

    logging.info(f"🧪 prompt_id: {prompt_id}, session_id: {session_id}")

    # Special case: if phase 4 and session_id present, inject summary
    if prompt_id == 4 and session_id:
        logging.info(f"[DEBUG] Injecting summary + rhythm for prompt {prompt_id}, session {session_id}")
        
        app_id = session.get("app_id") or detect_app_id_from_port(request.host)

        needs_summary = any("<summary>"      in msg["content"] for msg in messages)
        needs_rtxt    = any("<rhythm_text>"  in msg["content"] for msg in messages)
        needs_rwords  = any("<rhythm_words>" in msg["content"] for msg in messages)

        if needs_summary or needs_rtxt or needs_rwords:
            summary = get_latest_summary(session_id) or "[no summary found]"
            rhythm  = get_rhythm_words(session_id, app_id) or "[no rhythm words]"
            logging.info(f"🔁 Injecting summary='{summary[:120]}' | rhythm='{rhythm}'")

            for msg in messages:
                c = msg["content"]
                if "<summary>" in c:
                    c = c.replace("<summary>", summary)
                if "<rhythm_text>" in c:
                    c = c.replace("<rhythm_text>", rhythm)
                if "<rhythm_words>" in c:
                    c = c.replace("<rhythm_words>", rhythm)
                msg["content"] = c

            logging.info(f"[PROMPT {prompt_id}] Injection complete.")

    # Inject user input if placeholder exists
    for msg in messages:
        if "<user_input>" in msg["content"]:
            msg["content"] = msg["content"].replace("<user_input>", user_input)

    # Format for OpenAI
    openai_messages = [{"role": m["role"], "content": m["content"]} for m in messages]

    # 🔍 Log full request to OpenAI
    logging.info(f"[OPENAI] Prompt ID: {prompt_id}, Session: {session_id}")
    for i, msg in enumerate(openai_messages):
        logging.info(f"[OPENAI][{i}] {msg['role']}: {msg['content']}")

    logging.info(f"[PROMPT {prompt_id}] Final messages payload -> " +
             " | ".join(f"{m['role']}: {m['content']}" for m in messages))    

    # Send to OpenAI

    try:
        response = openai_client.chat.completions.create(
            model="gpt-4",
            messages=openai_messages,
            temperature=0.7
        )
    except Exception as e:
        logging.exception(f"❌ OpenAI call failed for prompt_id {prompt_id}, session {session_id}")
        return jsonify({"error": str(e)}), 500

    output = response.choices[0].message.content.strip()

    # 🔍 Log response
    logging.info(f"[OPENAI] Response:\n{output}")

    return jsonify({"output": output})


@app.route('/reset-phase', methods=['POST'])
def reset_phase():
    session.clear()
    session['phase'] = 0
    return jsonify({"message": "Phase reset to 0"}), 200

def segment_instructions(instruction_text):
    segments = []
    lines = instruction_text.strip().split('\n')

    for line in lines:
        sound_match = re.match(r"<sound:(.+)>", line.strip())
        if sound_match:
            segments.append({
                "type": "sound",
                "value": sound_match.group(1).strip()
            })
        else:
            segments.append({
                "type": "text",
                "value": line.strip()
            })

    return segments


@app.route('/next-phase', methods=['POST'])
def next_phase():
    # ---- read + log payload once ----
    payload = request.get_json(silent=True) or {}
    app.logger.info("next-phase payload: %r", payload)

    choice = payload.get("choice")
    # NEW: parse forcePhase safely
    force_phase = payload.get("forcePhase")
    try:
        force_phase = int(force_phase) if force_phase is not None else None
    except (TypeError, ValueError):
        force_phase = None

    current_phase = session.get("current_phase", 0)
    app_id = detect_app_id_from_port(request.host)
    max_phase = get_max_phase(app_id)

    app.logger.info(f"CURRENT PHASE in /next-phase: {current_phase, app_id}")

    # ---- decide next_phase (honor forcePhase if valid) ----
    honored_force = False
    if (force_phase is not None) and (1 <= force_phase <= max_phase):
        next_phase = force_phase
        honored_force = True
    else:
        # your existing logic
        if choice == "REPEAT" or (choice and choice.startswith("(B)")):
            next_phase = current_phase  # repeat same
        else:
            next_phase = current_phase + 1

    app.logger.info(
        "next-phase decision: current=%s choice=%s force=%s (valid=%s, max=%s) → target=%s",
        current_phase, choice, force_phase,
        (force_phase is not None) and (1 <= force_phase <= max_phase),
        max_phase, next_phase
    )

    # persist phase
    session["current_phase"] = next_phase

    phase_data = get_next_part(app_id=app_id, phase=next_phase)
    if not phase_data:
        return jsonify({"error": "Phase not found"}), 404

    section_label = phase_data.get('section_label', 'UNKNOWN').strip()
    combined_text = (
        f"{section_label}\n{phase_data['instruction_text']}"
        if section_label else phase_data['instruction_text']
    )

    return jsonify({
        "app_id": app_id,
        "phase": next_phase,
        "segmentedInstructions": segment_instructions(combined_text),
        "feedbackRequest": phase_data.get("feedbackRequest_text") or "",
        "questions": [
            {"choice": line[:3].strip(), "text": line.strip()}
            for line in phase_data["question_text"].split("\n") if line.strip()
        ],
        "soundFiles": extract_sound_files(combined_text),
        "ai_prompt_id": phase_data["ai_prompt_id"],
        "is_last_phase": (next_phase == max_phase),
        # NEW: lets the client confirm the jump was honored
        "honored_force": honored_force
    })

@app.route('/store-rhythm-words', methods=['POST'])
def store_rhythm_words():
    data = request.get_json() or {}
    session_id = data.get('session_id') or session.get('session_id')
    app_id     = data.get('app_id') or session.get('app_id') or detect_app_id_from_port(request.host)
    phase      = data.get('phase')
    rhythm_words = (data.get('rhythm_words') or '').strip()

    if not session_id:
        # mint + remember one so subsequent calls are consistent
        session_id = session.get("session_id")
        if not session_id:
            session_id = str(uuid.uuid4())
            session["session_id"] = session_id

    if not phase:
        return jsonify({'error': 'Missing phase'}), 400
    if not rhythm_words:
        return jsonify({'error': 'Missing rhythm_words'}), 400

    try:
        rows = store_rhythm_words_in_db(session_id, app_id, phase, rhythm_words)
        logging.info(f"[DB] rhythm_words upsert session={session_id} app={app_id} phase={phase} rows={rows} value='{rhythm_words}'")
        logging.info(f"[DB] rhythm_words now = {get_rhythm_words(session.get('session_id'), app_id)} for session {session.get('session_id')}")
        return jsonify({'status': 'success', 'rows_affected': rows, 'session_id': session_id})
    except Exception as e:
        logging.exception("Failed to store rhythm_words")
        return jsonify({'error': 'Failed to store rhythm_words'}), 500


@app.route('/feedback', methods=['POST'])
def handle_feedback():
    data = request.get_json() or {}
    feedback_text = data.get("feedback", "").strip()
    app_id = detect_app_id_from_port(request.host)
    phase = int(data.get("phase", 1))

    # ✅ Ensure session_id exists (and remember app_id for later lookups)

    session_id = (request.json or {}).get("session_id") or session.get("session_id")
    if not session_id:
        session_id = str(uuid.uuid4())
    session["session_id"] = session_id

    # session_id = session.get("session_id") or data.get("session_id")
    # if not session_id:
    #     session_id = str(uuid.uuid4())
    #     session["session_id"] = session_id
    session["app_id"] = app_id

    logging.info(f"🧠 /feedback phase={phase} app_id={app_id} session_id={session_id}")

    # ✅ Phase 6: persist rhythm_words directly (backend is source of truth)
    if phase == 6 and feedback_text:
        try:
            rows = store_rhythm_words_in_db(session_id, app_id, phase, feedback_text)
            logging.info(f"[DB] upsert rhythm_words rows={rows} session={session_id} app={app_id} phase={phase} value='{feedback_text}'")
        except Exception:
            logging.exception("Failed to store rhythm_words")

    # Call the model (will do placeholder injection after we patch it below)
    result = get_chatgpt_recommendation(feedback_text, app_id, phase)

    # ✅ Phase 2: persist summary on the server (don’t rely on a follow-up client call)
    if phase == 2 and result:
        try:
            store_feedback(tenant_id=session.get("tenant_id", 1),
                           user_id=session.get("user_id", 1),
                           session_id=session_id,
                           app_id=app_id,
                           phase_number=phase,
                           feedbacktext=feedback_text)
            store_summary(session_id, app_id, phase, result)
            logging.info(f"[DB] stored summary for session={session_id} phase={phase}")
        except Exception:
            logging.exception("Failed to store summary/feedback")

    # Response shape kept as before
    if phase == 2:
        return jsonify({"summary": result})
    else:
        return jsonify({"recommendation": result})


def store_rhythm_words_in_db(session_id, app_id, phase_number, rhythm_words):
    conn = get_db_connection(); cur = conn.cursor()
    # Update latest row for session/app
    cur.execute("""
        UPDATE session_feedback
           SET rhythm_words = %s, phase_number = phase_number
         WHERE session_id = %s AND app_id = %s
         ORDER BY id DESC
         LIMIT 1
    """, (rhythm_words, session_id, app_id))
    rows = cur.rowcount
    if rows == 0:
        # Seed if nothing exists yet
        cur.execute("""
            INSERT INTO session_feedback
                (tenant_id, user_id, session_id, app_id, phase_number, rhythm_words)
            VALUES (%s, %s, %s, %s, %s, %s)
        """, (session.get("tenant_id",1), session.get("user_id",1),
              session_id, app_id, phase_number, rhythm_words))
        rows = cur.rowcount
    conn.commit(); cur.close(); conn.close()
    return rows


def get_db_connection():
    return MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4"
    )    

def store_feedback(tenant_id, user_id, session_id, app_id, phase_number, feedbacktext):
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4"
    )
    cursor = conn.cursor()

    sql = (
        "INSERT INTO session_feedback "
        "(tenant_id, user_id, session_id, app_id, phase_number, feedbacktext) "
        "VALUES (%(tenant_id)s, %(user_id)s, %(session_id)s, %(app_id)s, %(phase_number)s, %(feedbacktext)s)"
    )
    params = {
        "tenant_id": int(tenant_id),
        "user_id": int(user_id),
        "session_id": str(session_id),
        "app_id": int(app_id),
        "phase_number": int(phase_number),
        "feedbacktext": feedbacktext,
    }

    # Optional debug if it ever breaks again
    logging.info("store_feedback params types=%s",
                 {k: type(v).__name__ for k, v in params.items()})

    cursor.execute(sql, params)
    conn.commit()
    conn.close()


def store_summary(session_id, app_id, phase_number, summary):
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4"
    )
    cursor = conn.cursor()
    cursor.execute("""
        UPDATE session_feedback
           SET summary = %s
         WHERE session_id = %s
           AND app_id     = %s
           AND phase_number = %s
         ORDER BY id DESC
         LIMIT 1
    """, (summary, session_id, app_id, phase_number))
    conn.commit()
    conn.close()


def get_chatgpt_recommendation(feedback_text, app_id, phase):
    try:
        logging.info(f"🧠 Getting GPT recommendation for phase {phase}, app_id {app_id}")

        phase_data = get_next_part(app_id, phase)
        if not phase_data:
            logging.error(f"❌ No phase data for app_id={app_id}, phase={phase}")
            return get_recommendation(feedback_text)

        prompt_name = phase_data.get("prompt_template_name")
        if not prompt_name:
            logging.warning("⚠️ No prompt_template_name. Falling back.")
            return get_recommendation(feedback_text)

        messages = get_prompt_messages_by_template_name(prompt_name) or []
        if not messages:
            logging.warning(f"⚠️ No messages for template '{prompt_name}'. Falling back.")
            return get_recommendation(feedback_text)

        # add user feedback
        messages.append({"role": "user", "content": feedback_text})

        # ✅ Inject placeholders here (this function is what phase 6 uses)
        app_id = session.get("app_id") or detect_app_id_from_port(request.host)
        sess_id = session.get("session_id")
        summary = get_latest_summary(sess_id) if sess_id else None
        rhythm  = get_rhythm_words(sess_id, app_id) if sess_id else None
        # Keep readable fallbacks (or leave as-is if you prefer raw placeholders)
        summary_fallback = summary or "[no summary found]"
        rhythm_fallback  = rhythm  or "[no rhythm words]"

        needs_summary = any("<summary>"      in m["content"] for m in messages)
        needs_rtxt    = any("<rhythm_text>"  in m["content"] for m in messages)
        needs_rwords  = any("<rhythm_words>" in m["content"] for m in messages)

        if needs_summary or needs_rtxt or needs_rwords:
            logging.info(f"🔁 Injecting placeholders (sess={sess_id}) summary?{needs_summary} rhythm?{needs_rtxt or needs_rwords}")
            for m in messages:
                c = m["content"]
                if "<summary>" in c:
                    c = c.replace("<summary>", summary_fallback)
                if "<rhythm_text>" in c:
                    c = c.replace("<rhythm_text>", rhythm_fallback)
                if "<rhythm_words>" in c:
                    c = c.replace("<rhythm_words>", rhythm_fallback)
                m["content"] = c

        # ✅ log the full prompt payload (post-injection)
        logging.info(f"📤 OpenAI call with model=gpt-4, messages={messages}")

        completion = OpenAI().chat.completions.create(
            model="gpt-4",
            messages=messages,
            temperature=0.7
        )
        result = completion.choices[0].message.content.strip()
        logging.info(f"📥 OpenAI response (truncated to 200 chars): {result[:200]}")
        return result

    except Exception as e:
        logging.error(f"🔥 OpenAI error: {e}", exc_info=True)
        return get_recommendation(feedback_text)


@app.route('/phases', methods=['GET'])
def get_phase_labels():
    try:
        app_id = 3 # detect_app_id_from_port(request.host)  # You already have this logic

        conn = MySQLdb.connect(
            host="localhost",
            user="openai_user",
            passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
            #passwd=db_password,
            db="openai_prompts",
            #db="openai_prompts",
            charset="utf8mb4"
        )

        cursor = conn.cursor(MySQLdb.cursors.DictCursor)

        cursor.execute("""
            SELECT phase_number, section_label
            FROM phase_content
            WHERE app_id = %s
            ORDER BY phase_number ASC
        """, (app_id,))
        results = cursor.fetchall()

        return jsonify(results)

    except Exception as e:
        logging.error(f"Error fetching phase labels: {e}")
        return jsonify([]), 500
 

@app.route('/store-feedback-and-summary', methods=['POST'])
def store_feedback_and_summary():
    data = request.get_json() or {}
    phase = int(data.get("phase", 0))
    summary = (data.get("summary") or "").strip()
    feedbacktext = (data.get("feedback") or "").strip()

    app_id = detect_app_id_from_port(request.host)
    session_id = session.get("session_id") or str(uuid.uuid4())
    session["session_id"] = session_id
    session["app_id"] = app_id

    logging.info(f"🟢 /store-feedback-and-summary sess={session_id} app={app_id} phase={phase}")

    conn = get_db_connection()
    cur = conn.cursor()

    # Try to UPDATE the most recent row for this session/app
    cur.execute("""
        UPDATE session_feedback
           SET feedbacktext = COALESCE(NULLIF(%s,''), feedbacktext),
               summary      = COALESCE(NULLIF(%s,''), summary)
         WHERE session_id = %s AND app_id = %s
         ORDER BY id DESC
         LIMIT 1
    """, (feedbacktext, summary, session_id, app_id))
    rows = cur.rowcount

    if rows == 0:
        # If none exists yet, INSERT one seed row
        cur.execute("""
            INSERT INTO session_feedback
                (tenant_id, user_id, session_id, app_id, phase_number, feedbacktext, summary)
            VALUES (%s, %s, %s, %s, %s, %s, %s)
        """, (session.get("tenant_id",1), session.get("user_id",1),
              session_id, app_id, phase, feedbacktext, summary))
        rows = cur.rowcount

    conn.commit()
    cur.close(); conn.close()

    return jsonify({"status": "ok", "rows_affected": rows, "session_id": session_id}), 200


def get_latest_summary(session_id: str) -> str | None:
    """
    Fetch the most recent non-null `summary` for a given session_id
    from the `session_feedback` table, ordered by phase_number then id.
    Returns the summary string or None if not found.
    """
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4",
        cursorclass=MySQLdb.cursors.DictCursor
    )

    try:
        cursor = conn.cursor()
        cursor.execute("""
            SELECT summary
              FROM session_feedback
             WHERE session_id = %s
               AND summary IS NOT NULL
             ORDER BY phase_number DESC, id DESC
             LIMIT 1
        """, (session_id,))

        row = cursor.fetchone()
        if row:
            summary = row['summary']
            logging.info(f"[DB] Loaded summary for session {session_id}: {summary}")
            return summary
        else:
            logging.info(f"[DB] No summary found for session {session_id}")
            return None
    finally:
        conn.close()


def get_rhythm_words(session_id: str, app_id: int | None = None) -> str | None:
    conn = get_db_connection()
    cur = conn.cursor()
    if app_id is None:
        cur.execute("""
            SELECT rhythm_words
              FROM session_feedback
             WHERE session_id = %s AND rhythm_words IS NOT NULL
             ORDER BY id DESC
             LIMIT 1
        """, (session_id,))
    else:
        cur.execute("""
            SELECT rhythm_words
              FROM session_feedback
             WHERE session_id = %s AND app_id = %s AND rhythm_words IS NOT NULL
             ORDER BY id DESC
             LIMIT 1
        """, (session_id, app_id))
    row = cur.fetchone()
    cur.close(); conn.close()
    return row[0] if row else None


print("==== URL MAP ====")
print(app.url_map)
print("==== END MAP ====")

# === Main entrypoint ===
#if __name__ == '__main__':
#    app.run(host='0.0.0.0', port=5001, debug=True)

if __name__ == '__main__':
    app.run(
        host='0.0.0.0',
        port=5006,
        ssl_context=(
            '/var/www/html/decompression/certs/fullchain.pem',
            '/var/www/html/decompression/certs/privkey.pem'
        ),
        debug=True
    )
    
