import sys
import os
from pathlib import Path
import torch
from typing import Dict, Any, List, Optional, Tuple, Union
from dotenv import load_dotenv
import logging
import tempfile
import gc
import uuid
import json
from datetime import datetime
from memories.models.base_model import BaseModel
from memories.models.api_connector import get_connector
# Load environment variables
load_dotenv()
[docs]class LoadModel:
[docs] def __init__(self,
use_gpu: bool = True,
model_provider: str = None,
deployment_type: str = None, # "local" or "api"
model_name: str = None,
api_key: str = None,
endpoint: str = None, # Add endpoint parameter
device: str = None):
"""
Initialize model loader with configuration.
Args:
use_gpu (bool): Whether to use GPU if available
model_provider (str): The model provider (e.g., "deepseek", "azure-ai", "mistral")
deployment_type (str): Either "local" or "api"
model_name (str): Short name of the model from BaseModel.MODEL_MAPPINGS
api_key (str): API key for the model provider (required for API deployment type)
endpoint (str): Endpoint URL for the model provider (optional)
device (str): Specific GPU device to use (e.g., "cuda:0", "cuda:1")
"""
# Setup logging
self.instance_id = str(uuid.uuid4())
self.logger = logging.getLogger(__name__)
# Load configuration
self.config = self._load_config()
# Set default values from config if not provided
if not all([model_provider, deployment_type, model_name]):
default_model = self.config["default_model"]
default_config = self.config["models"][default_model]
model_provider = model_provider or default_config["provider"]
deployment_type = deployment_type or default_config["type"]
model_name = model_name or default_model
# Validate inputs
if deployment_type not in self.config["deployment_types"]:
raise ValueError(f"deployment_type must be one of: {self.config['deployment_types']}")
# Special handling for azure-ai provider
if model_provider == "azure-ai":
if not endpoint:
raise ValueError("endpoint is required for azure-ai provider")
elif model_provider not in self.config["supported_providers"]:
raise ValueError(f"model_provider must be one of: {self.config['supported_providers']}")
if deployment_type == "api" and not api_key:
raise ValueError("api_key is required for API deployment type")
# Store configuration
self.use_gpu = use_gpu and torch.cuda.is_available()
self.model_provider = model_provider
self.deployment_type = deployment_type
self.model_name = model_name
self.api_key = api_key
self.endpoint = endpoint
# Handle device selection
self.device = device
if self.use_gpu:
if device:
if not device.startswith("cuda:"):
raise ValueError("Device must be in format 'cuda:N' where N is the GPU index")
device_idx = int(device.split(":")[-1])
if device_idx >= torch.cuda.device_count():
raise ValueError(f"Device {device} not available. Maximum device index is {torch.cuda.device_count()-1}")
else:
self.device = "cuda:0" # Default to first GPU
else:
self.device = "cpu"
# Initialize appropriate model interface
if deployment_type == "local":
self.base_model = BaseModel.get_instance()
success = self.base_model.initialize_model(
model=model_name,
use_gpu=use_gpu,
device=device
)
if not success:
raise RuntimeError(f"Failed to initialize model: {model_name}")
else: # api
self.api_connector = get_connector(model_provider, api_key, endpoint)
def _load_config(self) -> Dict[str, Any]:
"""Load model configuration."""
try:
config_path = Path(__file__).parent / "config" / "model_config.json"
with open(config_path, 'r') as f:
return json.load(f)
except Exception as e:
self.logger.error(f"Error loading config: {str(e)}")
return {}
[docs] def get_response(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""
Generate a response using either local model or API.
Args:
prompt: The input prompt
**kwargs: Additional generation parameters including:
max_length: Maximum length of generated response
temperature: Sampling temperature (0.0 to 1.0)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
num_beams: Number of beams for beam search
Returns:
Dict[str, Any]: Response dictionary containing:
text: The generated response text
metadata: Generation metadata (tokens, time, etc)
error: Error message if generation failed
"""
if not prompt or not isinstance(prompt, str):
return {
"error": "Invalid prompt - must be non-empty string",
"text": None,
"metadata": None
}
try:
# Log generation attempt
self.logger.info(f"Generating response for prompt: {prompt[:100]}...")
self.logger.debug(f"Full prompt: {prompt}")
self.logger.info(f"Using deployment type: {self.deployment_type}")
self.logger.debug(f"Generation parameters: {kwargs}")
# Validate and set default parameters
max_retries = kwargs.pop('max_retries', 3)
timeout = kwargs.pop('timeout', 30)
# Initialize response
response = None
error = None
metadata = {
"attempt": 0,
"total_tokens": 0,
"generation_time": 0
}
# Try generation with retries
for attempt in range(max_retries):
metadata["attempt"] = attempt + 1
try:
if self.deployment_type == "local":
self.logger.info("Using base model for generation")
response = self.base_model.generate(
prompt,
timeout=timeout,
**kwargs
)
else:
self.logger.info(f"Using {self.model_provider} API connector")
response = self.api_connector.generate(
prompt,
timeout=timeout,
**kwargs
)
if response:
break
except Exception as e:
error = str(e)
self.logger.warning(
f"Attempt {attempt + 1} failed: {error}",
exc_info=True
)
if attempt < max_retries - 1:
continue
# Process results
if response:
# Extract metadata if available
if isinstance(response, dict):
metadata.update(response.get('metadata', {}))
response = response.get('text', response)
self.logger.info(
f"Response generated successfully. Length: {len(response)}"
)
return {
"text": response,
"metadata": metadata,
"error": None
}
else:
error_msg = error or "Failed to generate response after retries"
self.logger.error(error_msg)
return {
"text": None,
"metadata": metadata,
"error": error_msg
}
except Exception as e:
self.logger.error(
f"Unexpected error in get_response: {str(e)}",
exc_info=True
)
return {
"text": None,
"metadata": {"attempt": 1},
"error": f"Unexpected error: {str(e)}"
}
[docs] def cleanup(self):
"""Clean up model resources."""
if self.deployment_type == "local" and hasattr(self, 'base_model'):
self.base_model.cleanup()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
self.logger.info("Model resources cleaned up")
[docs] def get_response_with_context(self, prompt: str, context_data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
Generate a response using context-aware prompting.
Args:
prompt: The input prompt
context_data: Dictionary containing contextual information including:
- location_info: Location details
- raw_data_summary: Summary of raw data from different sources
- analysis_results: Results of various analyses
- scenario_projections: Future scenario projections
- historical_trends: Historical trend analysis
**kwargs: Additional generation parameters including:
max_length: Maximum length of generated response
temperature: Sampling temperature (0.0 to 1.0)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
num_beams: Number of beams for beam search
Returns:
Dict[str, Any]: Response dictionary containing:
text: The generated response text
metadata: Generation metadata including context usage
error: Error message if generation failed
"""
if not prompt or not isinstance(prompt, str):
return {
"error": "Invalid prompt - must be non-empty string",
"text": None,
"metadata": None
}
if not context_data or not isinstance(context_data, dict):
return {
"error": "Invalid context_data - must be non-empty dictionary",
"text": None,
"metadata": None
}
try:
# Log generation attempt with context
self.logger.info(f"Generating context-aware response for prompt: {prompt[:100]}...")
self.logger.debug(f"Full prompt: {prompt}")
self.logger.debug(f"Context data keys: {list(context_data.keys())}")
self.logger.info(f"Using deployment type: {self.deployment_type}")
# Format prompt with context
formatted_prompt = self._format_prompt_with_context(prompt, context_data)
self.logger.debug(f"Formatted prompt with context: {formatted_prompt[:200]}...")
# Get response using formatted prompt
response = self.get_response(formatted_prompt, **kwargs)
if response.get("error"):
return response
# Analyze context usage in response
context_usage = self._analyze_context_usage(response["text"], context_data)
# Update metadata with context usage
response["metadata"]["context_used"] = context_usage
response["metadata"]["context_keys"] = list(context_data.keys())
response["metadata"]["prompt_length"] = len(formatted_prompt)
response["metadata"]["context_integration_timestamp"] = datetime.now().isoformat()
return response
except Exception as e:
self.logger.error(
f"Unexpected error in get_response_with_context: {str(e)}",
exc_info=True
)
return {
"text": None,
"metadata": {"attempt": 1},
"error": f"Error processing context: {str(e)}"
}
def _format_prompt_with_context(self, prompt: str, context_data: Dict[str, Any]) -> str:
"""Format the prompt by incorporating context data."""
context_sections = []
# Add location information if available
if "location_info" in context_data:
loc = context_data["location_info"]
context_sections.append(
f"Location: {loc.get('name', 'Unknown')}\n"
f"Type: {loc.get('type', 'Unknown')}\n"
f"Area: {loc.get('area_sqkm', 0):.2f} km²"
)
# Add data summaries if available
if "raw_data_summary" in context_data:
data = context_data["raw_data_summary"]
if "overture" in data:
ov = data["overture"]
context_sections.append(
f"Urban Data:\n"
f"- Buildings: {ov.get('total_buildings', 0)}\n"
f"- Places: {ov.get('total_places', 0)}\n"
f"- Transportation: {ov.get('total_transportation', 0)}"
)
if "sentinel" in data:
sen = data["sentinel"]
context_sections.append(
f"Satellite Data:\n"
f"- Scenes: {sen.get('total_scenes', 0)}\n"
f"- Coverage: {sen.get('coverage_percentage', 0)}%"
)
# Add analysis results if available
if "analysis_results" in context_data:
analysis = context_data["analysis_results"]
if "urban_metrics" in analysis:
um = analysis["urban_metrics"]
context_sections.append(
f"Urban Analysis:\n"
f"- Building Density: {um.get('building_density', 0):.2f}/km²\n"
f"- Urbanization Level: {um.get('urbanization_level', 'Unknown')}"
)
if "environmental_metrics" in analysis:
em = analysis["environmental_metrics"]
context_sections.append(
f"Environmental Analysis:\n"
f"- Vegetation Index: {em.get('vegetation_index', 0)}\n"
f"- Environmental Health: {em.get('environmental_health', 'Unknown')}"
)
# Add scenario projections if available
if "scenario_projections" in context_data:
scenarios = context_data["scenario_projections"]
context_sections.append("Future Scenarios:")
for scenario_type, details in scenarios.items():
context_sections.append(
f"{scenario_type.title()} Scenario (Probability: {details.get('probability', 0)*100:.0f}%):\n"
f"- Changes: {', '.join(str(c) for c in details.get('changes', []))}\n"
f"- Impact Factors: {', '.join(str(f) for f in details.get('impact_factors', []))}"
)
# Add historical trends if available
if "historical_trends" in context_data:
trends = context_data["historical_trends"]
context_sections.append(
f"Historical Trends:\n"
f"- Growth Rate: {trends.get('growth_rate', 0):.2f}\n"
f"- Trend Direction: {trends.get('trend_direction', 'stable')}\n"
f"- Seasonal Factors: {', '.join(trends.get('seasonal_factors', []))}"
)
# Combine context sections with the original prompt
context_text = "\n\n".join(context_sections)
formatted_prompt = f"""Context Information:
{context_text}
User Query: {prompt}
Please provide a detailed response incorporating the above context."""
return formatted_prompt
def _analyze_context_usage(self, response: str, context_data: Dict[str, Any]) -> Dict[str, float]:
"""Analyze how different parts of the context were used in the response."""
context_usage = {}
response_lower = response.lower()
for context_type, data in context_data.items():
if isinstance(data, dict):
# For nested dictionaries, check both keys and values
key_terms = set()
for k, v in data.items():
key_terms.add(str(k).lower())
if isinstance(v, (str, int, float)):
key_terms.add(str(v).lower())
elif isinstance(v, (list, tuple)):
key_terms.update(str(item).lower() for item in v)
elif isinstance(v, dict):
key_terms.update(str(k).lower() for k in v.keys())
key_terms.update(str(v).lower() for v in v.values() if isinstance(v, (str, int, float)))
elif isinstance(data, (list, tuple)):
key_terms = set(str(item).lower() for item in data)
else:
key_terms = set(str(data).lower().split())
# Count how many key terms appear in response
terms_found = sum(1 for term in key_terms if term in response_lower)
if len(key_terms) > 0:
usage_score = terms_found / len(key_terms)
else:
usage_score = 0.0
context_usage[context_type] = usage_score
return context_usage
[docs] def chat_completion(
self,
messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
**kwargs
) -> Dict[str, Any]:
"""
Generate a chat completion response using either local model or API.
Args:
messages: List of message dictionaries with 'role' and 'content' keys.
Roles can be 'user', 'assistant', 'system', or 'function'.
tools: Optional list of tool/function definitions that the model can use.
Each tool should have a 'type', 'function' with 'name', 'description', 'parameters'.
tool_choice: How to handle tool selection. Options:
- "auto": Let the model decide if it should call a function
- "none": Don't call any functions
- Dict with specific function to call
**kwargs: Additional parameters including:
temperature: Sampling temperature (0.0 to 1.0)
max_tokens: Maximum tokens in the response
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty parameter
presence_penalty: Presence penalty parameter
Returns:
Dict[str, Any]: Response dictionary containing:
message: The assistant's message
tool_calls: List of tool calls if any
metadata: Generation metadata
error: Error message if generation failed
"""
if not messages or not isinstance(messages, list):
return {
"error": "Invalid messages - must be non-empty list",
"message": None,
"tool_calls": None,
"metadata": None
}
try:
# Log generation attempt
self.logger.info(f"Generating chat completion for {len(messages)} messages")
self.logger.debug(f"Messages: {messages}")
self.logger.debug(f"Tools available: {len(tools) if tools else 0}")
# Validate and set default parameters
max_retries = kwargs.pop('max_retries', 3)
timeout = kwargs.pop('timeout', 30)
# Initialize response
response = None
error = None
metadata = {
"attempt": 0,
"total_tokens": 0,
"generation_time": 0,
"timestamp": datetime.now().isoformat()
}
# Try generation with retries
for attempt in range(max_retries):
metadata["attempt"] = attempt + 1
try:
if self.deployment_type == "local":
self.logger.info("Using base model for chat completion")
response = self.base_model.chat_completion(
messages=messages,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
**kwargs
)
else:
self.logger.info(f"Using {self.model_provider} API connector")
response = self.api_connector.chat_completion(
messages=messages,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
**kwargs
)
if response:
break
except Exception as e:
error = str(e)
self.logger.warning(
f"Attempt {attempt + 1} failed: {error}",
exc_info=True
)
if attempt < max_retries - 1:
continue
# Process results
if response:
# Extract metadata if available
if isinstance(response, dict):
metadata.update(response.get('metadata', {}))
self.logger.info("Chat completion generated successfully")
return {
"message": response.get('message', {}),
"tool_calls": response.get('tool_calls', []),
"metadata": metadata,
"error": None
}
else:
error_msg = error or "Failed to generate chat completion after retries"
self.logger.error(error_msg)
return {
"message": None,
"tool_calls": None,
"metadata": metadata,
"error": error_msg
}
except Exception as e:
self.logger.error(
f"Unexpected error in chat_completion: {str(e)}",
exc_info=True
)
return {
"message": None,
"tool_calls": None,
"metadata": {"attempt": 1},
"error": f"Unexpected error: {str(e)}"
}