rakam-systems-agent 0.1.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,295 @@
1
+ """OpenAI LLM Gateway implementation with structured output support."""
2
+ from __future__ import annotations
3
+ import os
4
+ from typing import Any, Dict, Iterator, Optional, Type, TypeVar
5
+
6
+ import tiktoken
7
+ from openai import OpenAI
8
+ from pydantic import BaseModel
9
+
10
+ from rakam_systems_core.ai_utils import logging
11
+ from rakam_systems_core.ai_core.interfaces.llm_gateway import LLMGateway, LLMRequest, LLMResponse
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ T = TypeVar("T", bound=BaseModel)
16
+
17
+
18
+ class OpenAIGateway(LLMGateway):
19
+ """OpenAI LLM Gateway with support for structured outputs.
20
+
21
+ Features:
22
+ - Text generation
23
+ - Structured output using response_format
24
+ - Streaming support
25
+ - Token counting with tiktoken
26
+ - Support for all OpenAI chat models
27
+
28
+ Example:
29
+ >>> gateway = OpenAIGateway(model="gpt-4o", api_key="...")
30
+ >>> request = LLMRequest(
31
+ ... system_prompt="You are a helpful assistant",
32
+ ... user_prompt="What is AI?",
33
+ ... temperature=0.7
34
+ ... )
35
+ >>> response = gateway.generate(request)
36
+ >>> print(response.content)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ name: str = "openai_gateway",
42
+ config: Optional[Dict[str, Any]] = None,
43
+ model: str = "gpt-4o",
44
+ default_temperature: float = 0.7,
45
+ api_key: Optional[str] = None,
46
+ base_url: Optional[str] = None,
47
+ organization: Optional[str] = None,
48
+ ):
49
+ """Initialize OpenAI Gateway.
50
+
51
+ Args:
52
+ name: Gateway name
53
+ config: Configuration dictionary
54
+ model: OpenAI model name (e.g., "gpt-4o", "gpt-4-turbo")
55
+ default_temperature: Default temperature for generation
56
+ api_key: OpenAI API key (falls back to OPENAI_API_KEY env var)
57
+ base_url: Optional base URL for API
58
+ organization: Optional organization ID
59
+ """
60
+ super().__init__(
61
+ name=name,
62
+ config=config,
63
+ provider="openai",
64
+ model=model,
65
+ default_temperature=default_temperature,
66
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
67
+ )
68
+
69
+ if not self.api_key:
70
+ raise ValueError(
71
+ "OpenAI API key must be provided via api_key parameter or OPENAI_API_KEY environment variable"
72
+ )
73
+
74
+ # Initialize OpenAI client
75
+ self.client = OpenAI(
76
+ api_key=self.api_key,
77
+ base_url=base_url,
78
+ organization=organization,
79
+ )
80
+
81
+ logger.info(
82
+ f"Initialized OpenAI Gateway with model={self.model}, temperature={self.default_temperature}"
83
+ )
84
+
85
+ def _build_messages(self, request: LLMRequest) -> list[dict]:
86
+ """Build messages array from request."""
87
+ messages = []
88
+
89
+ if request.system_prompt:
90
+ messages.append({
91
+ "role": "system",
92
+ "content": request.system_prompt
93
+ })
94
+
95
+ messages.append({
96
+ "role": "user",
97
+ "content": request.user_prompt
98
+ })
99
+
100
+ return messages
101
+
102
+ def generate(self, request: LLMRequest) -> LLMResponse:
103
+ """Generate a response from OpenAI.
104
+
105
+ Args:
106
+ request: Standardized LLM request
107
+
108
+ Returns:
109
+ Standardized LLM response
110
+ """
111
+ messages = self._build_messages(request)
112
+
113
+ # Prepare API call parameters
114
+ params = {
115
+ "model": self.model,
116
+ "messages": messages,
117
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
118
+ }
119
+
120
+ if request.max_tokens:
121
+ params["max_tokens"] = request.max_tokens
122
+
123
+ # Add extra parameters
124
+ params.update(request.extra_params)
125
+
126
+ logger.debug(
127
+ f"Calling OpenAI API with model={self.model}, temperature={params['temperature']}")
128
+
129
+ try:
130
+ completion = self.client.chat.completions.create(**params)
131
+
132
+ # Extract response
133
+ content = completion.choices[0].message.content
134
+
135
+ # Build usage information
136
+ usage = None
137
+ if completion.usage:
138
+ usage = {
139
+ "prompt_tokens": completion.usage.prompt_tokens,
140
+ "completion_tokens": completion.usage.completion_tokens,
141
+ "total_tokens": completion.usage.total_tokens,
142
+ }
143
+
144
+ response = LLMResponse(
145
+ content=content,
146
+ usage=usage,
147
+ model=completion.model,
148
+ finish_reason=completion.choices[0].finish_reason,
149
+ metadata={
150
+ "id": completion.id,
151
+ "created": completion.created,
152
+ }
153
+ )
154
+
155
+ logger.info(
156
+ f"OpenAI response received: {usage.get('total_tokens', 'unknown')} tokens, "
157
+ f"finish_reason={response.finish_reason}"
158
+ )
159
+
160
+ return response
161
+
162
+ except Exception as e:
163
+ logger.error(f"OpenAI API error: {str(e)}")
164
+ raise
165
+
166
+ def generate_structured(
167
+ self,
168
+ request: LLMRequest,
169
+ schema: Type[T],
170
+ ) -> T:
171
+ """Generate structured output conforming to a Pydantic schema.
172
+
173
+ Uses OpenAI's structured output feature to ensure response matches schema.
174
+
175
+ Args:
176
+ request: Standardized LLM request
177
+ schema: Pydantic model class to parse response into
178
+
179
+ Returns:
180
+ Instance of the schema class
181
+ """
182
+ messages = self._build_messages(request)
183
+
184
+ # Prepare API call parameters
185
+ params = {
186
+ "model": self.model,
187
+ "messages": messages,
188
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
189
+ }
190
+
191
+ if request.max_tokens:
192
+ params["max_tokens"] = request.max_tokens
193
+
194
+ # Add extra parameters
195
+ params.update(request.extra_params)
196
+
197
+ logger.debug(
198
+ f"Calling OpenAI API for structured output with model={self.model}, schema={schema.__name__}"
199
+ )
200
+
201
+ try:
202
+ # Use beta parse feature for structured outputs
203
+ completion = self.client.beta.chat.completions.parse(
204
+ **params,
205
+ response_format=schema,
206
+ )
207
+
208
+ parsed_result = completion.choices[0].message.parsed
209
+
210
+ logger.info(
211
+ f"OpenAI structured response received: schema={schema.__name__}"
212
+ )
213
+
214
+ return parsed_result
215
+
216
+ except Exception as e:
217
+ logger.error(f"OpenAI structured output error: {str(e)}")
218
+ raise
219
+
220
+ def stream(self, request: LLMRequest) -> Iterator[str]:
221
+ """Stream responses from OpenAI.
222
+
223
+ Args:
224
+ request: Standardized LLM request
225
+
226
+ Yields:
227
+ String chunks from the LLM
228
+ """
229
+ messages = self._build_messages(request)
230
+
231
+ # Prepare API call parameters
232
+ params = {
233
+ "model": self.model,
234
+ "messages": messages,
235
+ "temperature": request.temperature if request.temperature is not None else self.default_temperature,
236
+ "stream": True,
237
+ }
238
+
239
+ if request.max_tokens:
240
+ params["max_tokens"] = request.max_tokens
241
+
242
+ # Add extra parameters (excluding stream since we set it)
243
+ extra = {k: v for k, v in request.extra_params.items() if k !=
244
+ "stream"}
245
+ params.update(extra)
246
+
247
+ logger.debug(f"Streaming from OpenAI with model={self.model}")
248
+
249
+ try:
250
+ stream = self.client.chat.completions.create(**params)
251
+
252
+ for chunk in stream:
253
+ if chunk.choices[0].delta.content is not None:
254
+ yield chunk.choices[0].delta.content
255
+
256
+ except Exception as e:
257
+ logger.error(f"OpenAI streaming error: {str(e)}")
258
+ raise
259
+
260
+ def count_tokens(self, text: str, model: Optional[str] = None) -> int:
261
+ """Count tokens in text using tiktoken.
262
+
263
+ Args:
264
+ text: Text to count tokens for
265
+ model: Model name to determine encoding (uses instance model if None)
266
+
267
+ Returns:
268
+ Number of tokens in the text
269
+ """
270
+ try:
271
+ model_name = model or self.model
272
+
273
+ # Try to get encoding for the specific model
274
+ try:
275
+ encoding = tiktoken.encoding_for_model(model_name)
276
+ except KeyError:
277
+ # Fall back to cl100k_base for unknown models
278
+ logger.warning(
279
+ f"Unknown model {model_name}, using cl100k_base encoding"
280
+ )
281
+ encoding = tiktoken.get_encoding("cl100k_base")
282
+
283
+ token_count = len(encoding.encode(text))
284
+
285
+ logger.debug(
286
+ f"Counted {token_count} tokens for text of length {len(text)} characters"
287
+ )
288
+
289
+ return token_count
290
+
291
+ except Exception as e:
292
+ logger.warning(
293
+ f"Error counting tokens: {e}. Using character approximation.")
294
+ # Fallback to character-based approximation (rough estimate: 4 chars = 1 token)
295
+ return len(text) // 4