Source code for memories.models.base_model

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import os
import logging
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
import gc
from diffusers import StableDiffusionPipeline

# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Global pipe variable for Stable Diffusion
pipe = None

[docs]class BaseModel: """Base model class that can be shared across modules""" _instance = None _initialized = False
[docs] def __init__(self): """Initialize the base model.""" if not BaseModel._initialized: self.model = None self.tokenizer = None self.config = self._load_config() self.hf_token = os.getenv("HF_TOKEN") BaseModel._initialized = True
[docs] @classmethod def get_instance(cls): """Get singleton instance of BaseModel.""" if cls._instance is None: cls._instance = cls() return cls._instance
def _load_config(self) -> Dict[str, Any]: """Load model configuration from JSON file.""" try: config_path = Path(__file__).parent / "config" / "model_config.json" with open(config_path, 'r') as f: return json.load(f) except Exception as e: logger.error(f"Error loading model config: {str(e)}") return {}
[docs] def get_model_config(self, model_name: str) -> Dict[str, Any]: """Get configuration for a specific model.""" return self.config.get("models", {}).get(model_name, {})
[docs] def initialize_model(self, model: str, use_gpu: bool = True, device: str = None) -> bool: """Initialize a model with the specified configuration. Args: model: Model identifier from config use_gpu: Whether to use GPU if available device: Specific GPU device to use (e.g., "cuda:0", "cuda:1") Returns: bool: True if initialization successful, False otherwise """ try: # Clean up any existing model self.cleanup() # Get model configuration model_config = self.get_model_config(model) if not model_config: logger.error(f"No configuration found for model: {model}") return False # Determine device if use_gpu and torch.cuda.is_available(): if device: if not device.startswith("cuda:"): logger.error("Device must be in format 'cuda:N' where N is the GPU index") return False device_idx = int(device.split(":")[-1]) if device_idx >= torch.cuda.device_count(): logger.error(f"Device {device} not available. Maximum device index is {torch.cuda.device_count()-1}") return False else: device = "cuda:0" # Default to first GPU else: device = "cpu" if not self.config["global_config"]["fallback_to_cpu"]: logger.error("GPU requested but not available and fallback_to_cpu is False") return False # Get model parameters config = model_config["config"] model_name = model_config["name"] # Set up model parameters model_kwargs = { "use_auth_token": self.hf_token if self.config["global_config"]["use_auth_token"] else None, "torch_dtype": getattr(torch, config.get("torch_dtype", "float32")), "device_map": "auto" if device.startswith("cuda") else None, "trust_remote_code": config.get("trust_remote_code", True) } # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_auth_token=model_kwargs["use_auth_token"], trust_remote_code=model_kwargs["trust_remote_code"] ) self.model = AutoModelForCausalLM.from_pretrained( model_name, **model_kwargs ) # Move model to specified device if device == "cpu" or device.startswith("cuda:"): self.model = self.model.to(device) logger.info(f"Model {model_name} initialized successfully on {device}") return True except Exception as e: logger.error(f"Error initializing model: {str(e)}") self.cleanup() return False
[docs] def generate(self, prompt: str, **kwargs) -> str: """Generate text using the model with configured parameters. Args: prompt: Input prompt **kwargs: Override default generation parameters Returns: str: Generated text """ if self.model is None or self.tokenizer is None: raise ValueError("Model not initialized. Call initialize_model first.") try: # Get current model's config current_model = next( (name for name, cfg in self.config["models"].items() if cfg["name"] == self.model.config._name_or_path), self.config["default_model"] ) model_config = self.get_model_config(current_model)["config"] # Prepare generation config gen_config = { "max_length": model_config.get("max_length", 1000), "temperature": model_config.get("temperature", 0.7), "top_p": model_config.get("top_p", 0.95), "top_k": model_config.get("top_k", 50), "repetition_penalty": model_config.get("repetition_penalty", 1.1), "pad_token_id": self.tokenizer.eos_token_id, "do_sample": True } # Override with any provided kwargs gen_config.update(kwargs) # Generate inputs = self.tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate(**inputs, **gen_config) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: logger.error(f"Error during generation: {str(e)}") raise
[docs] def cleanup(self): """Clean up model resources.""" if hasattr(self, 'model') and self.model is not None: del self.model self.model = None if hasattr(self, 'tokenizer') and self.tokenizer is not None: del self.tokenizer self.tokenizer = None if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() logger.info("Model resources cleaned up")
[docs] @classmethod def get_model_path(cls, provider: str, model_key: str) -> str: """Get the full model path/identifier for a given provider and model key""" instance = cls.get_instance() if model_key not in instance.config["models"]: raise ValueError(f"Unknown model key: {model_key}") model_config = instance.config["models"][model_key] if model_config["provider"] != provider: raise ValueError(f"Model {model_key} not available for provider {provider}") return model_config["name"]
[docs] @classmethod def list_providers(cls) -> List[str]: """List all available providers""" instance = cls.get_instance() return instance.config["supported_providers"]
[docs] @classmethod def list_models(cls, provider: str = None) -> List[str]: """List all available models, optionally filtered by provider""" instance = cls.get_instance() if provider: if provider not in instance.config["supported_providers"]: raise ValueError(f"Unknown provider: {provider}") return [ name for name, cfg in instance.config["models"].items() if cfg["provider"] == provider ] return list(instance.config["models"].keys())
def load_stable_diffusion_model(device: str = None): """ Preloads the Stable Diffusion model into the global `pipe` variable. Args: device: Specific GPU device to use (e.g., "cuda:0", "cuda:1") """ global pipe if pipe is not None: logger.info("Stable Diffusion model already loaded; skipping.") return try: instance = BaseModel.get_instance() model_config = instance.get_model_config("stable-diffusion-2") if not model_config: raise ValueError("No configuration found for Stable Diffusion model") config = model_config["config"] model_name = model_config["name"] pipe = StableDiffusionPipeline.from_pretrained( model_name, **config ) # Determine device if torch.cuda.is_available(): if device: if not device.startswith("cuda:"): raise ValueError("Device must be in format 'cuda:N' where N is the GPU index") device_idx = int(device.split(":")[-1]) if device_idx >= torch.cuda.device_count(): raise ValueError(f"Device {device} not available. Maximum device index is {torch.cuda.device_count()-1}") else: device = "cuda:0" # Default to first GPU pipe = pipe.to(device) else: logger.warning("CUDA not available, using CPU. This will be slow!") pipe = pipe.to("cpu") logger.info(f"Stable Diffusion model loaded successfully on {device if device else 'cuda:0'}") except Exception as e: logger.error(f"Failed to load Stable Diffusion model: {e}") pipe = None raise RuntimeError("Failed to load Stable Diffusion model. Ensure proper environment setup and access.") from e def unload_stable_diffusion_model(): """ Unloads the Stable Diffusion model from memory and clears the GPU cache. """ global pipe if pipe: del pipe pipe = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Stable Diffusion model unloaded and GPU cache cleared.")