import os
import openai
import logging
import sys
import uuid
import traceback

# Force reconfig logging before other modules use it
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    stream=sys.stdout,  # Ensure it's sent to systemd journal
    force=True          # ✅ Overwrite any earlier configs
)

# Force reconfiguration of logging before any other module uses it
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from flask import Flask, request, jsonify, session, send_from_directory, render_template, url_for
from flask_cors import CORS
from openai import OpenAI
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
from prompt_manager import call_openai_prompt
from app_config import get_max_phase
from flask import request
from qa_system import (
    get_next_part,
    get_recommendation,
    get_answer,
    detect_app_id_from_port,
    get_prompt_messages_by_template_name  # ✅ Add this line
)

import MySQLdb
import MySQLdb.cursors
import re

# === OpenAI client setup ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# === App setup ===
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
app.secret_key = 'super_secret_key'

# === Application-level constants ===
APP_ID = 5001  # could later be dynamic from port
#MAX_PHASE = get_max_phase(APP_ID)
#logging.info(f"MAX PHASE setting from DB: {MAX_PHASE}")

try:
    MAX_PHASE = get_max_phase(APP_ID)
    print(">>>> Reached top-level MAX_PHASE fetch")  # <-- Debug test line
    logging.info(f"MAX PHASE setting from DB: {MAX_PHASE}")
except Exception as e:
    logging.error(f"Initialization error for MAX_PHASE: {e}")

def extract_sound_files(sound_tags):
    """
    Extracts filenames from <sound:filename> tags.
    Example: "<sound:calm_1.mp4><sound:pause.mp4>" → ['calm_1.mp4', 'pause.mp4']
    """
    if not sound_tags:
        return []
    return re.findall(r"<sound:([^>]+)>", sound_tags)

#Get the port from the URL
@app.before_request
def set_app_id_from_port():
    try:
        host = request.host  # e.g. "localhost:5001"
        port = int(host.split(":")[1])
        session['app_id'] = port  # Save to session or global if needed
        logging.info(f"Detected APP_ID from port: {port}")
    except Exception as e:
        logging.error(f"Could not extract port from host: {e}")

# === Utility ===
def parse_instructions(phase_text):
    segments = re.split(r'(<sound:[\w-]+>)', phase_text)
    sanitized_segments = []
    sound_files = set()

    for segment in segments:
        if "<sound:" in segment:
            sound_marker = re.search(r'<sound:([\w-]+)>', segment)
            if sound_marker:
                filename = sound_marker.group(1)
                sanitized_segments.append({"type": "sound", "value": filename})
                sound_files.add(filename)
        else:
            text_content = segment.strip()
            if text_content:
                sanitized_segments.append({"type": "text", "value": text_content})

    return sanitized_segments, list(sound_files)

# === Routes ===

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/static/<path:path>')
def static_files(path):
    return send_from_directory('static', path)


@app.route('/favicon.ico')
def favicon():
    return send_from_directory(
        os.path.join(app.root_path, 'static'),
        'favicon.ico',
        mimetype='image/vnd.microsoft.icon'
    )

