import os
from openai import OpenAI
from dotenv import load_dotenv
import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional

# Load environment variables
load_dotenv()

# Setup logging
logger = logging.getLogger(__name__)


class AIService:
    def __init__(self):
        """Initialize OpenAI client with API key from environment"""
        self.client = OpenAI(
            api_key=os.getenv("OPENAI_API_KEY")
        )
    
    def _log_ai_usage(
        self,
        company_id: Optional[int],
        service_type: str,
        feature: str,
        model: str,
        input_tokens: int,
        output_tokens: int,
        total_tokens: int,
        prompt_length: int,
        completion_length: int,
        cost: float,
        status: str,
        execution_time: float,
        error_message: str = None,
        request_data: Dict[str, Any] = None,
        response_data: Dict[str, Any] = None
    ):
        """
        Log AI API usage to ai_usage_logs collection
        
        Args:
            company_id: Company ID (if available)
            service_type: 'openai', 'claude', etc.
            feature: 'resume_parsing', 'offer_letter', 'job_description', 'screening', etc.
            model: 'gpt-4', 'gpt-3.5-turbo', etc.
            input_tokens: Number of input tokens
            output_tokens: Number of output tokens
            total_tokens: Total tokens used
            prompt_length: Length of prompt in characters
            completion_length: Length of completion in characters
            cost: Calculated cost based on pricing
            status: 'success' or 'failed'
            execution_time: Time taken in milliseconds
            error_message: Error message if failed
            request_data: Request details (optional, truncated for storage)
            response_data: Response details (optional, truncated for storage)
        """
        try:
            from services.database import get_database
            db = get_database()
            ai_usage_logs_collection = db['ai_usage_logs']
            
            # Prepare log data
            log_data = {
                "company_id": company_id,
                "service_type": service_type,
                "feature": feature,
                "model": model,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "total_tokens": total_tokens,
                "prompt_length": prompt_length,
                "completion_length": completion_length,
                "cost": round(cost, 6),  # Round to 6 decimal places
                "status": status,
                "execution_time": round(execution_time, 2),  # Milliseconds
                "created_at": datetime.utcnow()
            }
            
            # Add optional fields
            if error_message:
                log_data["error_message"] = error_message[:500]  # Limit error message length
            
            if request_data:
                # Store truncated request data
                log_data["request_data"] = {
                    k: (v[:200] if isinstance(v, str) else v) 
                    for k, v in request_data.items()
                }
            
            if response_data:
                # Store truncated response data
                log_data["response_data"] = {
                    k: (v[:200] if isinstance(v, str) else v)
                    for k, v in response_data.items()
                }
            
            # Insert log
            ai_usage_logs_collection.insert_one(log_data)
            logger.debug(f"Logged AI usage: {feature} - {total_tokens} tokens - ${cost}")
            
        except Exception as e:
            logger.error(f"Error logging AI usage: {str(e)}")
            # Don't fail the main process if logging fails
            pass
    
    def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
        """
        Calculate cost based on model and token usage
        Prices as of 2024 (update as needed)
        
        Args:
            model: Model name
            input_tokens: Number of input tokens
            output_tokens: Number of output tokens
            
        Returns:
            Cost in USD
        """
        # Pricing per 1K tokens (update as needed)
        pricing = {
            "gpt-4": {"input": 0.03, "output": 0.06},
            "gpt-4-turbo": {"input": 0.01, "output": 0.03},
            "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
            "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
        }
        
        # Default to gpt-3.5-turbo pricing if model not found
        model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
        
        input_cost = (input_tokens / 1000) * model_pricing["input"]
        output_cost = (output_tokens / 1000) * model_pricing["output"]
        
        return input_cost + output_cost
    
    def ai_response(
        self, 
        question: str, 
        system_prompt: str = "", 
        max_tokens: int = 1000,
        company_id: Optional[int] = None,
        feature: str = "general"
    ) -> str:
        """
        Generate AI response with usage logging
        
        Args:
            question: User question/prompt
            system_prompt: System prompt for context
            max_tokens: Maximum tokens for response
            company_id: Company ID for tracking (optional)
            feature: Feature name for tracking (e.g., 'screening', 'offer_letter')
            
        Returns:
            AI generated response text
        """
        start_time = time.time()
        model = os.getenv("OPENAI_API_MODEL", "gpt-3.5-turbo")
        
        # Base system prompt
        base_system_prompt = "You are a helpful assistant."
        
        # Combine system prompts
        if system_prompt:
            final_system_prompt = base_system_prompt + " " + system_prompt
        else:
            final_system_prompt = base_system_prompt
        
        # Set temperature based on feature - lower for consistency in screening/analysis
        if feature in ["screening", "indexing_keywords"]:
            temperature = 0.1  # Very low for consistent scoring/analysis
        else:
            temperature = 0.7  # Default for creative tasks
        
        try:
            response = self.client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": final_system_prompt},
                    {"role": "user", "content": question}
                ],
                max_tokens=max_tokens,
                temperature=temperature
            )
            
            # Extract response data
            answer = response.choices[0].message.content
            usage = response.usage
            
            # Calculate execution time
            execution_time = (time.time() - start_time) * 1000  # Convert to milliseconds
            
            # Calculate cost
            cost = self._calculate_cost(
                model=model,
                input_tokens=usage.prompt_tokens,
                output_tokens=usage.completion_tokens
            )
            
            # Log AI usage
            self._log_ai_usage(
                company_id=company_id,
                service_type="openai",
                feature=feature,
                model=model,
                input_tokens=usage.prompt_tokens,
                output_tokens=usage.completion_tokens,
                total_tokens=usage.total_tokens,
                prompt_length=len(question) + len(final_system_prompt),
                completion_length=len(answer),
                cost=cost,
                status="success",
                execution_time=execution_time,
                request_data={
                    "question_preview": question[:200],
                    "system_prompt_preview": system_prompt[:200] if system_prompt else ""
                },
                response_data={
                    "answer_preview": answer[:200]
                }
            )
            
            return answer
        
        except Exception as e:
            # Calculate execution time even on error
            execution_time = (time.time() - start_time) * 1000
            
            # Log failed AI usage
            self._log_ai_usage(
                company_id=company_id,
                service_type="openai",
                feature=feature,
                model=model,
                input_tokens=0,
                output_tokens=0,
                total_tokens=0,
                prompt_length=len(question) + len(system_prompt or ""),
                completion_length=0,
                cost=0.0,
                status="failed",
                execution_time=execution_time,
                error_message=str(e),
                request_data={
                    "question_preview": question[:200],
                    "system_prompt_preview": system_prompt[:200] if system_prompt else ""
                }
            )
            
            return f"Error: {str(e)}"