model-forge-llm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modelforge/registry.py ADDED
@@ -0,0 +1,272 @@
1
+ # Standard library imports
2
+ import os
3
+ from typing import Any
4
+
5
+ from langchain_community.chat_models import ChatOllama
6
+
7
+ # Third-party imports
8
+ from langchain_core.language_models.chat_models import BaseChatModel
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_openai import ChatOpenAI
11
+
12
+ # Try to import ChatGitHubCopilot, but make it optional
13
+ try:
14
+ from langchain_github_copilot import ChatGitHubCopilot
15
+
16
+ GITHUB_COPILOT_AVAILABLE = True
17
+ except ImportError:
18
+ GITHUB_COPILOT_AVAILABLE = False
19
+
20
+ # Local imports
21
+ from . import auth, config
22
+ from .exceptions import ConfigurationError, ModelNotFoundError, ProviderError
23
+ from .logging_config import get_logger
24
+
25
+
26
+ def _raise_provider_error(message: str) -> None:
27
+ """Raise a ProviderError with the given message."""
28
+ raise ProviderError(message)
29
+
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class ModelForgeRegistry:
35
+ """
36
+ Main registry for managing and instantiating LLM models.
37
+ This class serves as the primary entry point for accessing configured
38
+ language models across different providers.
39
+ """
40
+
41
+ def __init__(self, verbose: bool = False) -> None:
42
+ """
43
+ Initialize the ModelForgeRegistry.
44
+ Args:
45
+ verbose: Enable verbose debug logging.
46
+ """
47
+ self.verbose = verbose
48
+ self._config, _ = config.get_config()
49
+ logger.debug("ModelForgeRegistry initialized with verbose=%s", verbose)
50
+
51
+ def _get_model_config(
52
+ self, provider_name: str | None, model_alias: str | None
53
+ ) -> tuple[str, str, dict[str, Any], dict[str, Any]]:
54
+ """
55
+ Retrieves and validates provider and model configuration.
56
+ If provider or model are not specified, it falls back to the current selection.
57
+ """
58
+ logger.debug(
59
+ "Attempting to get model config for provider: '%s' and model: '%s'",
60
+ provider_name,
61
+ model_alias,
62
+ )
63
+
64
+ if not provider_name or not model_alias:
65
+ logger.debug(
66
+ "Provider or model not specified, falling back to current model."
67
+ )
68
+ current_model = config.get_current_model()
69
+ if not current_model:
70
+ msg = (
71
+ "No model selected. Use 'modelforge config use' "
72
+ "or provide provider and model."
73
+ )
74
+ raise ConfigurationError(msg)
75
+ provider_name = current_model.get("provider")
76
+ model_alias = current_model.get("model")
77
+ logger.debug(
78
+ "Using current model: provider='%s', model='%s'",
79
+ provider_name,
80
+ model_alias,
81
+ )
82
+
83
+ if not provider_name or not model_alias:
84
+ msg = "Could not determine provider and model to use."
85
+ raise ConfigurationError(msg)
86
+
87
+ provider_data = self._config.get("providers", {}).get(provider_name)
88
+ if not provider_data:
89
+ msg = f"Provider '{provider_name}' not found in configuration"
90
+ raise ProviderError(msg)
91
+
92
+ model_data = provider_data.get("models", {}).get(model_alias)
93
+ if model_data is None:
94
+ msg = f"Model '{model_alias}' not found for provider '{provider_name}'"
95
+ raise ModelNotFoundError(msg)
96
+
97
+ logger.debug(
98
+ "Successfully retrieved config for provider='%s' and model='%s'",
99
+ provider_name,
100
+ model_alias,
101
+ )
102
+ return provider_name, model_alias, provider_data, model_data
103
+
104
+ def get_llm(
105
+ self, provider_name: str | None = None, model_alias: str | None = None
106
+ ) -> BaseChatModel:
107
+ """
108
+ Get a fully authenticated and configured LLM instance.
109
+ Args:
110
+ provider_name: The provider name. If None, uses current selection.
111
+ model_alias: The model alias. If None, uses current selection.
112
+ Returns:
113
+ A LangChain-compatible LLM instance ready for use.
114
+ Raises:
115
+ ConfigurationError: If no model is selected or configuration is invalid.
116
+ ProviderError: If the provider is not supported or credentials are missing.
117
+ ModelNotFoundError: If the specified model is not found.
118
+ """
119
+ resolved_provider = provider_name
120
+ resolved_model = model_alias
121
+ try:
122
+ (
123
+ resolved_provider,
124
+ resolved_model,
125
+ provider_data,
126
+ model_data,
127
+ ) = self._get_model_config(resolved_provider, resolved_model)
128
+
129
+ llm_type = provider_data.get("llm_type")
130
+ if not llm_type:
131
+ _raise_provider_error(
132
+ f"Provider '{resolved_provider}' has no 'llm_type' configured."
133
+ )
134
+
135
+ logger.info(
136
+ "Creating LLM instance for provider: %s, model: %s",
137
+ resolved_provider,
138
+ resolved_model,
139
+ )
140
+
141
+ # Factory mapping for LLM creation
142
+ creator_map = {
143
+ "ollama": self._create_ollama_llm,
144
+ "google_genai": self._create_google_genai_llm,
145
+ "openai_compatible": self._create_openai_compatible_llm,
146
+ "github_copilot": self._create_github_copilot_llm,
147
+ }
148
+
149
+ creator = creator_map.get(str(llm_type))
150
+ if not creator:
151
+ _raise_provider_error(
152
+ f"Unsupported llm_type '{llm_type}' for provider "
153
+ f"'{resolved_provider}'"
154
+ )
155
+
156
+ return creator(resolved_provider, resolved_model, provider_data, model_data) # type: ignore[misc]
157
+
158
+ except (ConfigurationError, ProviderError, ModelNotFoundError):
159
+ logger.exception("Failed to create LLM")
160
+ raise
161
+ except Exception as e:
162
+ logger.exception(
163
+ "An unexpected error occurred while creating LLM instance for %s/%s",
164
+ resolved_provider,
165
+ resolved_model,
166
+ )
167
+ msg = "An unexpected error occurred during LLM creation."
168
+ raise ProviderError(msg) from e
169
+
170
+ def _create_openai_compatible_llm(
171
+ self,
172
+ provider_name: str,
173
+ model_alias: str,
174
+ provider_data: dict[str, Any],
175
+ model_data: dict[str, Any],
176
+ ) -> ChatOpenAI:
177
+ """
178
+ Create a ChatOpenAI instance for OpenAI-compatible providers.
179
+ """
180
+ credentials = auth.get_credentials(
181
+ provider_name, model_alias, provider_data, verbose=self.verbose
182
+ )
183
+ if not credentials:
184
+ msg = f"Could not retrieve credentials for provider: {provider_name}"
185
+ raise ProviderError(msg)
186
+
187
+ api_key = credentials.get("access_token") or credentials.get("api_key")
188
+ if not api_key:
189
+ msg = f"Could not find token or key for provider: {provider_name}"
190
+ raise ProviderError(msg)
191
+
192
+ actual_model_name = model_data.get("api_model_name", model_alias)
193
+ base_url = provider_data.get("base_url")
194
+
195
+ if self.verbose:
196
+ logger.debug("Creating ChatOpenAI instance with:")
197
+ logger.debug(" Provider: %s", provider_name)
198
+ logger.debug(" Model alias: %s", model_alias)
199
+ logger.debug(" Actual model name: %s", actual_model_name)
200
+ logger.debug(" Base URL: %s", base_url)
201
+
202
+ return ChatOpenAI(model=actual_model_name, api_key=api_key, base_url=base_url)
203
+
204
+ def _create_ollama_llm(
205
+ self,
206
+ provider_name: str, # noqa: ARG002
207
+ model_alias: str,
208
+ provider_data: dict[str, Any],
209
+ model_data: dict[str, Any], # noqa: ARG002
210
+ ) -> ChatOllama:
211
+ """Create ChatOllama instance."""
212
+ base_url = provider_data.get("base_url", os.getenv("OLLAMA_HOST"))
213
+ if not base_url:
214
+ msg = (
215
+ "Ollama 'base_url' not set in config and "
216
+ "OLLAMA_HOST env var is not set."
217
+ )
218
+ raise ConfigurationError(msg)
219
+ return ChatOllama(model=model_alias, base_url=base_url)
220
+
221
+ def _create_github_copilot_llm(
222
+ self,
223
+ provider_name: str,
224
+ model_alias: str,
225
+ provider_data: dict[str, Any],
226
+ model_data: dict[str, Any],
227
+ ) -> BaseChatModel:
228
+ """Create a ChatGitHubCopilot instance."""
229
+ if not GITHUB_COPILOT_AVAILABLE:
230
+ msg = (
231
+ "GitHub Copilot libraries not installed. "
232
+ "Please run 'poetry install --extras github-copilot'"
233
+ )
234
+ raise ProviderError(msg)
235
+
236
+ credentials = auth.get_credentials(
237
+ provider_name, model_alias, provider_data, verbose=self.verbose
238
+ )
239
+ if not credentials or "access_token" not in credentials:
240
+ msg = f"Could not get valid credentials for {provider_name}"
241
+ raise ProviderError(msg)
242
+
243
+ copilot_token = credentials["access_token"]
244
+ actual_model_name = model_data.get("api_model_name", model_alias)
245
+
246
+ if self.verbose:
247
+ logger.debug("Creating ChatGitHubCopilot instance with:")
248
+ logger.debug(" Provider: %s", provider_name)
249
+ logger.debug(" Model alias: %s", model_alias)
250
+ logger.debug(" Actual model name: %s", actual_model_name)
251
+
252
+ return ChatGitHubCopilot(api_key=copilot_token, model=actual_model_name)
253
+
254
+ def _create_google_genai_llm(
255
+ self,
256
+ provider_name: str,
257
+ model_alias: str,
258
+ provider_data: dict[str, Any],
259
+ model_data: dict[str, Any],
260
+ ) -> ChatGoogleGenerativeAI:
261
+ """Create ChatGoogleGenerativeAI instance."""
262
+ credentials = auth.get_credentials(
263
+ provider_name, model_alias, provider_data, verbose=self.verbose
264
+ )
265
+ if not credentials or "api_key" not in credentials:
266
+ msg = f"API key not found for {provider_name}"
267
+ raise ProviderError(msg)
268
+
269
+ api_key = credentials["api_key"]
270
+ actual_model_name = model_data.get("api_model_name", model_alias)
271
+
272
+ return ChatGoogleGenerativeAI(model=actual_model_name, google_api_key=api_key)