@app.route("/api/prompt/<int:prompt_id>", methods=["POST"])
def run_prompt(prompt_id):
    data = request.get_json()
    user_input = data.get("user_input", "")
    session_id = data.get("session_id", "")


    logging.info(f"[DEBUG] run_prompt(): prompt_id = {prompt_id}, session_id = {session_id}")


    db = get_db_connection()
    #cur = db.cursor()
    cur = db.cursor(MySQLdb.cursors.DictCursor)

    cur.execute("""
        SELECT sequence, role, content
          FROM prompt_messages
         WHERE prompt_id = %s
         ORDER BY sequence ASC
    """, (prompt_id,))
    messages = cur.fetchall()
    db.close()

    logging.info(f"🧪 prompt_id: {prompt_id}, session_id: {session_id}")

    # Special case: if phase 10 and session_id present, inject summary
    if prompt_id == 4 and session_id:
        logging.info(f"[DEBUG] Injecting summary + rhythm_words for prompt {prompt_id}, session {session_id}")

        needs_summary = any("<summary>" in msg["content"] for msg in messages)
        needs_rhythm  = any("<rhythm_text>" in msg["content"] for msg in messages)

        if session_id and (needs_summary or needs_rhythm):
            summary = get_latest_summary(session_id) or "[no summary found]"
            rhythm_text = get_rhythm_words(session_id) or "[no rhythm text]"

            logging.info(f"🔁 Injecting summary and rhythm_text → {summary} | {rhythm_text}")


            for msg in messages:
                msg["content"] = msg["content"].replace("<summary>", summary)
                msg["content"] = msg["content"].replace("<rhythm_text>", rhythm_text)

            logging.info(f"[PROMPT {prompt_id}] Injected summary + rhythm_text → {summary} | {rhythm_text}")


        #summary = get_latest_summary(session_id) or "[no summary found]"
        #rhythm_words = get_rhythm_words(session_id) or "[no rhythm words]"
        #for msg in messages:
        #    if "<summary>" in msg["content"]:
        #        msg["content"] = msg["content"].replace("<summary>", summary)
        #    if "<rhythm_text>" in msg["content"]:
        #        msg["content"] = msg["content"].replace("<rhythm_text>", rhythm_words)    
        #logging.info(f"[PROMPT {prompt_id}] Injected summary and rhythm_word for session {session_id}: {summary}: {rhythm_words}")

    # Inject user input if placeholder exists
    for msg in messages:
        if "<user_input>" in msg["content"]:
            msg["content"] = msg["content"].replace("<user_input>", user_input)

    # Format for OpenAI
    openai_messages = [{"role": m["role"], "content": m["content"]} for m in messages]

    # 🔍 Log full request to OpenAI
    logging.info(f"[OPENAI] Prompt ID: {prompt_id}, Session: {session_id}")
    for i, msg in enumerate(openai_messages):
        logging.info(f"[OPENAI][{i}] {msg['role']}: {msg['content']}")

    # Send to OpenAI

    try:
        response = openai_client.chat.completions.create(
            model="gpt-4",
            messages=messages,
            temperature=0.7
        )
    except Exception as e:
        logging.exception(f"❌ OpenAI call failed for prompt_id {prompt_id}, session {session_id}")
        return jsonify({"error": str(e)}), 500

    output = response.choices[0].message.content.strip()

    # 🔍 Log response
    logging.info(f"[OPENAI] Response:\n{output}")

    return jsonify({"output": output})


@app.route('/reset-phase', methods=['POST'])
def reset_phase():
    session.clear()
    session['phase'] = 0
    return jsonify({"message": "Phase reset to 0"}), 200

def segment_instructions(instruction_text):
    segments = []
    lines = instruction_text.strip().split('\n')

    for line in lines:
        sound_match = re.match(r"<sound:(.+)>", line.strip())
        if sound_match:
            segments.append({
                "type": "sound",
                "value": sound_match.group(1).strip()
            })
        else:
            segments.append({
                "type": "text",
                "value": line.strip()
            })

    return segments


@app.route('/next-phase', methods=['POST'])
def next_phase():
    data = request.get_json()
    choice = data.get("choice")
    current_phase = session.get("current_phase", 0)

    logging.info(f"CURRENT PHASE in /next-phase: {current_phase}")
    
    if choice == "REPEAT" or (choice and choice.startswith("(B)")):
        next_phase = current_phase  # repeat same
    else:
        next_phase = current_phase + 1

    logging.info(f"NEW PHASE in /next-phase: {next_phase}")

#-----------------------------------------

# ✅ Hardcoded test jump from phase 3 to 11

    if current_phase == 0:
        logging.warning("🧪 HARD-CODED TEST JUMP: Skipping phase 4, going to 11")
        next_phase = 4

#----------------------------------------

    session["current_phase"] = next_phase

    app_id = detect_app_id_from_port(request.host)

    phase_data = get_next_part(app_id=app_id, phase=next_phase)
    if not phase_data:
        return jsonify({"error": "Phase not found"}), 404

    section_label = phase_data.get('section_label', 'UNKNOWN').strip()
    combined_text = f"{section_label}\n{phase_data['instruction_text']}" if section_label else phase_data['instruction_text']


    return jsonify({
        "app_id": app_id,
        "phase": next_phase,
        "segmentedInstructions": segment_instructions(combined_text),
        "feedbackRequest": phase_data.get("feedbackRequest_text") or "",
        "questions": [
            {"choice": line[:3].strip(), "text": line.strip()}
            for line in phase_data["question_text"].split("\n") if line.strip()
        ],
        "soundFiles": extract_sound_files(combined_text),
        "ai_prompt_id": phase_data["ai_prompt_id"],
        "is_last_phase": (next_phase == get_max_phase(app_id))
    })

