rakam-systems-agent 0.1.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,313 @@
1
+ """LLM Gateway Factory for provider routing and configuration-driven model selection."""
2
+ from __future__ import annotations
3
+ import os
4
+ from typing import Any, Dict, Optional
5
+
6
+ from rakam_systems_core.ai_utils import logging
7
+ from rakam_systems_core.ai_core.interfaces.llm_gateway import LLMGateway
8
+ from .openai_gateway import OpenAIGateway
9
+ from .mistral_gateway import MistralGateway
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LLMGatewayFactory:
15
+ """Factory for creating LLM gateways based on provider and configuration.
16
+
17
+ This factory enables:
18
+ - Configuration-driven provider selection
19
+ - Automatic routing to the correct gateway
20
+ - Model string parsing (e.g., "openai:gpt-4o")
21
+ - Fallback to environment-based configuration
22
+
23
+ Example:
24
+ >>> # Using model string with provider prefix
25
+ >>> gateway = LLMGatewayFactory.create_gateway("openai:gpt-4o")
26
+ >>>
27
+ >>> # Using explicit provider and model
28
+ >>> gateway = LLMGatewayFactory.create_gateway_from_config({
29
+ ... "provider": "mistral",
30
+ ... "model": "mistral-large-latest",
31
+ ... "temperature": 0.7
32
+ ... })
33
+ """
34
+
35
+ # Registry of available providers
36
+ _PROVIDERS = {
37
+ "openai": OpenAIGateway,
38
+ "mistral": MistralGateway,
39
+ }
40
+
41
+ # Default models for each provider
42
+ _DEFAULT_MODELS = {
43
+ "openai": "gpt-4o",
44
+ "mistral": "mistral-large-latest",
45
+ }
46
+
47
+ @classmethod
48
+ def parse_model_string(cls, model_string: str) -> tuple[str, str]:
49
+ """Parse a model string into provider and model name.
50
+
51
+ Supports formats:
52
+ - "openai:gpt-4o" -> ("openai", "gpt-4o")
53
+ - "mistral:mistral-large-latest" -> ("mistral", "mistral-large-latest")
54
+ - "gpt-4o" -> ("openai", "gpt-4o") # assumes OpenAI if no prefix
55
+
56
+ Args:
57
+ model_string: Model string to parse
58
+
59
+ Returns:
60
+ Tuple of (provider, model_name)
61
+
62
+ Raises:
63
+ ValueError: If provider is unknown
64
+ """
65
+ if ":" in model_string:
66
+ provider, model = model_string.split(":", 1)
67
+ else:
68
+ # Default to OpenAI if no provider specified
69
+ provider = "openai"
70
+ model = model_string
71
+
72
+ if provider not in cls._PROVIDERS:
73
+ raise ValueError(
74
+ f"Unknown provider '{provider}'. Supported providers: {list(cls._PROVIDERS.keys())}"
75
+ )
76
+
77
+ return provider, model
78
+
79
+ @classmethod
80
+ def create_gateway(
81
+ cls,
82
+ model_string: Optional[str] = None,
83
+ temperature: Optional[float] = None,
84
+ api_key: Optional[str] = None,
85
+ **kwargs: Any,
86
+ ) -> LLMGateway:
87
+ """Create an LLM gateway from a model string.
88
+
89
+ Args:
90
+ model_string: Model string (e.g., "openai:gpt-4o", "mistral:mistral-large-latest")
91
+ Falls back to DEFAULT_LLM_MODEL env var
92
+ temperature: Temperature for generation
93
+ api_key: API key for the provider (provider-specific env vars used as fallback)
94
+ **kwargs: Additional parameters passed to gateway constructor
95
+
96
+ Returns:
97
+ Configured LLM gateway instance
98
+
99
+ Raises:
100
+ ValueError: If provider is unknown or configuration is invalid
101
+ """
102
+ # Get model string from parameter or environment
103
+ model_string = model_string or os.getenv(
104
+ "DEFAULT_LLM_MODEL", "openai:gpt-4o")
105
+
106
+ # Parse provider and model
107
+ provider, model = cls.parse_model_string(model_string)
108
+
109
+ # Get gateway class
110
+ gateway_class = cls._PROVIDERS[provider]
111
+
112
+ # Build gateway parameters
113
+ gateway_params = {
114
+ "model": model,
115
+ "default_temperature": temperature or float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7")),
116
+ }
117
+
118
+ # Add API key if provided
119
+ if api_key:
120
+ gateway_params["api_key"] = api_key
121
+
122
+ # Add any additional parameters
123
+ gateway_params.update(kwargs)
124
+
125
+ logger.info(
126
+ f"Creating {provider} gateway with model={model}, temperature={gateway_params['default_temperature']}"
127
+ )
128
+
129
+ return gateway_class(**gateway_params)
130
+
131
+ @classmethod
132
+ def create_gateway_from_config(
133
+ cls,
134
+ config: Dict[str, Any],
135
+ ) -> LLMGateway:
136
+ """Create an LLM gateway from a configuration dictionary.
137
+
138
+ Expected config format:
139
+ {
140
+ "provider": "openai", # or "mistral"
141
+ "model": "gpt-4o",
142
+ "temperature": 0.7,
143
+ "api_key": "...", # optional
144
+ ... # additional provider-specific params
145
+ }
146
+
147
+ Args:
148
+ config: Configuration dictionary
149
+
150
+ Returns:
151
+ Configured LLM gateway instance
152
+
153
+ Raises:
154
+ ValueError: If provider is unknown or configuration is invalid
155
+ """
156
+ provider = config.get("provider")
157
+ model = config.get("model")
158
+
159
+ if not provider:
160
+ raise ValueError("Configuration must specify 'provider'")
161
+
162
+ if provider not in cls._PROVIDERS:
163
+ raise ValueError(
164
+ f"Unknown provider '{provider}'. Supported providers: {list(cls._PROVIDERS.keys())}"
165
+ )
166
+
167
+ # Use default model if not specified
168
+ if not model:
169
+ model = cls._DEFAULT_MODELS.get(provider)
170
+ logger.warning(
171
+ f"No model specified for {provider}, using default: {model}"
172
+ )
173
+
174
+ # Get gateway class
175
+ gateway_class = cls._PROVIDERS[provider]
176
+
177
+ # Extract common parameters
178
+ gateway_params = {
179
+ "model": model,
180
+ "default_temperature": config.get("temperature", 0.7),
181
+ }
182
+
183
+ # Add API key if provided
184
+ if "api_key" in config:
185
+ gateway_params["api_key"] = config["api_key"]
186
+
187
+ # Add provider-specific parameters
188
+ provider_specific_keys = {
189
+ "openai": ["base_url", "organization"],
190
+ "mistral": [],
191
+ }
192
+
193
+ for key in provider_specific_keys.get(provider, []):
194
+ if key in config:
195
+ gateway_params[key] = config[key]
196
+
197
+ logger.info(
198
+ f"Creating {provider} gateway from config with model={model}"
199
+ )
200
+
201
+ return gateway_class(**gateway_params)
202
+
203
+ @classmethod
204
+ def get_default_gateway(cls) -> LLMGateway:
205
+ """Get a default gateway based on environment configuration.
206
+
207
+ Checks environment variables:
208
+ - DEFAULT_LLM_PROVIDER: Provider name (default: "openai")
209
+ - DEFAULT_LLM_MODEL: Model name (default: "gpt-4o" for OpenAI)
210
+ - DEFAULT_LLM_TEMPERATURE: Temperature (default: 0.7)
211
+
212
+ Returns:
213
+ Configured default LLM gateway
214
+ """
215
+ provider = os.getenv("DEFAULT_LLM_PROVIDER", "openai")
216
+ model = os.getenv("DEFAULT_LLM_MODEL",
217
+ cls._DEFAULT_MODELS.get(provider, "gpt-4o"))
218
+ temperature = float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7"))
219
+
220
+ # If model doesn't have provider prefix, add it
221
+ if ":" not in model:
222
+ model_string = f"{provider}:{model}"
223
+ else:
224
+ model_string = model
225
+
226
+ return cls.create_gateway(
227
+ model_string=model_string,
228
+ temperature=temperature,
229
+ )
230
+
231
+ @classmethod
232
+ def register_provider(
233
+ cls,
234
+ provider_name: str,
235
+ gateway_class: type[LLMGateway],
236
+ default_model: Optional[str] = None,
237
+ ) -> None:
238
+ """Register a new provider gateway.
239
+
240
+ This allows extending the factory with custom providers.
241
+
242
+ Args:
243
+ provider_name: Name of the provider (e.g., "custom")
244
+ gateway_class: Gateway class implementing LLMGateway
245
+ default_model: Optional default model for this provider
246
+ """
247
+ cls._PROVIDERS[provider_name] = gateway_class
248
+ if default_model:
249
+ cls._DEFAULT_MODELS[provider_name] = default_model
250
+
251
+ logger.info(
252
+ f"Registered provider '{provider_name}' with gateway class {gateway_class.__name__}"
253
+ )
254
+
255
+ @classmethod
256
+ def list_providers(cls) -> list[str]:
257
+ """List all registered providers.
258
+
259
+ Returns:
260
+ List of provider names
261
+ """
262
+ return list(cls._PROVIDERS.keys())
263
+
264
+ @classmethod
265
+ def get_default_model(cls, provider: str) -> Optional[str]:
266
+ """Get the default model for a provider.
267
+
268
+ Args:
269
+ provider: Provider name
270
+
271
+ Returns:
272
+ Default model name or None if provider unknown
273
+ """
274
+ return cls._DEFAULT_MODELS.get(provider)
275
+
276
+
277
+ # Convenience function for creating gateways
278
+ def get_llm_gateway(
279
+ model: Optional[str] = None,
280
+ provider: Optional[str] = None,
281
+ temperature: Optional[float] = None,
282
+ **kwargs: Any,
283
+ ) -> LLMGateway:
284
+ """Convenience function to create an LLM gateway.
285
+
286
+ Args:
287
+ model: Model name or full model string (e.g., "gpt-4o" or "openai:gpt-4o")
288
+ provider: Provider name (optional if model string includes provider)
289
+ temperature: Temperature for generation
290
+ **kwargs: Additional parameters
291
+
292
+ Returns:
293
+ Configured LLM gateway
294
+
295
+ Example:
296
+ >>> gateway = get_llm_gateway(model="gpt-4o", provider="openai")
297
+ >>> gateway = get_llm_gateway(model="openai:gpt-4o")
298
+ """
299
+ if provider and model:
300
+ # Build model string from separate provider and model
301
+ model_string = f"{provider}:{model}"
302
+ elif model:
303
+ # Use model as-is (may already include provider)
304
+ model_string = model
305
+ else:
306
+ # Use default
307
+ model_string = None
308
+
309
+ return LLMGatewayFactory.create_gateway(
310
+ model_string=model_string,
311
+ temperature=temperature,
312
+ **kwargs,
313
+ )
@@ -0,0 +1,287 @@
1
+ """Mistral LLM Gateway implementation with structured output support."""
2
+ from __future__ import annotations
3
+ import os
4
+ from typing import Any, Dict, Iterator, Optional, Type, TypeVar
5
+
6
+ from mistralai import Mistral
7
+ from pydantic import BaseModel
8
+
9
+ from rakam_systems_core.ai_utils import logging
10
+ from rakam_systems_core.ai_core.interfaces.llm_gateway import LLMGateway, LLMRequest, LLMResponse
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ T = TypeVar("T", bound=BaseModel)
15
+
16
+
17
+ class MistralGateway(LLMGateway):
18
+ """Mistral LLM Gateway with support for structured outputs.
19
+
20
+ Features:
21
+ - Text generation
22
+ - Structured output using JSON mode
23
+ - Streaming support
24
+ - Token counting (approximate)
25
+ - Support for all Mistral models
26
+
27
+ Example:
28
+ >>> gateway = MistralGateway(model="mistral-large-latest", api_key="...")
29
+ >>> request = LLMRequest(
30
+ ... system_prompt="You are a helpful assistant",
31
+ ... user_prompt="What is AI?",
32
+ ... temperature=0.7
33
+ ... )
34
+ >>> response = gateway.generate(request)
35
+ >>> print(response.content)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ name: str = "mistral_gateway",
41
+ config: Optional[Dict[str, Any]] = None,
42
+ model: str = "mistral-large-latest",
43
+ default_temperature: float = 0.7,
44
+ api_key: Optional[str] = None,
45
+ ):
46
+ """Initialize Mistral Gateway.
47
+
48
+ Args:
49
+ name: Gateway name
50
+ config: Configuration dictionary
51
+ model: Mistral model name (e.g., "mistral-large-latest", "mistral-small-latest")
52
+ default_temperature: Default temperature for generation
53
+ api_key: Mistral API key (falls back to MISTRAL_API_KEY env var)
54
+ """
55
+ super().__init__(
56
+ name=name,
57
+ config=config,
58
+ provider="mistral",
59
+ model=model,
60
+ default_temperature=default_temperature,
61
+ api_key=api_key or os.getenv("MISTRAL_API_KEY"),
62
+ )
63
+
64
+ if not self.api_key:
65
+ raise ValueError(
66
+ "Mistral API key must be provided via api_key parameter or MISTRAL_API_KEY environment variable"
67
+ )
68
+
69
+ # Initialize Mistral client
70
+ self.client = Mistral(api_key=self.api_key)
71
+
72
+ logger.info(
73
+ f"Initialized Mistral Gateway with model={self.model}, temperature={self.default_temperature}"
74
+ )
75
+
76
+ def _build_messages(self, request: LLMRequest) -> list[dict]:
77
+ """Build messages array from request."""
78
+ messages = []
79
+
80
+ if request.system_prompt:
81
+ messages.append({
82
+ "role": "system",
83
+ "content": request.system_prompt
84
+ })
85
+
86
+ messages.append({
87
+ "role": "user",
88
+ "content": request.user_prompt
89
+ })
90
+
91
+ return messages
92
+
93
+ def generate(self, request: LLMRequest) -> LLMResponse:
94
+ """Generate a response from Mistral.
95
+
96
+ Args:
97
+ request: Standardized LLM request
98
+
99
+ Returns:
100
+ Standardized LLM response
101
+ """
102
+ messages = self._build_messages(request)
103
+
104
+ # Prepare API call parameters
105
+ params = {
106
+ "model": self.model,
107
+ "messages": messages,
108
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
109
+ }
110
+
111
+ if request.max_tokens:
112
+ params["max_tokens"] = request.max_tokens
113
+
114
+ # Add extra parameters
115
+ params.update(request.extra_params)
116
+
117
+ logger.debug(
118
+ f"Calling Mistral API with model={self.model}, temperature={params['temperature']}")
119
+
120
+ try:
121
+ response = self.client.chat.complete(**params)
122
+
123
+ # Extract response
124
+ content = response.choices[0].message.content
125
+
126
+ # Build usage information
127
+ usage = None
128
+ if response.usage:
129
+ usage = {
130
+ "prompt_tokens": response.usage.prompt_tokens,
131
+ "completion_tokens": response.usage.completion_tokens,
132
+ "total_tokens": response.usage.total_tokens,
133
+ }
134
+
135
+ llm_response = LLMResponse(
136
+ content=content,
137
+ usage=usage,
138
+ model=response.model,
139
+ finish_reason=response.choices[0].finish_reason,
140
+ metadata={
141
+ "id": response.id,
142
+ "created": response.created,
143
+ }
144
+ )
145
+
146
+ logger.info(
147
+ f"Mistral response received: {usage.get('total_tokens', 'unknown') if usage else 'unknown'} tokens, "
148
+ f"finish_reason={llm_response.finish_reason}"
149
+ )
150
+
151
+ return llm_response
152
+
153
+ except Exception as e:
154
+ logger.error(f"Mistral API error: {str(e)}")
155
+ raise
156
+
157
+ def generate_structured(
158
+ self,
159
+ request: LLMRequest,
160
+ schema: Type[T],
161
+ ) -> T:
162
+ """Generate structured output conforming to a Pydantic schema.
163
+
164
+ Uses Mistral's JSON mode and parses the response into the schema.
165
+
166
+ Args:
167
+ request: Standardized LLM request
168
+ schema: Pydantic model class to parse response into
169
+
170
+ Returns:
171
+ Instance of the schema class
172
+ """
173
+ import json
174
+
175
+ messages = self._build_messages(request)
176
+
177
+ # Add schema information to the system prompt
178
+ schema_json = schema.model_json_schema()
179
+
180
+ # Enhance system prompt with schema information
181
+ enhanced_system = request.system_prompt or ""
182
+ enhanced_system += f"\n\nYou must respond with valid JSON that matches this schema:\n{json.dumps(schema_json, indent=2)}"
183
+
184
+ # Update messages with enhanced system prompt
185
+ messages = []
186
+ messages.append({
187
+ "role": "system",
188
+ "content": enhanced_system
189
+ })
190
+ messages.append({
191
+ "role": "user",
192
+ "content": request.user_prompt
193
+ })
194
+
195
+ # Prepare API call parameters
196
+ params = {
197
+ "model": self.model,
198
+ "messages": messages,
199
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
200
+ "response_format": {"type": "json_object"},
201
+ }
202
+
203
+ if request.max_tokens:
204
+ params["max_tokens"] = request.max_tokens
205
+
206
+ # Add extra parameters
207
+ params.update(request.extra_params)
208
+
209
+ logger.debug(
210
+ f"Calling Mistral API for structured output with model={self.model}, schema={schema.__name__}"
211
+ )
212
+
213
+ try:
214
+ response = self.client.chat.complete(**params)
215
+
216
+ # Extract and parse JSON response
217
+ content = response.choices[0].message.content
218
+ parsed_result = schema.model_validate_json(content)
219
+
220
+ logger.info(
221
+ f"Mistral structured response received: schema={schema.__name__}"
222
+ )
223
+
224
+ return parsed_result
225
+
226
+ except Exception as e:
227
+ logger.error(f"Mistral structured output error: {str(e)}")
228
+ raise
229
+
230
+ def stream(self, request: LLMRequest) -> Iterator[str]:
231
+ """Stream responses from Mistral.
232
+
233
+ Args:
234
+ request: Standardized LLM request
235
+
236
+ Yields:
237
+ String chunks from the LLM
238
+ """
239
+ messages = self._build_messages(request)
240
+
241
+ # Prepare API call parameters
242
+ params = {
243
+ "model": self.model,
244
+ "messages": messages,
245
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
246
+ }
247
+
248
+ if request.max_tokens:
249
+ params["max_tokens"] = request.max_tokens
250
+
251
+ # Add extra parameters
252
+ params.update(request.extra_params)
253
+
254
+ logger.debug(f"Streaming from Mistral with model={self.model}")
255
+
256
+ try:
257
+ stream = self.client.chat.stream(**params)
258
+
259
+ for chunk in stream:
260
+ if chunk.data.choices[0].delta.content is not None:
261
+ yield chunk.data.choices[0].delta.content
262
+
263
+ except Exception as e:
264
+ logger.error(f"Mistral streaming error: {str(e)}")
265
+ raise
266
+
267
+ def count_tokens(self, text: str, model: Optional[str] = None) -> int:
268
+ """Count tokens in text.
269
+
270
+ Mistral doesn't provide a native tokenization library, so we use approximation.
271
+
272
+ Args:
273
+ text: Text to count tokens for
274
+ model: Model name (unused for Mistral)
275
+
276
+ Returns:
277
+ Approximate number of tokens in the text
278
+ """
279
+ # Approximation: average of 4 characters per token
280
+ # This is less accurate than tiktoken but reasonable for most use cases
281
+ token_count = len(text) // 4
282
+
283
+ logger.debug(
284
+ f"Counted ~{token_count} tokens (approximation) for text of length {len(text)} characters"
285
+ )
286
+
287
+ return token_count