DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/llm/client.py
ADDED
|
@@ -0,0 +1,1206 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import anthropic
|
|
8
|
+
import openai
|
|
9
|
+
import outlines
|
|
10
|
+
|
|
11
|
+
from google import genai
|
|
12
|
+
from google.genai import types as genai_types
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from ..exceptions import DataSetGeneratorError
|
|
16
|
+
from .errors import handle_provider_error
|
|
17
|
+
from .rate_limit_config import (
|
|
18
|
+
RateLimitConfig,
|
|
19
|
+
create_rate_limit_config,
|
|
20
|
+
get_default_rate_limit_config,
|
|
21
|
+
)
|
|
22
|
+
from .retry_handler import RetryHandler, retry_with_backoff, retry_with_backoff_async
|
|
23
|
+
|
|
24
|
+
# JSON Schema union type keys that need recursive processing
|
|
25
|
+
_UNION_KEYS = ("anyOf", "oneOf", "allOf")
|
|
26
|
+
|
|
27
|
+
# Provider-specific parameter mappings
|
|
28
|
+
# Maps generic parameter names to provider-specific equivalents
|
|
29
|
+
# Format: {provider: {generic_name: provider_specific_name}}
|
|
30
|
+
_PROVIDER_PARAM_MAPPINGS: dict[str, dict[str, str]] = {
|
|
31
|
+
"gemini": {
|
|
32
|
+
"max_tokens": "max_output_tokens",
|
|
33
|
+
},
|
|
34
|
+
# Add other provider mappings as needed
|
|
35
|
+
# "anthropic": {"some_param": "anthropic_param"},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LLMClient:
|
|
40
|
+
"""Wrapper for Outlines models with retry logic and error handling."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
provider: str,
|
|
45
|
+
model_name: str,
|
|
46
|
+
rate_limit_config: RateLimitConfig | dict | None = None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
"""Initialize LLM client.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
provider: Provider name
|
|
53
|
+
model_name: Model identifier
|
|
54
|
+
rate_limit_config: Rate limiting configuration (None uses provider defaults)
|
|
55
|
+
**kwargs: Additional client configuration
|
|
56
|
+
"""
|
|
57
|
+
self.provider = provider
|
|
58
|
+
self.model_name = model_name
|
|
59
|
+
self._client_kwargs = kwargs # Store for lazy async model initialization
|
|
60
|
+
|
|
61
|
+
# Initialize rate limiting
|
|
62
|
+
if isinstance(rate_limit_config, dict):
|
|
63
|
+
self.rate_limit_config = create_rate_limit_config(provider, rate_limit_config)
|
|
64
|
+
elif rate_limit_config is None:
|
|
65
|
+
self.rate_limit_config = get_default_rate_limit_config(provider)
|
|
66
|
+
else:
|
|
67
|
+
self.rate_limit_config = rate_limit_config
|
|
68
|
+
|
|
69
|
+
self.retry_handler = RetryHandler(self.rate_limit_config, provider)
|
|
70
|
+
|
|
71
|
+
self.model: Any = make_outlines_model(provider, model_name, **kwargs)
|
|
72
|
+
# Lazy-initialize async_model only when needed
|
|
73
|
+
self._async_model: Any | None = None
|
|
74
|
+
self._async_model_initialized: bool = False
|
|
75
|
+
if self.model is None:
|
|
76
|
+
msg = f"Failed to create model for {provider}/{model_name}"
|
|
77
|
+
raise DataSetGeneratorError(msg)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def async_model(self) -> Any | None:
|
|
81
|
+
"""Lazily initialize and return the async model.
|
|
82
|
+
|
|
83
|
+
The async model is only created when first accessed, reducing memory
|
|
84
|
+
and connection overhead when only sync operations are used.
|
|
85
|
+
"""
|
|
86
|
+
if not self._async_model_initialized:
|
|
87
|
+
self._async_model = make_async_outlines_model(
|
|
88
|
+
self.provider, self.model_name, **self._client_kwargs
|
|
89
|
+
)
|
|
90
|
+
self._async_model_initialized = True
|
|
91
|
+
return self._async_model
|
|
92
|
+
|
|
93
|
+
def generate(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs) -> Any: # noqa: ARG002
|
|
94
|
+
"""Generate structured output using the provided schema.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
prompt: Input prompt
|
|
98
|
+
schema: Pydantic model or other schema type
|
|
99
|
+
max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
|
|
100
|
+
**kwargs: Additional generation parameters
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Generated output matching the schema
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
DataSetGeneratorError: If generation fails after all retries
|
|
107
|
+
|
|
108
|
+
Note:
|
|
109
|
+
The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
|
|
110
|
+
If provided, it will be ignored in favor of the configured retry handler.
|
|
111
|
+
"""
|
|
112
|
+
return self._generate_with_retry(prompt, schema, **kwargs)
|
|
113
|
+
|
|
114
|
+
# Providers that only support async generation (native API, not Outlines)
|
|
115
|
+
_ASYNC_ONLY_PROVIDERS = frozenset({"anthropic", "gemini"})
|
|
116
|
+
|
|
117
|
+
@retry_with_backoff
|
|
118
|
+
def _generate_with_retry(self, prompt: str, schema: Any, **kwargs) -> Any:
|
|
119
|
+
"""Internal method that performs actual generation with retry wrapper.
|
|
120
|
+
|
|
121
|
+
This method is decorated with retry_with_backoff to handle rate limits
|
|
122
|
+
and transient errors automatically.
|
|
123
|
+
"""
|
|
124
|
+
# Check for async-only providers
|
|
125
|
+
if self.provider in self._ASYNC_ONLY_PROVIDERS:
|
|
126
|
+
raise DataSetGeneratorError(
|
|
127
|
+
f"Synchronous generation is not supported for {self.provider}. "
|
|
128
|
+
f"Use generate_async() instead.",
|
|
129
|
+
context={"provider": self.provider},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Convert provider-specific parameters
|
|
133
|
+
kwargs = self._convert_generation_params(**kwargs)
|
|
134
|
+
|
|
135
|
+
# For OpenAI and OpenRouter, ensure all properties are in required array
|
|
136
|
+
generation_schema = schema
|
|
137
|
+
if (
|
|
138
|
+
self.provider in ("openai", "openrouter")
|
|
139
|
+
and isinstance(schema, type)
|
|
140
|
+
and issubclass(schema, BaseModel)
|
|
141
|
+
):
|
|
142
|
+
generation_schema = _create_openai_compatible_schema(schema)
|
|
143
|
+
|
|
144
|
+
# Generate JSON string using the model
|
|
145
|
+
json_output = self.model(prompt, generation_schema, **kwargs)
|
|
146
|
+
|
|
147
|
+
# Parse and validate the JSON response with the ORIGINAL schema
|
|
148
|
+
# This ensures we still get proper validation
|
|
149
|
+
try:
|
|
150
|
+
return schema.model_validate_json(json_output)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
raise DataSetGeneratorError(
|
|
153
|
+
f"Generation validation failed: {e}",
|
|
154
|
+
context={"raw_content": json_output},
|
|
155
|
+
) from e
|
|
156
|
+
|
|
157
|
+
async def generate_async(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs) -> Any: # noqa: ARG002
|
|
158
|
+
"""Asynchronously generate structured output using provider async clients.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
prompt: Input prompt
|
|
162
|
+
schema: Pydantic model or other schema type
|
|
163
|
+
max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
|
|
164
|
+
**kwargs: Additional generation parameters
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Generated output matching the schema
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
DataSetGeneratorError: If generation fails after all retries
|
|
171
|
+
|
|
172
|
+
Note:
|
|
173
|
+
The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
|
|
174
|
+
If provided, it will be ignored in favor of the configured retry handler.
|
|
175
|
+
"""
|
|
176
|
+
if self.async_model is None:
|
|
177
|
+
# Fallback to running the synchronous path in a worker thread
|
|
178
|
+
return await asyncio.to_thread(self.generate, prompt, schema, **kwargs)
|
|
179
|
+
|
|
180
|
+
return await self._generate_async_with_retry(prompt, schema, **kwargs)
|
|
181
|
+
|
|
182
|
+
@retry_with_backoff_async
|
|
183
|
+
async def _generate_async_with_retry(self, prompt: str, schema: Any, **kwargs) -> Any:
|
|
184
|
+
"""Internal async method that performs actual generation with retry wrapper.
|
|
185
|
+
|
|
186
|
+
This method is decorated with retry_with_backoff_async to handle rate limits
|
|
187
|
+
and transient errors automatically.
|
|
188
|
+
"""
|
|
189
|
+
kwargs = self._convert_generation_params(**kwargs)
|
|
190
|
+
|
|
191
|
+
# Native providers (Anthropic, Gemini) handle schema transformation internally
|
|
192
|
+
# For OpenAI and OpenRouter, ensure all properties are in required array
|
|
193
|
+
generation_schema = schema
|
|
194
|
+
if self.provider in self._ASYNC_ONLY_PROVIDERS:
|
|
195
|
+
# Native model classes handle schema transformation internally
|
|
196
|
+
pass
|
|
197
|
+
elif (
|
|
198
|
+
self.provider in ("openai", "openrouter")
|
|
199
|
+
and isinstance(schema, type)
|
|
200
|
+
and issubclass(schema, BaseModel)
|
|
201
|
+
):
|
|
202
|
+
generation_schema = _create_openai_compatible_schema(schema)
|
|
203
|
+
|
|
204
|
+
# Ensure we have an async model; if not, fall back to running the sync path
|
|
205
|
+
async_model = self.async_model
|
|
206
|
+
if async_model is None:
|
|
207
|
+
# Note: This will raise an error for async-only providers
|
|
208
|
+
return await asyncio.to_thread(self.generate, prompt, schema, **kwargs)
|
|
209
|
+
|
|
210
|
+
# Call the async model (guaranteed non-None by check above)
|
|
211
|
+
json_output = await async_model(prompt, generation_schema, **kwargs)
|
|
212
|
+
|
|
213
|
+
# Validate with original schema to ensure proper validation
|
|
214
|
+
try:
|
|
215
|
+
return schema.model_validate_json(json_output)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
raise DataSetGeneratorError(
|
|
218
|
+
f"Async generation validation failed: {e}",
|
|
219
|
+
context={"raw_content": json_output},
|
|
220
|
+
) from e
|
|
221
|
+
|
|
222
|
+
async def generate_async_stream(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs): # noqa: ARG002
|
|
223
|
+
"""Asynchronously generate structured output with streaming text chunks.
|
|
224
|
+
|
|
225
|
+
This method streams the LLM's output text as it's generated, then returns
|
|
226
|
+
the final parsed Pydantic model. It yields tuples of (chunk, result) where:
|
|
227
|
+
- During streaming: (text_chunk, None)
|
|
228
|
+
- When complete: (None, final_pydantic_model)
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
prompt: Input prompt
|
|
232
|
+
schema: Pydantic model or other schema type
|
|
233
|
+
max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
|
|
234
|
+
**kwargs: Additional generation parameters
|
|
235
|
+
|
|
236
|
+
Yields:
|
|
237
|
+
tuple[str | None, Any | None]:
|
|
238
|
+
- (chunk, None) during streaming
|
|
239
|
+
- (None, model) when generation is complete
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
DataSetGeneratorError: If generation fails after all retries
|
|
243
|
+
|
|
244
|
+
Note:
|
|
245
|
+
The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
>>> async for chunk, result in client.generate_async_stream(prompt, MyModel):
|
|
249
|
+
... if chunk:
|
|
250
|
+
... print(chunk, end='', flush=True) # Display streaming text
|
|
251
|
+
... if result:
|
|
252
|
+
... return result # Final parsed model
|
|
253
|
+
"""
|
|
254
|
+
# Call streaming generation directly (retry decorator doesn't work with generators)
|
|
255
|
+
kwargs = self._convert_generation_params(**kwargs)
|
|
256
|
+
|
|
257
|
+
# Native providers (Anthropic, Gemini) handle schema transformation internally
|
|
258
|
+
# For OpenAI and OpenRouter, ensure all properties are in required array
|
|
259
|
+
generation_schema = schema
|
|
260
|
+
if self.provider in self._ASYNC_ONLY_PROVIDERS:
|
|
261
|
+
# Native model classes handle schema transformation internally
|
|
262
|
+
pass
|
|
263
|
+
elif (
|
|
264
|
+
self.provider in ("openai", "openrouter")
|
|
265
|
+
and isinstance(schema, type)
|
|
266
|
+
and issubclass(schema, BaseModel)
|
|
267
|
+
):
|
|
268
|
+
generation_schema = _create_openai_compatible_schema(schema)
|
|
269
|
+
|
|
270
|
+
# Check if model supports streaming
|
|
271
|
+
async_model = self.async_model or self.model
|
|
272
|
+
if not hasattr(async_model, "generate_stream"):
|
|
273
|
+
# Fallback: no streaming support, yield entire result at once
|
|
274
|
+
result = await self.generate_async(prompt, schema, **kwargs)
|
|
275
|
+
yield (None, result)
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
# Stream generation
|
|
279
|
+
accumulated_text: list[str] = []
|
|
280
|
+
try:
|
|
281
|
+
# For sync models used in async context
|
|
282
|
+
if self.async_model is None:
|
|
283
|
+
# Use asyncio.Queue for true streaming from sync generator
|
|
284
|
+
# This yields chunks as they arrive instead of waiting for all
|
|
285
|
+
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
|
286
|
+
stream_error: list[Exception] = []
|
|
287
|
+
|
|
288
|
+
def _produce_chunks():
|
|
289
|
+
"""Run sync generator and put chunks in queue."""
|
|
290
|
+
try:
|
|
291
|
+
for chunk in self.model.generate_stream(
|
|
292
|
+
prompt, generation_schema, **kwargs
|
|
293
|
+
):
|
|
294
|
+
asyncio.run_coroutine_threadsafe(
|
|
295
|
+
queue.put(chunk), asyncio.get_event_loop()
|
|
296
|
+
)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
stream_error.append(e)
|
|
299
|
+
finally:
|
|
300
|
+
# Signal completion
|
|
301
|
+
asyncio.run_coroutine_threadsafe(queue.put(None), asyncio.get_event_loop())
|
|
302
|
+
|
|
303
|
+
# Start producer in background thread
|
|
304
|
+
loop = asyncio.get_event_loop()
|
|
305
|
+
loop.run_in_executor(None, _produce_chunks)
|
|
306
|
+
|
|
307
|
+
# Consume chunks as they arrive
|
|
308
|
+
while True:
|
|
309
|
+
chunk = await queue.get()
|
|
310
|
+
if chunk is None:
|
|
311
|
+
break
|
|
312
|
+
accumulated_text.append(chunk)
|
|
313
|
+
yield (chunk, None)
|
|
314
|
+
|
|
315
|
+
# Re-raise any error from the producer thread
|
|
316
|
+
if stream_error:
|
|
317
|
+
raise DataSetGeneratorError(
|
|
318
|
+
f"Streaming generation failed in producer: {stream_error[0]}",
|
|
319
|
+
context={
|
|
320
|
+
"raw_content": "".join(accumulated_text) if accumulated_text else None
|
|
321
|
+
},
|
|
322
|
+
) from stream_error[0]
|
|
323
|
+
else:
|
|
324
|
+
# True async streaming
|
|
325
|
+
stream = self.async_model.generate_stream(prompt, generation_schema, **kwargs)
|
|
326
|
+
async for chunk in stream:
|
|
327
|
+
accumulated_text.append(chunk)
|
|
328
|
+
yield (chunk, None)
|
|
329
|
+
|
|
330
|
+
# Parse accumulated JSON with original schema
|
|
331
|
+
full_text = "".join(accumulated_text)
|
|
332
|
+
result = schema.model_validate_json(full_text)
|
|
333
|
+
yield (None, result)
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
# Wrap and raise error with raw content for debugging
|
|
337
|
+
raw_content = "".join(accumulated_text) if accumulated_text else None
|
|
338
|
+
raise DataSetGeneratorError(
|
|
339
|
+
f"Streaming generation failed: {e}",
|
|
340
|
+
context={"raw_content": raw_content},
|
|
341
|
+
) from e
|
|
342
|
+
|
|
343
|
+
def _convert_generation_params(self, **kwargs) -> dict:
|
|
344
|
+
"""Convert generic parameters to provider-specific ones.
|
|
345
|
+
|
|
346
|
+
Uses the _PROVIDER_PARAM_MAPPINGS static dictionary for extensible
|
|
347
|
+
parameter conversion across different providers.
|
|
348
|
+
"""
|
|
349
|
+
mappings = _PROVIDER_PARAM_MAPPINGS.get(self.provider, {})
|
|
350
|
+
for generic_name, provider_name in mappings.items():
|
|
351
|
+
if generic_name in kwargs:
|
|
352
|
+
kwargs[provider_name] = kwargs.pop(generic_name)
|
|
353
|
+
|
|
354
|
+
return kwargs
|
|
355
|
+
|
|
356
|
+
def __repr__(self) -> str:
|
|
357
|
+
return f"LLMClient(provider={self.provider}, model={self.model_name})"
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class GeminiModel:
|
|
361
|
+
"""Gemini API client using native structured outputs.
|
|
362
|
+
|
|
363
|
+
Uses Gemini's native response_json_schema parameter which is more reliable
|
|
364
|
+
than Outlines' wrapper for structured output generation.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, client: genai.Client, model_name: str):
|
|
368
|
+
self.client = client
|
|
369
|
+
self.model_name = model_name
|
|
370
|
+
|
|
371
|
+
def _prepare_json_schema(self, schema: type[BaseModel]) -> dict:
|
|
372
|
+
"""Prepare JSON schema for Gemini by stripping incompatible fields.
|
|
373
|
+
|
|
374
|
+
Note: We only strip additionalProperties and incompatible fields here.
|
|
375
|
+
Ref inlining is intentionally NOT done to avoid excessive nesting depth
|
|
376
|
+
that can exceed Gemini's schema limits for deeply nested models.
|
|
377
|
+
"""
|
|
378
|
+
json_schema = schema.model_json_schema()
|
|
379
|
+
return _strip_additional_properties(json_schema)
|
|
380
|
+
|
|
381
|
+
async def __call__(
|
|
382
|
+
self,
|
|
383
|
+
prompt: str,
|
|
384
|
+
schema: type[BaseModel],
|
|
385
|
+
**kwargs,
|
|
386
|
+
) -> str:
|
|
387
|
+
"""Asynchronously generate structured output using Gemini's native JSON schema support.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
prompt: The input prompt
|
|
391
|
+
schema: Pydantic model class for structured output
|
|
392
|
+
**kwargs: Additional generation parameters (temperature, max_output_tokens, etc.)
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
JSON string matching the schema
|
|
396
|
+
"""
|
|
397
|
+
# Use prepared schema with Gemini-compatible transformations
|
|
398
|
+
prepared_schema = self._prepare_json_schema(schema)
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
config = genai_types.GenerateContentConfig(
|
|
402
|
+
response_mime_type="application/json",
|
|
403
|
+
response_json_schema=prepared_schema,
|
|
404
|
+
**kwargs,
|
|
405
|
+
)
|
|
406
|
+
except Exception as e:
|
|
407
|
+
raise DataSetGeneratorError(f"Failed to create Gemini config: {e}") from e
|
|
408
|
+
|
|
409
|
+
# Call Gemini API directly using async method
|
|
410
|
+
response = await self.client.aio.models.generate_content(
|
|
411
|
+
model=self.model_name,
|
|
412
|
+
contents=prompt,
|
|
413
|
+
config=config,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# Safely check for empty or blocked responses
|
|
417
|
+
# Some SDK versions may return None or have blocked candidates
|
|
418
|
+
if not response.candidates:
|
|
419
|
+
raise DataSetGeneratorError(
|
|
420
|
+
"Gemini returned empty response",
|
|
421
|
+
context={"finish_reason": None},
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
first_candidate = response.candidates[0]
|
|
425
|
+
if first_candidate.content is None or not first_candidate.content.parts:
|
|
426
|
+
raise DataSetGeneratorError(
|
|
427
|
+
"Gemini returned empty response",
|
|
428
|
+
context={"finish_reason": getattr(first_candidate, "finish_reason", None)},
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if response.text is None:
|
|
432
|
+
raise DataSetGeneratorError("Gemini returned empty response")
|
|
433
|
+
|
|
434
|
+
return response.text
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class AnthropicModel:
|
|
438
|
+
"""Anthropic API client using native structured outputs.
|
|
439
|
+
|
|
440
|
+
Uses Anthropic's beta structured outputs feature (structured-outputs-2025-11-13)
|
|
441
|
+
which provides guaranteed JSON schema compliance through constrained decoding.
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
# Beta header for structured outputs feature
|
|
445
|
+
STRUCTURED_OUTPUTS_BETA = "structured-outputs-2025-11-13"
|
|
446
|
+
|
|
447
|
+
def __init__(self, client: anthropic.AsyncAnthropic, model_name: str):
|
|
448
|
+
self.client = client
|
|
449
|
+
self.model_name = model_name
|
|
450
|
+
|
|
451
|
+
def _prepare_json_schema(self, schema: type[BaseModel]) -> dict:
|
|
452
|
+
"""Prepare JSON schema for Anthropic structured outputs.
|
|
453
|
+
|
|
454
|
+
Anthropic's structured outputs require:
|
|
455
|
+
- additionalProperties: false on all objects
|
|
456
|
+
- All properties in required array
|
|
457
|
+
"""
|
|
458
|
+
json_schema = schema.model_json_schema()
|
|
459
|
+
return _ensure_anthropic_strict_mode_compliance(json_schema)
|
|
460
|
+
|
|
461
|
+
async def __call__(
|
|
462
|
+
self,
|
|
463
|
+
prompt: str,
|
|
464
|
+
schema: type[BaseModel],
|
|
465
|
+
**kwargs,
|
|
466
|
+
) -> str:
|
|
467
|
+
"""Asynchronously generate structured output using Anthropic's native JSON schema support.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
prompt: The input prompt
|
|
471
|
+
schema: Pydantic model class for structured output
|
|
472
|
+
**kwargs: Additional generation parameters (temperature, max_tokens, etc.)
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
JSON string matching the schema
|
|
476
|
+
"""
|
|
477
|
+
prepared_schema = self._prepare_json_schema(schema)
|
|
478
|
+
|
|
479
|
+
# Extract max_tokens with a sensible default (16384 for structured outputs)
|
|
480
|
+
max_tokens = kwargs.pop("max_tokens", 16384)
|
|
481
|
+
|
|
482
|
+
try:
|
|
483
|
+
response = await self.client.beta.messages.create(
|
|
484
|
+
model=self.model_name,
|
|
485
|
+
max_tokens=max_tokens,
|
|
486
|
+
betas=[self.STRUCTURED_OUTPUTS_BETA],
|
|
487
|
+
messages=[{"role": "user", "content": prompt}],
|
|
488
|
+
output_format={
|
|
489
|
+
"type": "json_schema",
|
|
490
|
+
"schema": prepared_schema,
|
|
491
|
+
},
|
|
492
|
+
**kwargs,
|
|
493
|
+
)
|
|
494
|
+
except anthropic.BadRequestError as e:
|
|
495
|
+
raise DataSetGeneratorError(
|
|
496
|
+
f"Anthropic structured output request failed: {e}",
|
|
497
|
+
context={"schema": prepared_schema},
|
|
498
|
+
) from e
|
|
499
|
+
|
|
500
|
+
# Check for refusals
|
|
501
|
+
if response.stop_reason == "refusal":
|
|
502
|
+
raise DataSetGeneratorError(
|
|
503
|
+
"Anthropic refused the request for safety reasons",
|
|
504
|
+
context={"stop_reason": response.stop_reason},
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Check for max_tokens truncation
|
|
508
|
+
if response.stop_reason == "max_tokens":
|
|
509
|
+
raise DataSetGeneratorError(
|
|
510
|
+
"Anthropic response truncated due to max_tokens limit",
|
|
511
|
+
context={"stop_reason": response.stop_reason},
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Extract text from response
|
|
515
|
+
if not response.content:
|
|
516
|
+
raise DataSetGeneratorError("Anthropic returned empty response")
|
|
517
|
+
|
|
518
|
+
text_block = response.content[0]
|
|
519
|
+
if not hasattr(text_block, "text") or not text_block.text:
|
|
520
|
+
raise DataSetGeneratorError("Anthropic returned empty text response")
|
|
521
|
+
|
|
522
|
+
return text_block.text
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _raise_api_key_error(env_var: str) -> None:
|
|
526
|
+
"""Raise an error for missing API key."""
|
|
527
|
+
msg = f"{env_var} environment variable not set"
|
|
528
|
+
raise DataSetGeneratorError(msg)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _get_gemini_api_key() -> str:
|
|
532
|
+
"""Retrieve Gemini API key from environment variables.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
The API key string
|
|
536
|
+
|
|
537
|
+
Raises:
|
|
538
|
+
DataSetGeneratorError: If no API key is found
|
|
539
|
+
"""
|
|
540
|
+
for name in ("GOOGLE_API_KEY", "GEMINI_API_KEY"):
|
|
541
|
+
if api_key := os.getenv(name):
|
|
542
|
+
return api_key
|
|
543
|
+
_raise_api_key_error("GOOGLE_API_KEY or GEMINI_API_KEY")
|
|
544
|
+
# This return is never reached but satisfies type checker
|
|
545
|
+
raise AssertionError("unreachable")
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
# Provider to environment variable mapping
|
|
549
|
+
PROVIDER_API_KEY_MAP: dict[str, list[str]] = {
|
|
550
|
+
"openai": ["OPENAI_API_KEY"],
|
|
551
|
+
"anthropic": ["ANTHROPIC_API_KEY"],
|
|
552
|
+
"gemini": ["GOOGLE_API_KEY", "GEMINI_API_KEY"],
|
|
553
|
+
"openrouter": ["OPENROUTER_API_KEY"],
|
|
554
|
+
"ollama": [], # No API key required
|
|
555
|
+
# Test providers for unit tests, no API key required
|
|
556
|
+
"test": [],
|
|
557
|
+
"override": [],
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def validate_provider_api_key(provider: str) -> tuple[bool, str | None]:
|
|
562
|
+
"""Validate that the required API key exists for a provider.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
provider: Provider name (openai, anthropic, gemini, ollama, openrouter)
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
Tuple of (is_valid, error_message). If valid, error_message is None.
|
|
569
|
+
"""
|
|
570
|
+
env_vars = PROVIDER_API_KEY_MAP.get(provider)
|
|
571
|
+
|
|
572
|
+
if env_vars is None:
|
|
573
|
+
return False, f"Unknown provider: {provider}"
|
|
574
|
+
|
|
575
|
+
# Ollama doesn't need an API key
|
|
576
|
+
if not env_vars:
|
|
577
|
+
return True, None
|
|
578
|
+
|
|
579
|
+
# Check if any of the required env vars are set
|
|
580
|
+
for env_var in env_vars:
|
|
581
|
+
if os.getenv(env_var):
|
|
582
|
+
return True, None
|
|
583
|
+
|
|
584
|
+
# Build helpful error message
|
|
585
|
+
if len(env_vars) == 1:
|
|
586
|
+
return False, f"Missing API key: {env_vars[0]} environment variable is not set"
|
|
587
|
+
var_list = " or ".join(env_vars)
|
|
588
|
+
return False, f"Missing API key: Set {var_list} environment variable"
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def get_required_api_key_env_var(provider: str) -> str | None:
|
|
592
|
+
"""Get the environment variable name(s) required for a provider.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
provider: Provider name
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
Human-readable string of required env var(s), or None if no key required
|
|
599
|
+
"""
|
|
600
|
+
env_vars = PROVIDER_API_KEY_MAP.get(provider)
|
|
601
|
+
if not env_vars:
|
|
602
|
+
return None
|
|
603
|
+
if len(env_vars) == 1:
|
|
604
|
+
return env_vars[0]
|
|
605
|
+
return " or ".join(env_vars)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def _raise_unsupported_provider_error(provider: str) -> None:
|
|
609
|
+
"""Raise an error for unsupported provider."""
|
|
610
|
+
msg = f"Unsupported provider: {provider}"
|
|
611
|
+
raise DataSetGeneratorError(msg)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def _get_openai_client_config(
|
|
615
|
+
api_key_env_var: str | None,
|
|
616
|
+
default_base_url: str | None,
|
|
617
|
+
dummy_key: str | None = None,
|
|
618
|
+
**kwargs,
|
|
619
|
+
) -> tuple[str | None, dict[str, Any]]:
|
|
620
|
+
"""Extract common configuration for OpenAI-compatible clients.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
api_key_env_var: Environment variable name for API key (None to skip check)
|
|
624
|
+
default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
|
|
625
|
+
dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
|
|
626
|
+
**kwargs: Additional client configuration (may include base_url override)
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
Tuple of (api_key, client_kwargs) for client initialization
|
|
630
|
+
|
|
631
|
+
Raises:
|
|
632
|
+
DataSetGeneratorError: If required API key is missing
|
|
633
|
+
"""
|
|
634
|
+
# Get API key from environment or use dummy key
|
|
635
|
+
if api_key_env_var:
|
|
636
|
+
api_key = os.getenv(api_key_env_var)
|
|
637
|
+
if not api_key:
|
|
638
|
+
_raise_api_key_error(api_key_env_var)
|
|
639
|
+
else:
|
|
640
|
+
api_key = dummy_key
|
|
641
|
+
|
|
642
|
+
# Set up base_url if provided
|
|
643
|
+
client_kwargs = {k: v for k, v in kwargs.items() if k != "base_url"}
|
|
644
|
+
if default_base_url:
|
|
645
|
+
client_kwargs["base_url"] = kwargs.get("base_url", default_base_url)
|
|
646
|
+
|
|
647
|
+
return api_key, client_kwargs
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def _create_openai_compatible_client(
|
|
651
|
+
api_key_env_var: str | None,
|
|
652
|
+
default_base_url: str | None,
|
|
653
|
+
dummy_key: str | None = None,
|
|
654
|
+
**kwargs,
|
|
655
|
+
) -> openai.OpenAI:
|
|
656
|
+
"""Create an OpenAI-compatible client for providers that use OpenAI's API format.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
api_key_env_var: Environment variable name for API key (None to skip check)
|
|
660
|
+
default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
|
|
661
|
+
dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
|
|
662
|
+
**kwargs: Additional client configuration (may include base_url override)
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
Configured OpenAI client instance
|
|
666
|
+
|
|
667
|
+
Raises:
|
|
668
|
+
DataSetGeneratorError: If required API key is missing
|
|
669
|
+
"""
|
|
670
|
+
api_key, client_kwargs = _get_openai_client_config(
|
|
671
|
+
api_key_env_var, default_base_url, dummy_key, **kwargs
|
|
672
|
+
)
|
|
673
|
+
return openai.OpenAI(api_key=api_key, **client_kwargs)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def _create_async_openai_compatible_client(
|
|
677
|
+
api_key_env_var: str | None,
|
|
678
|
+
default_base_url: str | None,
|
|
679
|
+
dummy_key: str | None = None,
|
|
680
|
+
**kwargs,
|
|
681
|
+
) -> openai.AsyncOpenAI:
|
|
682
|
+
"""Create an async OpenAI-compatible client for providers that use OpenAI's API format.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
api_key_env_var: Environment variable name for API key (None to skip check)
|
|
686
|
+
default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
|
|
687
|
+
dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
|
|
688
|
+
**kwargs: Additional client configuration (may include base_url override)
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
Configured AsyncOpenAI client instance
|
|
692
|
+
|
|
693
|
+
Raises:
|
|
694
|
+
DataSetGeneratorError: If required API key is missing
|
|
695
|
+
"""
|
|
696
|
+
api_key, client_kwargs = _get_openai_client_config(
|
|
697
|
+
api_key_env_var, default_base_url, dummy_key, **kwargs
|
|
698
|
+
)
|
|
699
|
+
return openai.AsyncOpenAI(api_key=api_key, **client_kwargs)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _is_incompatible_object(schema: dict) -> bool:
|
|
703
|
+
"""Check if a schema represents an object incompatible with Gemini.
|
|
704
|
+
|
|
705
|
+
Gemini rejects objects with no properties defined (like dict[str, Any]).
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
schema: JSON schema dictionary
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
True if the schema is an incompatible object type
|
|
712
|
+
"""
|
|
713
|
+
return schema.get("type") == "object" and "properties" not in schema and "$ref" not in schema
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def _is_incompatible_array(schema: dict) -> bool:
|
|
717
|
+
"""Check if a schema represents an array with incompatible items.
|
|
718
|
+
|
|
719
|
+
Arrays with object items that have no properties (like list[dict[str, Any]])
|
|
720
|
+
are incompatible with Gemini.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
schema: JSON schema dictionary
|
|
724
|
+
|
|
725
|
+
Returns:
|
|
726
|
+
True if the schema is an incompatible array type
|
|
727
|
+
"""
|
|
728
|
+
if schema.get("type") != "array" or "items" not in schema:
|
|
729
|
+
return False
|
|
730
|
+
|
|
731
|
+
items = schema["items"]
|
|
732
|
+
return isinstance(items, dict) and _is_incompatible_object(items)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def _inline_refs(schema_dict: dict, defs: dict | None = None) -> dict:
|
|
736
|
+
"""
|
|
737
|
+
Recursively inline $ref references in a JSON schema.
|
|
738
|
+
|
|
739
|
+
Gemini's structured output can be unreliable with $ref. This function
|
|
740
|
+
resolves all references by inlining the definitions directly.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
schema_dict: JSON schema dictionary
|
|
744
|
+
defs: The $defs dictionary from the root schema
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
Modified schema dict with all $ref inlined
|
|
748
|
+
"""
|
|
749
|
+
if not isinstance(schema_dict, dict):
|
|
750
|
+
return schema_dict
|
|
751
|
+
|
|
752
|
+
# Use provided defs or extract from schema (guaranteed to be dict after this)
|
|
753
|
+
resolved_defs: dict = defs if defs is not None else schema_dict.get("$defs", {})
|
|
754
|
+
|
|
755
|
+
# Handle $ref - inline the referenced definition
|
|
756
|
+
if "$ref" in schema_dict:
|
|
757
|
+
ref_path = schema_dict["$ref"]
|
|
758
|
+
# Parse reference like "#/$defs/GraphSubtopic"
|
|
759
|
+
if ref_path.startswith("#/$defs/"):
|
|
760
|
+
def_name = ref_path[len("#/$defs/") :]
|
|
761
|
+
if def_name in resolved_defs:
|
|
762
|
+
# Return a copy of the definition with refs inlined
|
|
763
|
+
inlined = _inline_refs(dict(resolved_defs[def_name]), resolved_defs)
|
|
764
|
+
# Preserve any other properties from the original (like description)
|
|
765
|
+
for key, value in schema_dict.items():
|
|
766
|
+
if key != "$ref":
|
|
767
|
+
inlined[key] = value
|
|
768
|
+
return inlined
|
|
769
|
+
# If we can't resolve, return as-is
|
|
770
|
+
return schema_dict
|
|
771
|
+
|
|
772
|
+
# Create a copy to modify
|
|
773
|
+
result: dict[str, Any] = {}
|
|
774
|
+
|
|
775
|
+
for key, value in schema_dict.items():
|
|
776
|
+
if key == "$defs":
|
|
777
|
+
# Skip $defs since we're inlining them
|
|
778
|
+
continue
|
|
779
|
+
if key == "properties" and isinstance(value, dict):
|
|
780
|
+
result[key] = {
|
|
781
|
+
prop_name: _inline_refs(prop_schema, resolved_defs)
|
|
782
|
+
for prop_name, prop_schema in value.items()
|
|
783
|
+
}
|
|
784
|
+
elif key == "items" and isinstance(value, dict):
|
|
785
|
+
result[key] = _inline_refs(value, resolved_defs)
|
|
786
|
+
elif key in _UNION_KEYS and isinstance(value, list):
|
|
787
|
+
result[key] = [_inline_refs(variant, resolved_defs) for variant in value]
|
|
788
|
+
else:
|
|
789
|
+
result[key] = value
|
|
790
|
+
|
|
791
|
+
return result
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def _strip_additional_properties(schema_dict: dict) -> dict:
|
|
795
|
+
"""
|
|
796
|
+
Recursively remove additionalProperties from JSON schema and handle dict[str, Any] fields.
|
|
797
|
+
|
|
798
|
+
Gemini doesn't support:
|
|
799
|
+
1. additionalProperties field in JSON schemas
|
|
800
|
+
2. Objects with no properties defined (e.g., dict[str, Any])
|
|
801
|
+
3. Arrays whose items are objects with no properties (e.g., list[dict[str, Any]])
|
|
802
|
+
|
|
803
|
+
Fields like dict[str, Any] have additionalProperties: true and no properties defined.
|
|
804
|
+
Gemini requires that object-type fields must have properties, so we exclude these
|
|
805
|
+
fields from the schema entirely.
|
|
806
|
+
|
|
807
|
+
Note: This function preserves $defs and does NOT inline refs. Use _inline_refs
|
|
808
|
+
separately if ref inlining is needed.
|
|
809
|
+
|
|
810
|
+
Args:
|
|
811
|
+
schema_dict: JSON schema dictionary
|
|
812
|
+
|
|
813
|
+
Returns:
|
|
814
|
+
Modified schema dict without additionalProperties and dict[str, Any] fields
|
|
815
|
+
"""
|
|
816
|
+
if not isinstance(schema_dict, dict):
|
|
817
|
+
return schema_dict
|
|
818
|
+
|
|
819
|
+
# For Gemini, identify and remove incompatible fields
|
|
820
|
+
if "properties" in schema_dict:
|
|
821
|
+
properties_to_remove = []
|
|
822
|
+
for prop_name, prop_schema in schema_dict["properties"].items():
|
|
823
|
+
if not isinstance(prop_schema, dict):
|
|
824
|
+
continue
|
|
825
|
+
|
|
826
|
+
# Check for direct incompatibilities
|
|
827
|
+
if prop_schema.get("additionalProperties") is True:
|
|
828
|
+
# Remove fields with additionalProperties: true (e.g., dict[str, Any])
|
|
829
|
+
properties_to_remove.append(prop_name)
|
|
830
|
+
elif _is_incompatible_object(prop_schema):
|
|
831
|
+
# Remove objects with no properties
|
|
832
|
+
properties_to_remove.append(prop_name)
|
|
833
|
+
elif _is_incompatible_array(prop_schema):
|
|
834
|
+
# Remove arrays with incompatible items (e.g., list[dict[str, Any]])
|
|
835
|
+
properties_to_remove.append(prop_name)
|
|
836
|
+
elif "anyOf" in prop_schema and any(
|
|
837
|
+
isinstance(variant, dict)
|
|
838
|
+
and (_is_incompatible_object(variant) or _is_incompatible_array(variant))
|
|
839
|
+
for variant in prop_schema["anyOf"]
|
|
840
|
+
):
|
|
841
|
+
# Check if anyOf contains incompatible variants
|
|
842
|
+
properties_to_remove.append(prop_name)
|
|
843
|
+
|
|
844
|
+
# Remove incompatible properties from the schema
|
|
845
|
+
for prop_name in properties_to_remove:
|
|
846
|
+
del schema_dict["properties"][prop_name]
|
|
847
|
+
|
|
848
|
+
# Update required array to exclude removed properties
|
|
849
|
+
if "required" in schema_dict:
|
|
850
|
+
schema_dict["required"] = [
|
|
851
|
+
r for r in schema_dict["required"] if r not in properties_to_remove
|
|
852
|
+
]
|
|
853
|
+
|
|
854
|
+
# Remove additionalProperties from current level
|
|
855
|
+
schema_dict.pop("additionalProperties", None)
|
|
856
|
+
|
|
857
|
+
# Recursively process nested structures
|
|
858
|
+
if "$defs" in schema_dict:
|
|
859
|
+
for def_name, def_schema in schema_dict["$defs"].items():
|
|
860
|
+
schema_dict["$defs"][def_name] = _strip_additional_properties(def_schema)
|
|
861
|
+
|
|
862
|
+
# Process properties recursively
|
|
863
|
+
if "properties" in schema_dict:
|
|
864
|
+
for prop_name, prop_schema in schema_dict["properties"].items():
|
|
865
|
+
schema_dict["properties"][prop_name] = _strip_additional_properties(prop_schema)
|
|
866
|
+
|
|
867
|
+
# Process items (for arrays)
|
|
868
|
+
if "items" in schema_dict:
|
|
869
|
+
schema_dict["items"] = _strip_additional_properties(schema_dict["items"])
|
|
870
|
+
|
|
871
|
+
# Process union types (anyOf, oneOf, allOf)
|
|
872
|
+
for union_key in _UNION_KEYS:
|
|
873
|
+
if union_key in schema_dict:
|
|
874
|
+
schema_dict[union_key] = [
|
|
875
|
+
_strip_additional_properties(variant) for variant in schema_dict[union_key]
|
|
876
|
+
]
|
|
877
|
+
|
|
878
|
+
return schema_dict
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
def _ensure_anthropic_strict_mode_compliance(schema_dict: dict) -> dict:
|
|
882
|
+
"""Ensure schema complies with Anthropic's structured outputs requirements.
|
|
883
|
+
|
|
884
|
+
Anthropic's structured outputs require:
|
|
885
|
+
1. For objects, 'additionalProperties' must be explicitly set to false
|
|
886
|
+
2. ALL properties must be in the 'required' array (no optional fields allowed)
|
|
887
|
+
3. No fields with additionalProperties: true (incompatible with strict mode)
|
|
888
|
+
|
|
889
|
+
This is similar to OpenAI's strict mode requirements.
|
|
890
|
+
|
|
891
|
+
Args:
|
|
892
|
+
schema_dict: JSON schema dictionary
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
Modified schema dict meeting Anthropic structured outputs requirements
|
|
896
|
+
"""
|
|
897
|
+
# Reuse OpenAI's strict mode compliance function as requirements are similar
|
|
898
|
+
return _ensure_openai_strict_mode_compliance(schema_dict)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def _strip_ref_sibling_keywords(schema_dict: dict) -> dict:
|
|
902
|
+
"""
|
|
903
|
+
Remove sibling keywords from $ref properties.
|
|
904
|
+
|
|
905
|
+
OpenAI's strict mode doesn't allow $ref to have additional keywords like 'description'.
|
|
906
|
+
When Pydantic generates schemas, it adds 'description' alongside '$ref' for nested models
|
|
907
|
+
that have Field(description=...). This function strips those extra keywords.
|
|
908
|
+
|
|
909
|
+
Args:
|
|
910
|
+
schema_dict: JSON schema dictionary
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
Modified schema dict with $ref siblings removed
|
|
914
|
+
"""
|
|
915
|
+
if not isinstance(schema_dict, dict):
|
|
916
|
+
return schema_dict
|
|
917
|
+
|
|
918
|
+
# If this dict has $ref, remove all sibling keywords except $ref itself
|
|
919
|
+
if "$ref" in schema_dict:
|
|
920
|
+
return {"$ref": schema_dict["$ref"]}
|
|
921
|
+
|
|
922
|
+
# Process properties recursively
|
|
923
|
+
if "properties" in schema_dict:
|
|
924
|
+
for prop_name, prop_schema in schema_dict["properties"].items():
|
|
925
|
+
schema_dict["properties"][prop_name] = _strip_ref_sibling_keywords(prop_schema)
|
|
926
|
+
|
|
927
|
+
# Process $defs recursively
|
|
928
|
+
if "$defs" in schema_dict:
|
|
929
|
+
for def_name, def_schema in schema_dict["$defs"].items():
|
|
930
|
+
schema_dict["$defs"][def_name] = _strip_ref_sibling_keywords(def_schema)
|
|
931
|
+
|
|
932
|
+
# Process items (for arrays)
|
|
933
|
+
if "items" in schema_dict:
|
|
934
|
+
schema_dict["items"] = _strip_ref_sibling_keywords(schema_dict["items"])
|
|
935
|
+
|
|
936
|
+
# Process union types (anyOf, oneOf, allOf)
|
|
937
|
+
for union_key in _UNION_KEYS:
|
|
938
|
+
if union_key in schema_dict:
|
|
939
|
+
schema_dict[union_key] = [
|
|
940
|
+
_strip_ref_sibling_keywords(variant) for variant in schema_dict[union_key]
|
|
941
|
+
]
|
|
942
|
+
|
|
943
|
+
return schema_dict
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
def _ensure_openai_strict_mode_compliance(schema_dict: dict) -> dict:
|
|
947
|
+
"""
|
|
948
|
+
Ensure schema complies with OpenAI's strict mode requirements.
|
|
949
|
+
|
|
950
|
+
OpenAI's strict mode requires:
|
|
951
|
+
1. For objects, 'additionalProperties' must be explicitly set to false
|
|
952
|
+
2. ALL properties must be in the 'required' array (no optional fields allowed)
|
|
953
|
+
3. No fields with additionalProperties: true (incompatible with strict mode)
|
|
954
|
+
4. $ref cannot have sibling keywords like 'description'
|
|
955
|
+
|
|
956
|
+
Fields like dict[str, Any] have additionalProperties: true and cannot be
|
|
957
|
+
represented in strict mode, so they are excluded from the schema entirely.
|
|
958
|
+
|
|
959
|
+
Args:
|
|
960
|
+
schema_dict: JSON schema dictionary
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
Modified schema dict meeting OpenAI strict mode requirements
|
|
964
|
+
"""
|
|
965
|
+
if not isinstance(schema_dict, dict):
|
|
966
|
+
return schema_dict
|
|
967
|
+
|
|
968
|
+
# First, strip sibling keywords from $ref properties
|
|
969
|
+
schema_dict = _strip_ref_sibling_keywords(schema_dict)
|
|
970
|
+
|
|
971
|
+
# For OpenAI strict mode, identify and remove dict[str, Any] fields
|
|
972
|
+
# These have additionalProperties: true which is incompatible with strict mode
|
|
973
|
+
if "properties" in schema_dict:
|
|
974
|
+
properties_to_remove = []
|
|
975
|
+
for prop_name, prop_schema in schema_dict["properties"].items():
|
|
976
|
+
# Check for direct additionalProperties: true
|
|
977
|
+
if isinstance(prop_schema, dict) and prop_schema.get("additionalProperties") is True:
|
|
978
|
+
# Remove fields with additionalProperties: true (e.g., dict[str, Any])
|
|
979
|
+
properties_to_remove.append(prop_name)
|
|
980
|
+
# Check for anyOf containing object variants with additionalProperties: true
|
|
981
|
+
elif isinstance(prop_schema, dict) and "anyOf" in prop_schema:
|
|
982
|
+
for variant in prop_schema["anyOf"]:
|
|
983
|
+
if isinstance(variant, dict) and variant.get("additionalProperties") is True:
|
|
984
|
+
# This anyOf contains an incompatible object variant - remove entire field
|
|
985
|
+
properties_to_remove.append(prop_name)
|
|
986
|
+
break
|
|
987
|
+
|
|
988
|
+
# Remove incompatible properties from the schema
|
|
989
|
+
for prop_name in properties_to_remove:
|
|
990
|
+
del schema_dict["properties"][prop_name]
|
|
991
|
+
|
|
992
|
+
# Update required array to exclude removed properties
|
|
993
|
+
if "required" in schema_dict:
|
|
994
|
+
schema_dict["required"] = [
|
|
995
|
+
r for r in schema_dict["required"] if r not in properties_to_remove
|
|
996
|
+
]
|
|
997
|
+
|
|
998
|
+
# After removing incompatible fields, ensure ALL remaining properties are required
|
|
999
|
+
# OpenAI strict mode doesn't allow optional fields
|
|
1000
|
+
property_keys = list(schema_dict["properties"].keys())
|
|
1001
|
+
schema_dict["required"] = property_keys
|
|
1002
|
+
schema_dict["additionalProperties"] = False
|
|
1003
|
+
|
|
1004
|
+
# For all objects (including those without properties), set additionalProperties to false
|
|
1005
|
+
if schema_dict.get("type") == "object":
|
|
1006
|
+
schema_dict["additionalProperties"] = False
|
|
1007
|
+
|
|
1008
|
+
# Recursively process nested structures
|
|
1009
|
+
if "$defs" in schema_dict:
|
|
1010
|
+
for def_name, def_schema in schema_dict["$defs"].items():
|
|
1011
|
+
schema_dict["$defs"][def_name] = _ensure_openai_strict_mode_compliance(def_schema)
|
|
1012
|
+
|
|
1013
|
+
# Process properties recursively
|
|
1014
|
+
if "properties" in schema_dict:
|
|
1015
|
+
for prop_name, prop_schema in schema_dict["properties"].items():
|
|
1016
|
+
schema_dict["properties"][prop_name] = _ensure_openai_strict_mode_compliance(
|
|
1017
|
+
prop_schema
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
# Process items (for arrays)
|
|
1021
|
+
if "items" in schema_dict:
|
|
1022
|
+
schema_dict["items"] = _ensure_openai_strict_mode_compliance(schema_dict["items"])
|
|
1023
|
+
|
|
1024
|
+
# Process union types (anyOf, oneOf, allOf)
|
|
1025
|
+
# This must be done to handle nested structures like list[dict[str, Any]] | None
|
|
1026
|
+
# where the dict[str, Any] is inside an array variant
|
|
1027
|
+
for union_key in _UNION_KEYS:
|
|
1028
|
+
if union_key in schema_dict:
|
|
1029
|
+
schema_dict[union_key] = [
|
|
1030
|
+
_ensure_openai_strict_mode_compliance(variant) for variant in schema_dict[union_key]
|
|
1031
|
+
]
|
|
1032
|
+
|
|
1033
|
+
return schema_dict
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@lru_cache(maxsize=128)
|
|
1037
|
+
def _get_cached_openai_schema(schema: type[BaseModel]) -> type[BaseModel]:
|
|
1038
|
+
"""
|
|
1039
|
+
Get or create a cached OpenAI-compatible version of a Pydantic schema.
|
|
1040
|
+
|
|
1041
|
+
This function caches transformed schemas to avoid repeated processing
|
|
1042
|
+
of the same Pydantic model during multiple generation calls.
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
schema: Original Pydantic model
|
|
1046
|
+
|
|
1047
|
+
Returns:
|
|
1048
|
+
Cached wrapper model that generates OpenAI-compatible schemas
|
|
1049
|
+
"""
|
|
1050
|
+
|
|
1051
|
+
# Create a new model class that overrides model_json_schema
|
|
1052
|
+
class OpenAICompatModel(schema): # type: ignore[misc,valid-type]
|
|
1053
|
+
@classmethod
|
|
1054
|
+
def model_json_schema(cls, **kwargs):
|
|
1055
|
+
# Get the original schema
|
|
1056
|
+
original_schema = super().model_json_schema(**kwargs)
|
|
1057
|
+
# Ensure OpenAI strict mode compliance
|
|
1058
|
+
return _ensure_openai_strict_mode_compliance(original_schema)
|
|
1059
|
+
|
|
1060
|
+
# Set name and docstring
|
|
1061
|
+
OpenAICompatModel.__name__ = f"{schema.__name__}OpenAICompat"
|
|
1062
|
+
OpenAICompatModel.__doc__ = schema.__doc__
|
|
1063
|
+
|
|
1064
|
+
# Rebuild model to resolve forward references (e.g., PendingToolCall in AgentStep)
|
|
1065
|
+
OpenAICompatModel.model_rebuild()
|
|
1066
|
+
|
|
1067
|
+
return OpenAICompatModel
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
def _create_openai_compatible_schema(schema: type[BaseModel]) -> type[BaseModel]:
|
|
1071
|
+
"""
|
|
1072
|
+
Create an OpenAI-compatible version of a Pydantic schema.
|
|
1073
|
+
|
|
1074
|
+
OpenAI's strict mode requires that all objects have 'additionalProperties: false'.
|
|
1075
|
+
This function ensures the schema meets those requirements while preserving
|
|
1076
|
+
Pydantic's correct handling of required vs optional fields.
|
|
1077
|
+
|
|
1078
|
+
Uses caching to avoid repeated transformation of the same schema.
|
|
1079
|
+
|
|
1080
|
+
Args:
|
|
1081
|
+
schema: Original Pydantic model
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
Wrapper model that generates OpenAI-compatible schemas
|
|
1085
|
+
"""
|
|
1086
|
+
return _get_cached_openai_schema(schema)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def make_outlines_model(provider: str, model_name: str, **kwargs) -> Any:
|
|
1090
|
+
"""Create an Outlines model for the specified provider and model.
|
|
1091
|
+
|
|
1092
|
+
Args:
|
|
1093
|
+
provider: Provider name (openai, anthropic, gemini, ollama)
|
|
1094
|
+
model_name: Model identifier
|
|
1095
|
+
**kwargs: Additional parameters passed to the client
|
|
1096
|
+
|
|
1097
|
+
Returns:
|
|
1098
|
+
Outlines model instance
|
|
1099
|
+
|
|
1100
|
+
Raises:
|
|
1101
|
+
DataSetGeneratorError: If provider is unsupported or configuration fails
|
|
1102
|
+
"""
|
|
1103
|
+
try:
|
|
1104
|
+
if provider == "openai":
|
|
1105
|
+
client = _create_openai_compatible_client(
|
|
1106
|
+
api_key_env_var="OPENAI_API_KEY",
|
|
1107
|
+
default_base_url=None, # Use OpenAI's default
|
|
1108
|
+
**kwargs,
|
|
1109
|
+
)
|
|
1110
|
+
return outlines.from_openai(client, model_name)
|
|
1111
|
+
|
|
1112
|
+
if provider == "ollama":
|
|
1113
|
+
client = _create_openai_compatible_client(
|
|
1114
|
+
api_key_env_var=None, # No API key required
|
|
1115
|
+
default_base_url="http://localhost:11434/v1",
|
|
1116
|
+
dummy_key="ollama",
|
|
1117
|
+
**kwargs,
|
|
1118
|
+
)
|
|
1119
|
+
return outlines.from_openai(client, model_name)
|
|
1120
|
+
|
|
1121
|
+
if provider == "openrouter":
|
|
1122
|
+
client = _create_openai_compatible_client(
|
|
1123
|
+
api_key_env_var="OPENROUTER_API_KEY",
|
|
1124
|
+
default_base_url="https://openrouter.ai/api/v1",
|
|
1125
|
+
**kwargs,
|
|
1126
|
+
)
|
|
1127
|
+
return outlines.from_openai(client, model_name)
|
|
1128
|
+
|
|
1129
|
+
if provider == "anthropic":
|
|
1130
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
1131
|
+
if not api_key:
|
|
1132
|
+
_raise_api_key_error("ANTHROPIC_API_KEY")
|
|
1133
|
+
# Use native Anthropic structured outputs API
|
|
1134
|
+
client = anthropic.AsyncAnthropic(api_key=api_key, **kwargs)
|
|
1135
|
+
return AnthropicModel(client, model_name)
|
|
1136
|
+
|
|
1137
|
+
if provider == "gemini":
|
|
1138
|
+
api_key = _get_gemini_api_key()
|
|
1139
|
+
# Use direct Gemini API instead of Outlines for better structured output reliability
|
|
1140
|
+
client = genai.Client(api_key=api_key)
|
|
1141
|
+
return GeminiModel(client, model_name)
|
|
1142
|
+
|
|
1143
|
+
_raise_unsupported_provider_error(provider)
|
|
1144
|
+
|
|
1145
|
+
except DataSetGeneratorError:
|
|
1146
|
+
# Re-raise our own errors (like missing API keys)
|
|
1147
|
+
raise
|
|
1148
|
+
except Exception as e:
|
|
1149
|
+
# Use the organized error handler
|
|
1150
|
+
raise handle_provider_error(e, provider, model_name) from e
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def make_async_outlines_model(provider: str, model_name: str, **kwargs) -> Any | None:
|
|
1154
|
+
"""Create an async Outlines model when the provider supports it.
|
|
1155
|
+
|
|
1156
|
+
Returns ``None`` for providers without async-capable clients.
|
|
1157
|
+
"""
|
|
1158
|
+
|
|
1159
|
+
try:
|
|
1160
|
+
if provider == "openai":
|
|
1161
|
+
client = _create_async_openai_compatible_client(
|
|
1162
|
+
api_key_env_var="OPENAI_API_KEY",
|
|
1163
|
+
default_base_url=None, # Use OpenAI's default
|
|
1164
|
+
**kwargs,
|
|
1165
|
+
)
|
|
1166
|
+
return outlines.from_openai(client, model_name)
|
|
1167
|
+
|
|
1168
|
+
if provider == "ollama":
|
|
1169
|
+
client = _create_async_openai_compatible_client(
|
|
1170
|
+
api_key_env_var=None, # No API key required
|
|
1171
|
+
default_base_url="http://localhost:11434/v1",
|
|
1172
|
+
dummy_key="ollama",
|
|
1173
|
+
**kwargs,
|
|
1174
|
+
)
|
|
1175
|
+
return outlines.from_openai(client, model_name)
|
|
1176
|
+
|
|
1177
|
+
if provider == "openrouter":
|
|
1178
|
+
client = _create_async_openai_compatible_client(
|
|
1179
|
+
api_key_env_var="OPENROUTER_API_KEY",
|
|
1180
|
+
default_base_url="https://openrouter.ai/api/v1",
|
|
1181
|
+
**kwargs,
|
|
1182
|
+
)
|
|
1183
|
+
return outlines.from_openai(client, model_name)
|
|
1184
|
+
|
|
1185
|
+
if provider == "anthropic":
|
|
1186
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
1187
|
+
if not api_key:
|
|
1188
|
+
_raise_api_key_error("ANTHROPIC_API_KEY")
|
|
1189
|
+
# Use native Anthropic structured outputs API
|
|
1190
|
+
client = anthropic.AsyncAnthropic(api_key=api_key, **kwargs)
|
|
1191
|
+
return AnthropicModel(client, model_name)
|
|
1192
|
+
|
|
1193
|
+
if provider == "gemini":
|
|
1194
|
+
api_key = _get_gemini_api_key()
|
|
1195
|
+
# Use direct async Gemini API for better structured output reliability
|
|
1196
|
+
client = genai.Client(api_key=api_key)
|
|
1197
|
+
return GeminiModel(client, model_name)
|
|
1198
|
+
|
|
1199
|
+
except DataSetGeneratorError:
|
|
1200
|
+
raise
|
|
1201
|
+
except Exception as e:
|
|
1202
|
+
raise handle_provider_error(e, provider, model_name) from e
|
|
1203
|
+
|
|
1204
|
+
# Outlines does not currently expose async structured generation wrappers
|
|
1205
|
+
# for the remaining providers. Fallback to synchronous execution later.
|
|
1206
|
+
return None
|