@app.route('/store-rhythm-words', methods=['POST'])
def store_rhythm_words():
    data = request.get_json()
    session_id = data.get('session_id')
    app_id = data.get('app_id')
    phase = data.get('phase')
    rhythm_words = data.get('rhythm_words')

    if not session_id or not phase:
        return jsonify({'error': 'Missing session_id or phase'}), 400

    try:
        store_rhythm_words_in_db(session_id, app_id, rhythm_words)
        return jsonify({'status': 'success'})
    except Exception as e:
        logging.error(f"Failed to store rhythm_words: {e}")
        return jsonify({'error': 'Failed to store rhythm_words'}), 500

@app.route('/feedback', methods=['POST'])
def handle_feedback():
    #data = request.get_json()
    #feedback_text = data.get('feedback', '')
    #recommendation = get_chatgpt_recommendation(feedback_text)
    #return jsonify({"recommendation": recommendation})

    data = request.get_json()
    feedback_text = data.get("feedback")
    app_id = detect_app_id_from_port(request.host)
    phase = data.get("phase", 1)

    logging.info(f"🧠 Getting GPT recommendation for phase {phase}, app_id {app_id}")

    result = get_chatgpt_recommendation(feedback_text, app_id, phase)

    if phase == 2:
        return jsonify({"summary": result})
    else:
        return jsonify({"recommendation": result})

def store_rhythm_words_in_db(session_id, app_id, rhythm_words):
    conn = get_db_connection()
    cur = conn.cursor()

    cur.execute("""
        UPDATE session_feedback
           SET rhythm_words = %s
         WHERE session_id = %s AND app_id = %s
    """, (rhythm_words, session_id, app_id))

    conn.commit()
    cur.close()
    conn.close()

def get_db_connection():
    return MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4"
    )    

def store_feedback(tenant_id, user_id, session_id, app_id, phase_number, feedbacktext):
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",  # replace or load via env var
        db="openai_prompts",
        charset="utf8mb4"
    )
    cursor = conn.cursor()
    cursor.execute("""
        INSERT INTO session_feedback (tenant_id, user_id, session_id, app_id, phase_number, feedbacktext)
        VALUES (%s, %s, %s, %s, %s, %s)
    """, (tenant_id, user_id, session_id, app_id, phase_number, feedbacktext))
    conn.commit()
    conn.close()

def store_summary(session_id, phase_number, summary):
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4"
    )
    cursor = conn.cursor()
    cursor.execute("""
        UPDATE session_feedback
        SET summary = %s
        WHERE session_id = %s AND phase_number = %s
        ORDER BY id DESC
        LIMIT 1
    """, (summary, session_id, phase_number))
    conn.commit()
    conn.close()   
    

def get_chatgpt_recommendation(feedback_text, app_id=1, phase=1):
    try:
        app_id = 2
        phase = session.get("current_phase", 0)
        logging.info(f"🧠 Getting GPT recommendation for phase {phase}, app_id {app_id}")

        phase_data = get_next_part(app_id, phase)
        logging.info(f"🔎 Phase data fetched: {phase_data}")
        if not phase_data:
            logging.error(f"❌ No phase data found for app_id={app_id}, phase={phase}")
            return get_recommendation(feedback_text)

        prompt_name = phase_data.get("prompt_template_name")
        if not prompt_name:
            logging.warning("⚠️ Missing prompt_template_name. Falling back to rule-based recommendation.")
            return get_recommendation(feedback_text)

        messages = get_prompt_messages_by_template_name(prompt_name)
        if not messages:
            logging.warning(f"⚠️ No messages found for prompt template '{prompt_name}'. Falling back.")
            return get_recommendation(feedback_text)

        # Log the prompt messages before sending
        logging.debug(f"📨 Prompt messages before adding user input:\n{messages}")
        logging.info(f"📨 Prompt messages before adding user input:\n{messages}")

        messages.append({
            "role": "user",
            "content": feedback_text
        })

        logging.debug(f"📡 Sending messages to OpenAI:\n{messages}")

        client = OpenAI()
        completion = client.chat.completions.create(
#       completion = openai.ChatCompletion.create(
            model="gpt-4",
            messages=messages
        )

        response = completion.choices[0].message.content.strip()
        logging.info("✅ GPT recommendation received.")
        return response

    except Exception as e:
        logging.error(f"🔥 OpenAI error in get_chatgpt_recommendation: {e}")
