from openai import OpenAI
import openai
import json
import time
import re
from pinecone import Pinecone
from datetime import datetime
import conexiones
import json
import a_env_vars
import mysql.connector
from mysql.connector import Error
import sys
import requests
import traceback
import random
import tiktoken
from dotenv import load_dotenv
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from clases import ChatResponse
import subprocess
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading

class ConfigManager:
    _instance = None
    REQUIRED_KEYS = [
        'OPENAI_API_KEY',
        'PINECONE_API_KEY',
        'OPENAI_API_VALOR_MAS',
        'SERVER_ENV',
        'XI_API_KEY',
        'AUDIO_URL',
        'RUTA_AUDIO'
    ]
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._load_config()
        return cls._instance
    
    def _load_config(self):
        # Obtener ruta al archivo .env
        env_path = os.getenv('ENV_PATH', '/var/www/dev.catia.catastroantioquia-mas.com/valormas/.env')
        
        if not os.path.exists(env_path):
            raise FileNotFoundError(f"Archivo .env no encontrado en {env_path}")
        
        # Cargar variables de entorno desde .env
        load_dotenv(env_path)
        
        # Cargar y validar las claves requeridas
        missing_keys = []
        for key in self.REQUIRED_KEYS:
            value = os.getenv(key)
            if not value and key != 'SERVER_ENV':  # SERVER_ENV puede estar vacío?
                missing_keys.append(key)
            setattr(self, key, value)
        
        if missing_keys:
            raise ValueError(
                f"Claves faltantes en .env: {', '.join(missing_keys)}\n"
                f"Archivo .env: {env_path}\n"
                "Por favor verifica las claves en el archivo .env."
            )
        
        # Normalizar rutas terminadas en '/'
        if self.AUDIO_URL and not self.AUDIO_URL.endswith('/'):
            self.AUDIO_URL += '/'
        if self.RUTA_AUDIO and not self.RUTA_AUDIO.endswith('/'):
            self.RUTA_AUDIO += '/'

# Crear instancia única y cargar configuración
try:
    config = ConfigManager()
except Exception as e:
    print(f"Error crítico de configuración: {str(e)}")
    raise SystemExit(1)

os.environ['PATH'] += ':/usr/bin'

# OPTIMIZACIÓN 1: Pool de modelos para evitar recargas
class ModelManager:
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialize()
        return cls._instance
    
    def _initialize(self):
        """Inicializar modelo una sola vez"""
        self.model_dir = "/data/llama-model-base-crm"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="llama-inference")
        self._load_model()
        
    def _load_model(self):
        """Cargar modelo optimizado"""
        try:
            logging.info("Cargando modelo Llama...")
            
            # Tokenizer con configuraciones optimizadas
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_dir,
                use_fast=True  # Usar tokenizer rápido si está disponible
            )
            
            # Configurar pad token si no existe
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Modelo con cache habilitado para inferencia rápida
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_dir,
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto",
                low_cpu_mem_usage=True,
                load_in_4bit=False,
                load_in_8bit=False,
                use_cache=True,  # ✅ Importante para velocidad
            )

            if hasattr(self.model.config, "quantization_config"):
                self.model.config.quantization_config = None

            self.model.eval()
            
            # Pre-calentar modelo con un ejemplo
            self._warmup()
            
            logging.info(f"Modelo cargado exitosamente en {self.device}")
            
        except Exception as e:
            logging.error(f"Error cargando modelo: {e}")
            raise
    
    def _warmup(self):
        """Pre-calentar el modelo"""
        try:
            dummy_prompt = "### Instruction:\nHola\n\n### Input:\nTest\n\n### Response:\n"
            inputs = self.tokenizer([dummy_prompt], return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                _ = self.model.generate(
                    **inputs,
                    max_new_tokens=5,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Limpiar memoria después del warmup
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
                
            logging.info("Modelo pre-calentado correctamente")
        except Exception as e:
            logging.warning(f"Error en pre-calentamiento: {e}")
    
    def generate_sync(self, prompt_text):
        """Generación síncrona optimizada"""
        try:
            with torch.no_grad():
                inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.device)
                input_tokens = inputs['input_ids'].shape[1]

                start_time = time.time()

                # Parámetros optimizados para velocidad y calidad
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=200,
                    min_new_tokens=10,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    top_k=50,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    use_cache=True,
                    early_stopping=True,
                    repetition_penalty=1.1
                )

                end_time = time.time()

                response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                output_tokens = outputs.shape[1] - input_tokens
                inference_time_ms = int((end_time - start_time) * 1000)
                
                # Limpiar cache GPU
                if self.device.type == 'cuda':
                    torch.cuda.empty_cache()

                return {
                    'response_text': response_text,
                    'input_tokens': input_tokens,
                    'output_tokens': output_tokens,
                    'inference_time_ms': inference_time_ms
                }
                
        except Exception as e:
            logging.error(f"Error en generación: {e}")
            raise

