rakam-systems-agent 0.1.1rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rakam_systems_agent/__init__.py +35 -0
- rakam_systems_agent/components/__init__.py +26 -0
- rakam_systems_agent/components/base_agent.py +358 -0
- rakam_systems_agent/components/chat_history/__init__.py +10 -0
- rakam_systems_agent/components/chat_history/json_chat_history.py +372 -0
- rakam_systems_agent/components/chat_history/postgres_chat_history.py +668 -0
- rakam_systems_agent/components/chat_history/sql_chat_history.py +446 -0
- rakam_systems_agent/components/llm_gateway/README.md +505 -0
- rakam_systems_agent/components/llm_gateway/__init__.py +16 -0
- rakam_systems_agent/components/llm_gateway/gateway_factory.py +313 -0
- rakam_systems_agent/components/llm_gateway/mistral_gateway.py +287 -0
- rakam_systems_agent/components/llm_gateway/openai_gateway.py +295 -0
- rakam_systems_agent/components/tools/LLM_GATEWAY_TOOLS_README.md +533 -0
- rakam_systems_agent/components/tools/__init__.py +46 -0
- rakam_systems_agent/components/tools/example_tools.py +431 -0
- rakam_systems_agent/components/tools/llm_gateway_tools.py +605 -0
- rakam_systems_agent/components/tools/search_tool.py +14 -0
- rakam_systems_agent/server/README.md +375 -0
- rakam_systems_agent/server/__init__.py +12 -0
- rakam_systems_agent/server/mcp_server_agent.py +127 -0
- rakam_systems_agent-0.1.1rc7.dist-info/METADATA +367 -0
- rakam_systems_agent-0.1.1rc7.dist-info/RECORD +23 -0
- rakam_systems_agent-0.1.1rc7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""LLM Gateway Factory for provider routing and configuration-driven model selection."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from rakam_systems_core.ai_utils import logging
|
|
7
|
+
from rakam_systems_core.ai_core.interfaces.llm_gateway import LLMGateway
|
|
8
|
+
from .openai_gateway import OpenAIGateway
|
|
9
|
+
from .mistral_gateway import MistralGateway
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMGatewayFactory:
|
|
15
|
+
"""Factory for creating LLM gateways based on provider and configuration.
|
|
16
|
+
|
|
17
|
+
This factory enables:
|
|
18
|
+
- Configuration-driven provider selection
|
|
19
|
+
- Automatic routing to the correct gateway
|
|
20
|
+
- Model string parsing (e.g., "openai:gpt-4o")
|
|
21
|
+
- Fallback to environment-based configuration
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
>>> # Using model string with provider prefix
|
|
25
|
+
>>> gateway = LLMGatewayFactory.create_gateway("openai:gpt-4o")
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Using explicit provider and model
|
|
28
|
+
>>> gateway = LLMGatewayFactory.create_gateway_from_config({
|
|
29
|
+
... "provider": "mistral",
|
|
30
|
+
... "model": "mistral-large-latest",
|
|
31
|
+
... "temperature": 0.7
|
|
32
|
+
... })
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Registry of available providers
|
|
36
|
+
_PROVIDERS = {
|
|
37
|
+
"openai": OpenAIGateway,
|
|
38
|
+
"mistral": MistralGateway,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Default models for each provider
|
|
42
|
+
_DEFAULT_MODELS = {
|
|
43
|
+
"openai": "gpt-4o",
|
|
44
|
+
"mistral": "mistral-large-latest",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def parse_model_string(cls, model_string: str) -> tuple[str, str]:
|
|
49
|
+
"""Parse a model string into provider and model name.
|
|
50
|
+
|
|
51
|
+
Supports formats:
|
|
52
|
+
- "openai:gpt-4o" -> ("openai", "gpt-4o")
|
|
53
|
+
- "mistral:mistral-large-latest" -> ("mistral", "mistral-large-latest")
|
|
54
|
+
- "gpt-4o" -> ("openai", "gpt-4o") # assumes OpenAI if no prefix
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model_string: Model string to parse
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Tuple of (provider, model_name)
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If provider is unknown
|
|
64
|
+
"""
|
|
65
|
+
if ":" in model_string:
|
|
66
|
+
provider, model = model_string.split(":", 1)
|
|
67
|
+
else:
|
|
68
|
+
# Default to OpenAI if no provider specified
|
|
69
|
+
provider = "openai"
|
|
70
|
+
model = model_string
|
|
71
|
+
|
|
72
|
+
if provider not in cls._PROVIDERS:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Unknown provider '{provider}'. Supported providers: {list(cls._PROVIDERS.keys())}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return provider, model
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def create_gateway(
|
|
81
|
+
cls,
|
|
82
|
+
model_string: Optional[str] = None,
|
|
83
|
+
temperature: Optional[float] = None,
|
|
84
|
+
api_key: Optional[str] = None,
|
|
85
|
+
**kwargs: Any,
|
|
86
|
+
) -> LLMGateway:
|
|
87
|
+
"""Create an LLM gateway from a model string.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model_string: Model string (e.g., "openai:gpt-4o", "mistral:mistral-large-latest")
|
|
91
|
+
Falls back to DEFAULT_LLM_MODEL env var
|
|
92
|
+
temperature: Temperature for generation
|
|
93
|
+
api_key: API key for the provider (provider-specific env vars used as fallback)
|
|
94
|
+
**kwargs: Additional parameters passed to gateway constructor
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Configured LLM gateway instance
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If provider is unknown or configuration is invalid
|
|
101
|
+
"""
|
|
102
|
+
# Get model string from parameter or environment
|
|
103
|
+
model_string = model_string or os.getenv(
|
|
104
|
+
"DEFAULT_LLM_MODEL", "openai:gpt-4o")
|
|
105
|
+
|
|
106
|
+
# Parse provider and model
|
|
107
|
+
provider, model = cls.parse_model_string(model_string)
|
|
108
|
+
|
|
109
|
+
# Get gateway class
|
|
110
|
+
gateway_class = cls._PROVIDERS[provider]
|
|
111
|
+
|
|
112
|
+
# Build gateway parameters
|
|
113
|
+
gateway_params = {
|
|
114
|
+
"model": model,
|
|
115
|
+
"default_temperature": temperature or float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7")),
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# Add API key if provided
|
|
119
|
+
if api_key:
|
|
120
|
+
gateway_params["api_key"] = api_key
|
|
121
|
+
|
|
122
|
+
# Add any additional parameters
|
|
123
|
+
gateway_params.update(kwargs)
|
|
124
|
+
|
|
125
|
+
logger.info(
|
|
126
|
+
f"Creating {provider} gateway with model={model}, temperature={gateway_params['default_temperature']}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return gateway_class(**gateway_params)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def create_gateway_from_config(
|
|
133
|
+
cls,
|
|
134
|
+
config: Dict[str, Any],
|
|
135
|
+
) -> LLMGateway:
|
|
136
|
+
"""Create an LLM gateway from a configuration dictionary.
|
|
137
|
+
|
|
138
|
+
Expected config format:
|
|
139
|
+
{
|
|
140
|
+
"provider": "openai", # or "mistral"
|
|
141
|
+
"model": "gpt-4o",
|
|
142
|
+
"temperature": 0.7,
|
|
143
|
+
"api_key": "...", # optional
|
|
144
|
+
... # additional provider-specific params
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
config: Configuration dictionary
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Configured LLM gateway instance
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ValueError: If provider is unknown or configuration is invalid
|
|
155
|
+
"""
|
|
156
|
+
provider = config.get("provider")
|
|
157
|
+
model = config.get("model")
|
|
158
|
+
|
|
159
|
+
if not provider:
|
|
160
|
+
raise ValueError("Configuration must specify 'provider'")
|
|
161
|
+
|
|
162
|
+
if provider not in cls._PROVIDERS:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Unknown provider '{provider}'. Supported providers: {list(cls._PROVIDERS.keys())}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Use default model if not specified
|
|
168
|
+
if not model:
|
|
169
|
+
model = cls._DEFAULT_MODELS.get(provider)
|
|
170
|
+
logger.warning(
|
|
171
|
+
f"No model specified for {provider}, using default: {model}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Get gateway class
|
|
175
|
+
gateway_class = cls._PROVIDERS[provider]
|
|
176
|
+
|
|
177
|
+
# Extract common parameters
|
|
178
|
+
gateway_params = {
|
|
179
|
+
"model": model,
|
|
180
|
+
"default_temperature": config.get("temperature", 0.7),
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# Add API key if provided
|
|
184
|
+
if "api_key" in config:
|
|
185
|
+
gateway_params["api_key"] = config["api_key"]
|
|
186
|
+
|
|
187
|
+
# Add provider-specific parameters
|
|
188
|
+
provider_specific_keys = {
|
|
189
|
+
"openai": ["base_url", "organization"],
|
|
190
|
+
"mistral": [],
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
for key in provider_specific_keys.get(provider, []):
|
|
194
|
+
if key in config:
|
|
195
|
+
gateway_params[key] = config[key]
|
|
196
|
+
|
|
197
|
+
logger.info(
|
|
198
|
+
f"Creating {provider} gateway from config with model={model}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return gateway_class(**gateway_params)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def get_default_gateway(cls) -> LLMGateway:
|
|
205
|
+
"""Get a default gateway based on environment configuration.
|
|
206
|
+
|
|
207
|
+
Checks environment variables:
|
|
208
|
+
- DEFAULT_LLM_PROVIDER: Provider name (default: "openai")
|
|
209
|
+
- DEFAULT_LLM_MODEL: Model name (default: "gpt-4o" for OpenAI)
|
|
210
|
+
- DEFAULT_LLM_TEMPERATURE: Temperature (default: 0.7)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Configured default LLM gateway
|
|
214
|
+
"""
|
|
215
|
+
provider = os.getenv("DEFAULT_LLM_PROVIDER", "openai")
|
|
216
|
+
model = os.getenv("DEFAULT_LLM_MODEL",
|
|
217
|
+
cls._DEFAULT_MODELS.get(provider, "gpt-4o"))
|
|
218
|
+
temperature = float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7"))
|
|
219
|
+
|
|
220
|
+
# If model doesn't have provider prefix, add it
|
|
221
|
+
if ":" not in model:
|
|
222
|
+
model_string = f"{provider}:{model}"
|
|
223
|
+
else:
|
|
224
|
+
model_string = model
|
|
225
|
+
|
|
226
|
+
return cls.create_gateway(
|
|
227
|
+
model_string=model_string,
|
|
228
|
+
temperature=temperature,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def register_provider(
|
|
233
|
+
cls,
|
|
234
|
+
provider_name: str,
|
|
235
|
+
gateway_class: type[LLMGateway],
|
|
236
|
+
default_model: Optional[str] = None,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Register a new provider gateway.
|
|
239
|
+
|
|
240
|
+
This allows extending the factory with custom providers.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
provider_name: Name of the provider (e.g., "custom")
|
|
244
|
+
gateway_class: Gateway class implementing LLMGateway
|
|
245
|
+
default_model: Optional default model for this provider
|
|
246
|
+
"""
|
|
247
|
+
cls._PROVIDERS[provider_name] = gateway_class
|
|
248
|
+
if default_model:
|
|
249
|
+
cls._DEFAULT_MODELS[provider_name] = default_model
|
|
250
|
+
|
|
251
|
+
logger.info(
|
|
252
|
+
f"Registered provider '{provider_name}' with gateway class {gateway_class.__name__}"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def list_providers(cls) -> list[str]:
|
|
257
|
+
"""List all registered providers.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of provider names
|
|
261
|
+
"""
|
|
262
|
+
return list(cls._PROVIDERS.keys())
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def get_default_model(cls, provider: str) -> Optional[str]:
|
|
266
|
+
"""Get the default model for a provider.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
provider: Provider name
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Default model name or None if provider unknown
|
|
273
|
+
"""
|
|
274
|
+
return cls._DEFAULT_MODELS.get(provider)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# Convenience function for creating gateways
|
|
278
|
+
def get_llm_gateway(
|
|
279
|
+
model: Optional[str] = None,
|
|
280
|
+
provider: Optional[str] = None,
|
|
281
|
+
temperature: Optional[float] = None,
|
|
282
|
+
**kwargs: Any,
|
|
283
|
+
) -> LLMGateway:
|
|
284
|
+
"""Convenience function to create an LLM gateway.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
model: Model name or full model string (e.g., "gpt-4o" or "openai:gpt-4o")
|
|
288
|
+
provider: Provider name (optional if model string includes provider)
|
|
289
|
+
temperature: Temperature for generation
|
|
290
|
+
**kwargs: Additional parameters
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Configured LLM gateway
|
|
294
|
+
|
|
295
|
+
Example:
|
|
296
|
+
>>> gateway = get_llm_gateway(model="gpt-4o", provider="openai")
|
|
297
|
+
>>> gateway = get_llm_gateway(model="openai:gpt-4o")
|
|
298
|
+
"""
|
|
299
|
+
if provider and model:
|
|
300
|
+
# Build model string from separate provider and model
|
|
301
|
+
model_string = f"{provider}:{model}"
|
|
302
|
+
elif model:
|
|
303
|
+
# Use model as-is (may already include provider)
|
|
304
|
+
model_string = model
|
|
305
|
+
else:
|
|
306
|
+
# Use default
|
|
307
|
+
model_string = None
|
|
308
|
+
|
|
309
|
+
return LLMGatewayFactory.create_gateway(
|
|
310
|
+
model_string=model_string,
|
|
311
|
+
temperature=temperature,
|
|
312
|
+
**kwargs,
|
|
313
|
+
)
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""Mistral LLM Gateway implementation with structured output support."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Iterator, Optional, Type, TypeVar
|
|
5
|
+
|
|
6
|
+
from mistralai import Mistral
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from rakam_systems_core.ai_utils import logging
|
|
10
|
+
from rakam_systems_core.ai_core.interfaces.llm_gateway import LLMGateway, LLMRequest, LLMResponse
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T", bound=BaseModel)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MistralGateway(LLMGateway):
|
|
18
|
+
"""Mistral LLM Gateway with support for structured outputs.
|
|
19
|
+
|
|
20
|
+
Features:
|
|
21
|
+
- Text generation
|
|
22
|
+
- Structured output using JSON mode
|
|
23
|
+
- Streaming support
|
|
24
|
+
- Token counting (approximate)
|
|
25
|
+
- Support for all Mistral models
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> gateway = MistralGateway(model="mistral-large-latest", api_key="...")
|
|
29
|
+
>>> request = LLMRequest(
|
|
30
|
+
... system_prompt="You are a helpful assistant",
|
|
31
|
+
... user_prompt="What is AI?",
|
|
32
|
+
... temperature=0.7
|
|
33
|
+
... )
|
|
34
|
+
>>> response = gateway.generate(request)
|
|
35
|
+
>>> print(response.content)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
name: str = "mistral_gateway",
|
|
41
|
+
config: Optional[Dict[str, Any]] = None,
|
|
42
|
+
model: str = "mistral-large-latest",
|
|
43
|
+
default_temperature: float = 0.7,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
):
|
|
46
|
+
"""Initialize Mistral Gateway.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
name: Gateway name
|
|
50
|
+
config: Configuration dictionary
|
|
51
|
+
model: Mistral model name (e.g., "mistral-large-latest", "mistral-small-latest")
|
|
52
|
+
default_temperature: Default temperature for generation
|
|
53
|
+
api_key: Mistral API key (falls back to MISTRAL_API_KEY env var)
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(
|
|
56
|
+
name=name,
|
|
57
|
+
config=config,
|
|
58
|
+
provider="mistral",
|
|
59
|
+
model=model,
|
|
60
|
+
default_temperature=default_temperature,
|
|
61
|
+
api_key=api_key or os.getenv("MISTRAL_API_KEY"),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if not self.api_key:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Mistral API key must be provided via api_key parameter or MISTRAL_API_KEY environment variable"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Initialize Mistral client
|
|
70
|
+
self.client = Mistral(api_key=self.api_key)
|
|
71
|
+
|
|
72
|
+
logger.info(
|
|
73
|
+
f"Initialized Mistral Gateway with model={self.model}, temperature={self.default_temperature}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _build_messages(self, request: LLMRequest) -> list[dict]:
|
|
77
|
+
"""Build messages array from request."""
|
|
78
|
+
messages = []
|
|
79
|
+
|
|
80
|
+
if request.system_prompt:
|
|
81
|
+
messages.append({
|
|
82
|
+
"role": "system",
|
|
83
|
+
"content": request.system_prompt
|
|
84
|
+
})
|
|
85
|
+
|
|
86
|
+
messages.append({
|
|
87
|
+
"role": "user",
|
|
88
|
+
"content": request.user_prompt
|
|
89
|
+
})
|
|
90
|
+
|
|
91
|
+
return messages
|
|
92
|
+
|
|
93
|
+
def generate(self, request: LLMRequest) -> LLMResponse:
|
|
94
|
+
"""Generate a response from Mistral.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
request: Standardized LLM request
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Standardized LLM response
|
|
101
|
+
"""
|
|
102
|
+
messages = self._build_messages(request)
|
|
103
|
+
|
|
104
|
+
# Prepare API call parameters
|
|
105
|
+
params = {
|
|
106
|
+
"model": self.model,
|
|
107
|
+
"messages": messages,
|
|
108
|
+
"temperature": request.temperature if request.temperature is not None else self.default_temperature,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if request.max_tokens:
|
|
112
|
+
params["max_tokens"] = request.max_tokens
|
|
113
|
+
|
|
114
|
+
# Add extra parameters
|
|
115
|
+
params.update(request.extra_params)
|
|
116
|
+
|
|
117
|
+
logger.debug(
|
|
118
|
+
f"Calling Mistral API with model={self.model}, temperature={params['temperature']}")
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
response = self.client.chat.complete(**params)
|
|
122
|
+
|
|
123
|
+
# Extract response
|
|
124
|
+
content = response.choices[0].message.content
|
|
125
|
+
|
|
126
|
+
# Build usage information
|
|
127
|
+
usage = None
|
|
128
|
+
if response.usage:
|
|
129
|
+
usage = {
|
|
130
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
131
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
132
|
+
"total_tokens": response.usage.total_tokens,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
llm_response = LLMResponse(
|
|
136
|
+
content=content,
|
|
137
|
+
usage=usage,
|
|
138
|
+
model=response.model,
|
|
139
|
+
finish_reason=response.choices[0].finish_reason,
|
|
140
|
+
metadata={
|
|
141
|
+
"id": response.id,
|
|
142
|
+
"created": response.created,
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
logger.info(
|
|
147
|
+
f"Mistral response received: {usage.get('total_tokens', 'unknown') if usage else 'unknown'} tokens, "
|
|
148
|
+
f"finish_reason={llm_response.finish_reason}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return llm_response
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.error(f"Mistral API error: {str(e)}")
|
|
155
|
+
raise
|
|
156
|
+
|
|
157
|
+
def generate_structured(
|
|
158
|
+
self,
|
|
159
|
+
request: LLMRequest,
|
|
160
|
+
schema: Type[T],
|
|
161
|
+
) -> T:
|
|
162
|
+
"""Generate structured output conforming to a Pydantic schema.
|
|
163
|
+
|
|
164
|
+
Uses Mistral's JSON mode and parses the response into the schema.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
request: Standardized LLM request
|
|
168
|
+
schema: Pydantic model class to parse response into
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Instance of the schema class
|
|
172
|
+
"""
|
|
173
|
+
import json
|
|
174
|
+
|
|
175
|
+
messages = self._build_messages(request)
|
|
176
|
+
|
|
177
|
+
# Add schema information to the system prompt
|
|
178
|
+
schema_json = schema.model_json_schema()
|
|
179
|
+
|
|
180
|
+
# Enhance system prompt with schema information
|
|
181
|
+
enhanced_system = request.system_prompt or ""
|
|
182
|
+
enhanced_system += f"\n\nYou must respond with valid JSON that matches this schema:\n{json.dumps(schema_json, indent=2)}"
|
|
183
|
+
|
|
184
|
+
# Update messages with enhanced system prompt
|
|
185
|
+
messages = []
|
|
186
|
+
messages.append({
|
|
187
|
+
"role": "system",
|
|
188
|
+
"content": enhanced_system
|
|
189
|
+
})
|
|
190
|
+
messages.append({
|
|
191
|
+
"role": "user",
|
|
192
|
+
"content": request.user_prompt
|
|
193
|
+
})
|
|
194
|
+
|
|
195
|
+
# Prepare API call parameters
|
|
196
|
+
params = {
|
|
197
|
+
"model": self.model,
|
|
198
|
+
"messages": messages,
|
|
199
|
+
"temperature": request.temperature if request.temperature is not None else self.default_temperature,
|
|
200
|
+
"response_format": {"type": "json_object"},
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if request.max_tokens:
|
|
204
|
+
params["max_tokens"] = request.max_tokens
|
|
205
|
+
|
|
206
|
+
# Add extra parameters
|
|
207
|
+
params.update(request.extra_params)
|
|
208
|
+
|
|
209
|
+
logger.debug(
|
|
210
|
+
f"Calling Mistral API for structured output with model={self.model}, schema={schema.__name__}"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
response = self.client.chat.complete(**params)
|
|
215
|
+
|
|
216
|
+
# Extract and parse JSON response
|
|
217
|
+
content = response.choices[0].message.content
|
|
218
|
+
parsed_result = schema.model_validate_json(content)
|
|
219
|
+
|
|
220
|
+
logger.info(
|
|
221
|
+
f"Mistral structured response received: schema={schema.__name__}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return parsed_result
|
|
225
|
+
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.error(f"Mistral structured output error: {str(e)}")
|
|
228
|
+
raise
|
|
229
|
+
|
|
230
|
+
def stream(self, request: LLMRequest) -> Iterator[str]:
|
|
231
|
+
"""Stream responses from Mistral.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
request: Standardized LLM request
|
|
235
|
+
|
|
236
|
+
Yields:
|
|
237
|
+
String chunks from the LLM
|
|
238
|
+
"""
|
|
239
|
+
messages = self._build_messages(request)
|
|
240
|
+
|
|
241
|
+
# Prepare API call parameters
|
|
242
|
+
params = {
|
|
243
|
+
"model": self.model,
|
|
244
|
+
"messages": messages,
|
|
245
|
+
"temperature": request.temperature if request.temperature is not None else self.default_temperature,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
if request.max_tokens:
|
|
249
|
+
params["max_tokens"] = request.max_tokens
|
|
250
|
+
|
|
251
|
+
# Add extra parameters
|
|
252
|
+
params.update(request.extra_params)
|
|
253
|
+
|
|
254
|
+
logger.debug(f"Streaming from Mistral with model={self.model}")
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
stream = self.client.chat.stream(**params)
|
|
258
|
+
|
|
259
|
+
for chunk in stream:
|
|
260
|
+
if chunk.data.choices[0].delta.content is not None:
|
|
261
|
+
yield chunk.data.choices[0].delta.content
|
|
262
|
+
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.error(f"Mistral streaming error: {str(e)}")
|
|
265
|
+
raise
|
|
266
|
+
|
|
267
|
+
def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
|
268
|
+
"""Count tokens in text.
|
|
269
|
+
|
|
270
|
+
Mistral doesn't provide a native tokenization library, so we use approximation.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
text: Text to count tokens for
|
|
274
|
+
model: Model name (unused for Mistral)
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Approximate number of tokens in the text
|
|
278
|
+
"""
|
|
279
|
+
# Approximation: average of 4 characters per token
|
|
280
|
+
# This is less accurate than tiktoken but reasonable for most use cases
|
|
281
|
+
token_count = len(text) // 4
|
|
282
|
+
|
|
283
|
+
logger.debug(
|
|
284
|
+
f"Counted ~{token_count} tokens (approximation) for text of length {len(text)} characters"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return token_count
|