#        logging.error(f"Second: Error in get_chatgpt_recommendation: {traceback.format_exc()}")
        return get_recommendation(feedback_text)


@app.route('/phases', methods=['GET'])
def get_phase_labels():
    try:
        app_id = detect_app_id_from_port(request.host)  # You already have this logic

        conn = MySQLdb.connect(
            host="localhost",
            user="openai_user",
            passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
            #passwd=db_password,
            db="openai_prompts",
            #db="openai_prompts",
            charset="utf8mb4"
        )

        cursor = conn.cursor(MySQLdb.cursors.DictCursor)

        cursor.execute("""
            SELECT phase_number, section_label
            FROM phase_content
            WHERE app_id = %s
            ORDER BY phase_number ASC
        """, (app_id,))
        results = cursor.fetchall()

        return jsonify(results)

    except Exception as e:
        logging.error(f"Error fetching phase labels: {e}")
        return jsonify([]), 500
 

@app.route('/store-feedback-and-summary', methods=['POST'])
def store_feedback_and_summary():
    data = request.get_json()

    phase = data.get("phase")
    summary = data.get("summary")         
    feedbacktext = data.get("feedback")    

    # Log everything
    logging.info(f"🟢 PHASE: {phase}")
    logging.info(f"🟢 FEEDBACKTEXT: {feedbacktext}")
    logging.info(f"🟢 SUMMARY: {summary}")


    app_id = detect_app_id_from_port(request.host)
    session_id = session.get("session_id", str(uuid.uuid4()))
    session["session_id"] = session_id

    tenant_id = session.get("tenant_id", 1)
    user_id = session.get("user_id", 1)

    # 🔍 Debug print
    logging.info(f"✅ Executing /store-feedback-and-summary: data={data}, phase={phase}, summary={summary}, feedback={feedbacktext}, summary={summary}, app_id={app_id}, session_id={session_id}")

    if not feedbacktext:
        logging.warning("⚠️ No feedback text provided.")
    if not summary:
        logging.warning("⚠️ No summary provided.")
    
    
    store_feedback(tenant_id, user_id, session_id, app_id, phase, feedbacktext)
    store_summary(session_id, phase, summary)

    return jsonify({
        "status": "summary confirmed",
        "session_id": session_id
    }), 200


def get_latest_summary(session_id: str) -> str | None:
    """
    Fetch the most recent non-null `summary` for a given session_id
    from the `session_feedback` table, ordered by phase_number then id.
    Returns the summary string or None if not found.
    """
    conn = MySQLdb.connect(
        host="localhost",
        user="openai_user",
        passwd="IOyg76H2l%252BewRX2xhsDJAo7qnfVDHtx9RB%253D%",
        db="openai_prompts",
        charset="utf8mb4",
        cursorclass=MySQLdb.cursors.DictCursor
    )

    try:
        cursor = conn.cursor()
        cursor.execute("""
            SELECT summary
              FROM session_feedback
             WHERE session_id = %s
               AND summary IS NOT NULL
             ORDER BY phase_number DESC, id DESC
             LIMIT 1
        """, (session_id,))

        row = cursor.fetchone()
        if row:
            summary = row['summary']
            logging.info(f"[DB] Loaded summary for session {session_id}: {summary}")
            return summary
        else:
            logging.info(f"[DB] No summary found for session {session_id}")
            return None
    finally:
        conn.close()


def get_rhythm_words(session_id: str) -> str | None:
    conn = get_db_connection()
    cur = conn.cursor()

    cur.execute("""
        SELECT rhythm_words
          FROM session_feedback
         WHERE session_id = %s
           AND rhythm_words IS NOT NULL
         ORDER BY id DESC
         LIMIT 1
    """, (session_id,))

    row = cur.fetchone()
    cur.close()
    conn.close()

    return row[0] if row else None


# === Main entrypoint ===
#if __name__ == '__main__':
#    app.run(host='0.0.0.0', port=5001, debug=True)

if __name__ == '__main__':
    app.run(
        host='0.0.0.0',
        port=5001,
        ssl_context=(
            '/var/www/html/decompression/certs/fullchain.pem',
            '/var/www/html/decompression/certs/privkey.pem'
        ),
        debug=True
    )
    