# OPTIMIZACIÓN 2: Sistema de colas para controlar concurrencia
class RequestQueue:
    def __init__(self, max_concurrent=6):
        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.queue_size = 0
        self.max_queue_size = 30
        self._lock = asyncio.Lock()
    
    async def process_request(self, coro):
        async with self._lock:
            if self.queue_size >= self.max_queue_size:
                raise Exception("Cola de requests llena. Intenta más tarde.")
            self.queue_size += 1
        
        try:
            async with self.semaphore:
                result = await coro
                return result
        finally:
            async with self._lock:
                self.queue_size -= 1

# Check if gcc is installed
def check_gcc():
    try:
        subprocess.run(["gcc", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        logging.info("gcc compiler is installed and available at /usr/bin/gcc.")
    except subprocess.CalledProcessError:
        logging.error("gcc compiler is not available at /usr/bin/gcc. Please install gcc.")
        raise SystemExit(1)

# Check GCC during app startup
try:
    check_gcc()
except:
    logging.warning("GCC no disponible, continuando sin verificación")

# OPTIMIZACIÓN 3: Instancias globales optimizadas
try:
    model_manager = ModelManager()
    request_queue = RequestQueue()
    logging.info("Sistema inicializado correctamente")
except Exception as e:
    logging.error(f"Error en inicialización: {e}")
    # Fallback a variables originales si falla la optimización
    model_dir = "/data/llama-model-base-crm"
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True,
        load_in_4bit=False,
        load_in_8bit=False
    )
    if hasattr(model.config, "quantization_config"):
        model.config.quantization_config = None
    model.eval()
    model_manager = None
    request_queue = None

# OPTIMIZACIÓN 4: Pool de conexiones DB
class DBPool:
    _pools = {}
    _lock = threading.Lock()
    
    @classmethod
    def get_connection(cls, config_name='config'):
        try:
            if config_name == 'config':
                config_data = conexiones.config
            else:
                config_data = conexiones.config2
            
            return mysql.connector.connect(
                **config_data,
                autocommit=True,  # ✅ Mejora performance
                use_unicode=True,
                charset='utf8mb4',
                connect_timeout=10,
                read_timeout=30,
                write_timeout=30
            )
        except mysql.connector.Error as err:
            logging.error(f"Error conexión DB: {err}")
            return None

# función para conectar a la base de datos en local - OPTIMIZADA
def conectar_db():
    return DBPool.get_connection('config')

def conectar_db2():
    return DBPool.get_connection('config2')

# Función para crear un hilo de trabajo - OPTIMIZADA
def crear_hilo(prompt: str, idcliente: int, id_asistente: str):
    connection = conectar_db2()
    
    if connection is None:
        return "Error al conectar con la base de datos."
    
    try:
        cursor = connection.cursor()
        modelo = "Llama CRM"
        llegada = "WEB"
        
        # Consulta para insertar una nueva conversación
        query = """
        INSERT INTO conversaciones (id_asistente, modelo, llegada, fecha_hora, pregunta, usuario)
        VALUES (%s, %s, %s, %s, %s, %s)
        """
        
        # Ejecutar la consulta
        cursor.execute(query, (id_asistente, modelo, llegada, datetime.now(), prompt, idcliente))
        
        # Obtener el último ID insertado
        last_inserted_id = cursor.lastrowid
        
        # Generar el hilo de conversación
        hilo_conversacion = f"{last_inserted_id}{id_asistente}{idcliente}"
        
        update_query = """
        UPDATE conversaciones 
        SET hilo_conversacion = %s 
        WHERE id = %s
        """
        cursor.execute(update_query, (hilo_conversacion, last_inserted_id))
        
        return hilo_conversacion
    
    except mysql.connector.Error as err:
        logging.error(f"Error creando hilo: {err}")
        return f"Error al insertar en la base de datos: {err}"
    
    finally:
        if connection and connection.is_connected():
            cursor.close()
            connection.close()

# OPTIMIZADA            
def actualizar_respuesta(hilo_conversacion, respuesta, input_tokens, output_tokens):
    connection = conectar_db2()
    
    if connection is None:
        return "Error al conectar con la base de datos."
    
    try:
        cursor = connection.cursor()
        
        # Actualizar la respuesta, input y output en la base de datos
        query = """
        UPDATE conversaciones
        SET respuesta = %s, input = %s, output = %s
        WHERE hilo_conversacion = %s
        """
        
        cursor.execute(query, (respuesta, input_tokens, output_tokens, hilo_conversacion))
        return True
    
    except mysql.connector.Error as err:
        logging.error(f"Error actualizando respuesta: {err}")
        return f"Error al actualizar la respuesta en la base de datos: {err}"
    
    finally:
        if connection and connection.is_connected():
            cursor.close()
            connection.close()

def limpiar_texto(texto):
    # Mapeo de caracteres erróneos a caracteres correctos con acentos
    reemplazos = {
        "Ã¡": "á", "Ã©": "é", "Ã­": "í", "Ã³": "ó", "Ãº": "ú",
        "Ã±": "ñ", "Ã‘": "Ñ", "Ã¼": "ü", "Ã‰": "É", "Ãš": "Ú",
        "Ã“": "Ó", "Ã�": "Í", "Ã€": "À"
    }

    # Reemplazar caracteres mal codificados
    for incorrecto, correcto in reemplazos.items():
        texto = texto.replace(incorrecto, correcto)

    # Eliminar saltos de línea y espacios innecesarios
    texto = re.sub(r"\s+", " ", texto).strip()

    return texto

# OPTIMIZADA - Con timeout y mejor manejo de errores
def obtener_configuracion_voz_y_generar_audio(id_asistente, texto):
    if not config.XI_API_KEY:
        return None
    
    texto_limpio = limpiar_texto(texto)
    connection = conectar_db()

    if connection is None:
        return {"error": "Error al conectar con la base de datos."}

    try:
        cursor = connection.cursor(dictionary=True)

        query_voz = """
            SELECT 
                asistentes.id_voz,
                voces.nombre_voz,
                voces.id_eleven,
                asistentes_voces.similarity_boost,
                asistentes_voces.stability,
                asistentes_voces.style,
                asistentes_voces.use_speaker_boost
            FROM asistentes
            INNER JOIN asistentes_voces ON asistentes.id_voz = asistentes_voces.id
            INNER JOIN voces ON asistentes_voces.id_modelo_voz = voces.id
            WHERE asistentes.id = %s
        """
        cursor.execute(query_voz, (id_asistente,))
        asistente = cursor.fetchone()

        if not asistente:
            return {"error": "No se encontró un asistente con el ID proporcionado."}

        # Preparar payload JSON correctamente
        payload = {
            "model_id": "eleven_multilingual_v2",
            "text": texto_limpio,
            "voice_settings": {
                "similarity_boost": float(asistente['similarity_boost']),
                "stability": float(asistente['stability']),
                "style": float(asistente['style']),
                "use_speaker_boost": bool(asistente['use_speaker_boost'])
            }
        }

        # URL de ElevenLabs
        api_url = f"https://api.elevenlabs.io/v1/text-to-speech/{asistente['id_eleven']}?optimize_streaming_latency=0&output_format=mp3_44100_128"

        headers = {
            "Content-Type": "application/json",
            "xi-api-key": config.XI_API_KEY
        }

        # OPTIMIZACIÓN: Timeout y mejores headers
        response = requests.post(
            api_url, 
            headers=headers, 
            json=payload,  # ✅ Usar json= en lugar de data=
            timeout=30
        )

        if response.status_code == 200:
            numero_aleatorio = random.randint(10000, 99999)
            nombre_archivo = f"audio_{numero_aleatorio}_{id_asistente}.mp3"
            ruta_audio = f"{config.RUTA_AUDIO}{nombre_archivo}"
            audio_url = f"{config.AUDIO_URL}{nombre_archivo}"

            # Crear directorio si no existe
            os.makedirs(os.path.dirname(ruta_audio), exist_ok=True)

            with open(ruta_audio, "wb") as audio_file:
                audio_file.write(response.content)

            return audio_url
        else:
            logging.error(f"Error ElevenLabs: {response.status_code} - {response.text}")
            return {
                "success": False,
                "error": f"Error en ElevenLabs: {response.status_code} - {response.text}"
            }

    except Exception as e:
        logging.error(f"Error generando audio: {e}")
        return {"error": f"Error generando audio: {str(e)}"}
    
    finally:
        if connection and connection.is_connected():
            cursor.close()
            connection.close()

# FUNCIÓN PRINCIPAL OPTIMIZADA - Mantiene la interfaz original pero con mejoras internas
async def generar_respuesta(hilo_conversacion: str, prompt: str, idcliente: int, id_asistente: str, volume_up: str):
    
    prompt_text = (
        f"### Instruction:\n"
        f"Eres Catia, un asistente amable que responde de forma breve, clara y precisa a preguntas sobre trámites de catastro antioquia.\n\n"
        f"### Input:\n{prompt}\n\n"
        f"### Response:\n"
    )
    
    try:
        # OPTIMIZACIÓN: Usar el modelo manager si está disponible
        if model_manager and hasattr(model_manager, 'generate_sync'):
            loop = asyncio.get_event_loop()
            generation_result = await loop.run_in_executor(
                model_manager.executor,
                model_manager.generate_sync,
                prompt_text
            )
            
            response_text = generation_result['response_text']
            input_tokens = generation_result['input_tokens']
            output_tokens = generation_result['output_tokens']
            inference_time_ms = generation_result['inference_time_ms']
            
        else:
            # Fallback al método original si falla la optimización
            inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
            input_tokens = inputs['input_ids'].shape[1]

            start_time = time.time()

            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )

            end_time = time.time()

            response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            output_tokens = outputs.shape[1]
            inference_time_ms = int((end_time - start_time) * 1000)
        
        total_tokens = input_tokens + output_tokens
        
        # Verificar/crear hilo de forma asíncrona
        loop = asyncio.get_event_loop()
        hilo_conversacion = await loop.run_in_executor(
            None, verificar, hilo_conversacion, prompt, idcliente, id_asistente
        )
        
        match = re.search(r'### Response:\n(.*?)(### Instruction:|$)', response_text, re.DOTALL)
        
        if match:
            respuesta_limpia = match.group(1).strip()
            
            # Actualizar respuesta de forma asíncrona sin esperar
            asyncio.create_task(
                asyncio.to_thread(actualizar_respuesta, hilo_conversacion, respuesta_limpia, input_tokens, output_tokens)
            )
            
            # Generar audio si es necesario
            ruta_audio = None
            if volume_up == "si":
                ruta_audio = await loop.run_in_executor(
                    None, obtener_configuracion_voz_y_generar_audio, id_asistente, respuesta_limpia
                )
            
            return ChatResponse(
                Exito="Exito",
                Respuesta=respuesta_limpia,
                Tarea_Creada=hilo_conversacion,
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                total_tokens=total_tokens,
                inference_time_ms=inference_time_ms,
                thread_id=hilo_conversacion,
                ruta_audio=ruta_audio,
            )
        
        else:
            return ChatResponse(
                Exito="Error",
                Respuesta="No se pudo extraer la respuesta del modelo.",
                Tarea_Creada=hilo_conversacion,
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                total_tokens=total_tokens,
                inference_time_ms=inference_time_ms,
                thread_id=hilo_conversacion,
            )
    
    except Exception as e:
        logging.error(f"Error en generar_respuesta: {e}")
        return ChatResponse(
            Exito="Error",
            Respuesta=f"Error inesperado: {str(e)}",
            Tarea_Creada=hilo_conversacion,
            input_tokens=0,
            output_tokens=0,
            total_tokens=0,
            inference_time_ms=0,
            thread_id=hilo_conversacion,
        )

def verificar(hilo_conversacion, prompt, idcliente, id_asistente):
    if hilo_conversacion == "nada":
        hilo_conversacion = crear_hilo(prompt, idcliente, id_asistente)
    return hilo_conversacion