DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1206 @@
1
+ import asyncio
2
+ import os
3
+
4
+ from functools import lru_cache
5
+ from typing import Any
6
+
7
+ import anthropic
8
+ import openai
9
+ import outlines
10
+
11
+ from google import genai
12
+ from google.genai import types as genai_types
13
+ from pydantic import BaseModel
14
+
15
+ from ..exceptions import DataSetGeneratorError
16
+ from .errors import handle_provider_error
17
+ from .rate_limit_config import (
18
+ RateLimitConfig,
19
+ create_rate_limit_config,
20
+ get_default_rate_limit_config,
21
+ )
22
+ from .retry_handler import RetryHandler, retry_with_backoff, retry_with_backoff_async
23
+
24
+ # JSON Schema union type keys that need recursive processing
25
+ _UNION_KEYS = ("anyOf", "oneOf", "allOf")
26
+
27
+ # Provider-specific parameter mappings
28
+ # Maps generic parameter names to provider-specific equivalents
29
+ # Format: {provider: {generic_name: provider_specific_name}}
30
+ _PROVIDER_PARAM_MAPPINGS: dict[str, dict[str, str]] = {
31
+ "gemini": {
32
+ "max_tokens": "max_output_tokens",
33
+ },
34
+ # Add other provider mappings as needed
35
+ # "anthropic": {"some_param": "anthropic_param"},
36
+ }
37
+
38
+
39
+ class LLMClient:
40
+ """Wrapper for Outlines models with retry logic and error handling."""
41
+
42
+ def __init__(
43
+ self,
44
+ provider: str,
45
+ model_name: str,
46
+ rate_limit_config: RateLimitConfig | dict | None = None,
47
+ **kwargs,
48
+ ):
49
+ """Initialize LLM client.
50
+
51
+ Args:
52
+ provider: Provider name
53
+ model_name: Model identifier
54
+ rate_limit_config: Rate limiting configuration (None uses provider defaults)
55
+ **kwargs: Additional client configuration
56
+ """
57
+ self.provider = provider
58
+ self.model_name = model_name
59
+ self._client_kwargs = kwargs # Store for lazy async model initialization
60
+
61
+ # Initialize rate limiting
62
+ if isinstance(rate_limit_config, dict):
63
+ self.rate_limit_config = create_rate_limit_config(provider, rate_limit_config)
64
+ elif rate_limit_config is None:
65
+ self.rate_limit_config = get_default_rate_limit_config(provider)
66
+ else:
67
+ self.rate_limit_config = rate_limit_config
68
+
69
+ self.retry_handler = RetryHandler(self.rate_limit_config, provider)
70
+
71
+ self.model: Any = make_outlines_model(provider, model_name, **kwargs)
72
+ # Lazy-initialize async_model only when needed
73
+ self._async_model: Any | None = None
74
+ self._async_model_initialized: bool = False
75
+ if self.model is None:
76
+ msg = f"Failed to create model for {provider}/{model_name}"
77
+ raise DataSetGeneratorError(msg)
78
+
79
+ @property
80
+ def async_model(self) -> Any | None:
81
+ """Lazily initialize and return the async model.
82
+
83
+ The async model is only created when first accessed, reducing memory
84
+ and connection overhead when only sync operations are used.
85
+ """
86
+ if not self._async_model_initialized:
87
+ self._async_model = make_async_outlines_model(
88
+ self.provider, self.model_name, **self._client_kwargs
89
+ )
90
+ self._async_model_initialized = True
91
+ return self._async_model
92
+
93
+ def generate(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs) -> Any: # noqa: ARG002
94
+ """Generate structured output using the provided schema.
95
+
96
+ Args:
97
+ prompt: Input prompt
98
+ schema: Pydantic model or other schema type
99
+ max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
100
+ **kwargs: Additional generation parameters
101
+
102
+ Returns:
103
+ Generated output matching the schema
104
+
105
+ Raises:
106
+ DataSetGeneratorError: If generation fails after all retries
107
+
108
+ Note:
109
+ The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
110
+ If provided, it will be ignored in favor of the configured retry handler.
111
+ """
112
+ return self._generate_with_retry(prompt, schema, **kwargs)
113
+
114
+ # Providers that only support async generation (native API, not Outlines)
115
+ _ASYNC_ONLY_PROVIDERS = frozenset({"anthropic", "gemini"})
116
+
117
+ @retry_with_backoff
118
+ def _generate_with_retry(self, prompt: str, schema: Any, **kwargs) -> Any:
119
+ """Internal method that performs actual generation with retry wrapper.
120
+
121
+ This method is decorated with retry_with_backoff to handle rate limits
122
+ and transient errors automatically.
123
+ """
124
+ # Check for async-only providers
125
+ if self.provider in self._ASYNC_ONLY_PROVIDERS:
126
+ raise DataSetGeneratorError(
127
+ f"Synchronous generation is not supported for {self.provider}. "
128
+ f"Use generate_async() instead.",
129
+ context={"provider": self.provider},
130
+ )
131
+
132
+ # Convert provider-specific parameters
133
+ kwargs = self._convert_generation_params(**kwargs)
134
+
135
+ # For OpenAI and OpenRouter, ensure all properties are in required array
136
+ generation_schema = schema
137
+ if (
138
+ self.provider in ("openai", "openrouter")
139
+ and isinstance(schema, type)
140
+ and issubclass(schema, BaseModel)
141
+ ):
142
+ generation_schema = _create_openai_compatible_schema(schema)
143
+
144
+ # Generate JSON string using the model
145
+ json_output = self.model(prompt, generation_schema, **kwargs)
146
+
147
+ # Parse and validate the JSON response with the ORIGINAL schema
148
+ # This ensures we still get proper validation
149
+ try:
150
+ return schema.model_validate_json(json_output)
151
+ except Exception as e:
152
+ raise DataSetGeneratorError(
153
+ f"Generation validation failed: {e}",
154
+ context={"raw_content": json_output},
155
+ ) from e
156
+
157
+ async def generate_async(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs) -> Any: # noqa: ARG002
158
+ """Asynchronously generate structured output using provider async clients.
159
+
160
+ Args:
161
+ prompt: Input prompt
162
+ schema: Pydantic model or other schema type
163
+ max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
164
+ **kwargs: Additional generation parameters
165
+
166
+ Returns:
167
+ Generated output matching the schema
168
+
169
+ Raises:
170
+ DataSetGeneratorError: If generation fails after all retries
171
+
172
+ Note:
173
+ The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
174
+ If provided, it will be ignored in favor of the configured retry handler.
175
+ """
176
+ if self.async_model is None:
177
+ # Fallback to running the synchronous path in a worker thread
178
+ return await asyncio.to_thread(self.generate, prompt, schema, **kwargs)
179
+
180
+ return await self._generate_async_with_retry(prompt, schema, **kwargs)
181
+
182
+ @retry_with_backoff_async
183
+ async def _generate_async_with_retry(self, prompt: str, schema: Any, **kwargs) -> Any:
184
+ """Internal async method that performs actual generation with retry wrapper.
185
+
186
+ This method is decorated with retry_with_backoff_async to handle rate limits
187
+ and transient errors automatically.
188
+ """
189
+ kwargs = self._convert_generation_params(**kwargs)
190
+
191
+ # Native providers (Anthropic, Gemini) handle schema transformation internally
192
+ # For OpenAI and OpenRouter, ensure all properties are in required array
193
+ generation_schema = schema
194
+ if self.provider in self._ASYNC_ONLY_PROVIDERS:
195
+ # Native model classes handle schema transformation internally
196
+ pass
197
+ elif (
198
+ self.provider in ("openai", "openrouter")
199
+ and isinstance(schema, type)
200
+ and issubclass(schema, BaseModel)
201
+ ):
202
+ generation_schema = _create_openai_compatible_schema(schema)
203
+
204
+ # Ensure we have an async model; if not, fall back to running the sync path
205
+ async_model = self.async_model
206
+ if async_model is None:
207
+ # Note: This will raise an error for async-only providers
208
+ return await asyncio.to_thread(self.generate, prompt, schema, **kwargs)
209
+
210
+ # Call the async model (guaranteed non-None by check above)
211
+ json_output = await async_model(prompt, generation_schema, **kwargs)
212
+
213
+ # Validate with original schema to ensure proper validation
214
+ try:
215
+ return schema.model_validate_json(json_output)
216
+ except Exception as e:
217
+ raise DataSetGeneratorError(
218
+ f"Async generation validation failed: {e}",
219
+ context={"raw_content": json_output},
220
+ ) from e
221
+
222
+ async def generate_async_stream(self, prompt: str, schema: Any, max_retries: int = 3, **kwargs): # noqa: ARG002
223
+ """Asynchronously generate structured output with streaming text chunks.
224
+
225
+ This method streams the LLM's output text as it's generated, then returns
226
+ the final parsed Pydantic model. It yields tuples of (chunk, result) where:
227
+ - During streaming: (text_chunk, None)
228
+ - When complete: (None, final_pydantic_model)
229
+
230
+ Args:
231
+ prompt: Input prompt
232
+ schema: Pydantic model or other schema type
233
+ max_retries: Maximum number of retry attempts (deprecated, use rate_limit_config)
234
+ **kwargs: Additional generation parameters
235
+
236
+ Yields:
237
+ tuple[str | None, Any | None]:
238
+ - (chunk, None) during streaming
239
+ - (None, model) when generation is complete
240
+
241
+ Raises:
242
+ DataSetGeneratorError: If generation fails after all retries
243
+
244
+ Note:
245
+ The max_retries parameter is deprecated. Use rate_limit_config in __init__ instead.
246
+
247
+ Example:
248
+ >>> async for chunk, result in client.generate_async_stream(prompt, MyModel):
249
+ ... if chunk:
250
+ ... print(chunk, end='', flush=True) # Display streaming text
251
+ ... if result:
252
+ ... return result # Final parsed model
253
+ """
254
+ # Call streaming generation directly (retry decorator doesn't work with generators)
255
+ kwargs = self._convert_generation_params(**kwargs)
256
+
257
+ # Native providers (Anthropic, Gemini) handle schema transformation internally
258
+ # For OpenAI and OpenRouter, ensure all properties are in required array
259
+ generation_schema = schema
260
+ if self.provider in self._ASYNC_ONLY_PROVIDERS:
261
+ # Native model classes handle schema transformation internally
262
+ pass
263
+ elif (
264
+ self.provider in ("openai", "openrouter")
265
+ and isinstance(schema, type)
266
+ and issubclass(schema, BaseModel)
267
+ ):
268
+ generation_schema = _create_openai_compatible_schema(schema)
269
+
270
+ # Check if model supports streaming
271
+ async_model = self.async_model or self.model
272
+ if not hasattr(async_model, "generate_stream"):
273
+ # Fallback: no streaming support, yield entire result at once
274
+ result = await self.generate_async(prompt, schema, **kwargs)
275
+ yield (None, result)
276
+ return
277
+
278
+ # Stream generation
279
+ accumulated_text: list[str] = []
280
+ try:
281
+ # For sync models used in async context
282
+ if self.async_model is None:
283
+ # Use asyncio.Queue for true streaming from sync generator
284
+ # This yields chunks as they arrive instead of waiting for all
285
+ queue: asyncio.Queue[str | None] = asyncio.Queue()
286
+ stream_error: list[Exception] = []
287
+
288
+ def _produce_chunks():
289
+ """Run sync generator and put chunks in queue."""
290
+ try:
291
+ for chunk in self.model.generate_stream(
292
+ prompt, generation_schema, **kwargs
293
+ ):
294
+ asyncio.run_coroutine_threadsafe(
295
+ queue.put(chunk), asyncio.get_event_loop()
296
+ )
297
+ except Exception as e:
298
+ stream_error.append(e)
299
+ finally:
300
+ # Signal completion
301
+ asyncio.run_coroutine_threadsafe(queue.put(None), asyncio.get_event_loop())
302
+
303
+ # Start producer in background thread
304
+ loop = asyncio.get_event_loop()
305
+ loop.run_in_executor(None, _produce_chunks)
306
+
307
+ # Consume chunks as they arrive
308
+ while True:
309
+ chunk = await queue.get()
310
+ if chunk is None:
311
+ break
312
+ accumulated_text.append(chunk)
313
+ yield (chunk, None)
314
+
315
+ # Re-raise any error from the producer thread
316
+ if stream_error:
317
+ raise DataSetGeneratorError(
318
+ f"Streaming generation failed in producer: {stream_error[0]}",
319
+ context={
320
+ "raw_content": "".join(accumulated_text) if accumulated_text else None
321
+ },
322
+ ) from stream_error[0]
323
+ else:
324
+ # True async streaming
325
+ stream = self.async_model.generate_stream(prompt, generation_schema, **kwargs)
326
+ async for chunk in stream:
327
+ accumulated_text.append(chunk)
328
+ yield (chunk, None)
329
+
330
+ # Parse accumulated JSON with original schema
331
+ full_text = "".join(accumulated_text)
332
+ result = schema.model_validate_json(full_text)
333
+ yield (None, result)
334
+
335
+ except Exception as e:
336
+ # Wrap and raise error with raw content for debugging
337
+ raw_content = "".join(accumulated_text) if accumulated_text else None
338
+ raise DataSetGeneratorError(
339
+ f"Streaming generation failed: {e}",
340
+ context={"raw_content": raw_content},
341
+ ) from e
342
+
343
+ def _convert_generation_params(self, **kwargs) -> dict:
344
+ """Convert generic parameters to provider-specific ones.
345
+
346
+ Uses the _PROVIDER_PARAM_MAPPINGS static dictionary for extensible
347
+ parameter conversion across different providers.
348
+ """
349
+ mappings = _PROVIDER_PARAM_MAPPINGS.get(self.provider, {})
350
+ for generic_name, provider_name in mappings.items():
351
+ if generic_name in kwargs:
352
+ kwargs[provider_name] = kwargs.pop(generic_name)
353
+
354
+ return kwargs
355
+
356
+ def __repr__(self) -> str:
357
+ return f"LLMClient(provider={self.provider}, model={self.model_name})"
358
+
359
+
360
+ class GeminiModel:
361
+ """Gemini API client using native structured outputs.
362
+
363
+ Uses Gemini's native response_json_schema parameter which is more reliable
364
+ than Outlines' wrapper for structured output generation.
365
+ """
366
+
367
+ def __init__(self, client: genai.Client, model_name: str):
368
+ self.client = client
369
+ self.model_name = model_name
370
+
371
+ def _prepare_json_schema(self, schema: type[BaseModel]) -> dict:
372
+ """Prepare JSON schema for Gemini by stripping incompatible fields.
373
+
374
+ Note: We only strip additionalProperties and incompatible fields here.
375
+ Ref inlining is intentionally NOT done to avoid excessive nesting depth
376
+ that can exceed Gemini's schema limits for deeply nested models.
377
+ """
378
+ json_schema = schema.model_json_schema()
379
+ return _strip_additional_properties(json_schema)
380
+
381
+ async def __call__(
382
+ self,
383
+ prompt: str,
384
+ schema: type[BaseModel],
385
+ **kwargs,
386
+ ) -> str:
387
+ """Asynchronously generate structured output using Gemini's native JSON schema support.
388
+
389
+ Args:
390
+ prompt: The input prompt
391
+ schema: Pydantic model class for structured output
392
+ **kwargs: Additional generation parameters (temperature, max_output_tokens, etc.)
393
+
394
+ Returns:
395
+ JSON string matching the schema
396
+ """
397
+ # Use prepared schema with Gemini-compatible transformations
398
+ prepared_schema = self._prepare_json_schema(schema)
399
+
400
+ try:
401
+ config = genai_types.GenerateContentConfig(
402
+ response_mime_type="application/json",
403
+ response_json_schema=prepared_schema,
404
+ **kwargs,
405
+ )
406
+ except Exception as e:
407
+ raise DataSetGeneratorError(f"Failed to create Gemini config: {e}") from e
408
+
409
+ # Call Gemini API directly using async method
410
+ response = await self.client.aio.models.generate_content(
411
+ model=self.model_name,
412
+ contents=prompt,
413
+ config=config,
414
+ )
415
+
416
+ # Safely check for empty or blocked responses
417
+ # Some SDK versions may return None or have blocked candidates
418
+ if not response.candidates:
419
+ raise DataSetGeneratorError(
420
+ "Gemini returned empty response",
421
+ context={"finish_reason": None},
422
+ )
423
+
424
+ first_candidate = response.candidates[0]
425
+ if first_candidate.content is None or not first_candidate.content.parts:
426
+ raise DataSetGeneratorError(
427
+ "Gemini returned empty response",
428
+ context={"finish_reason": getattr(first_candidate, "finish_reason", None)},
429
+ )
430
+
431
+ if response.text is None:
432
+ raise DataSetGeneratorError("Gemini returned empty response")
433
+
434
+ return response.text
435
+
436
+
437
+ class AnthropicModel:
438
+ """Anthropic API client using native structured outputs.
439
+
440
+ Uses Anthropic's beta structured outputs feature (structured-outputs-2025-11-13)
441
+ which provides guaranteed JSON schema compliance through constrained decoding.
442
+ """
443
+
444
+ # Beta header for structured outputs feature
445
+ STRUCTURED_OUTPUTS_BETA = "structured-outputs-2025-11-13"
446
+
447
+ def __init__(self, client: anthropic.AsyncAnthropic, model_name: str):
448
+ self.client = client
449
+ self.model_name = model_name
450
+
451
+ def _prepare_json_schema(self, schema: type[BaseModel]) -> dict:
452
+ """Prepare JSON schema for Anthropic structured outputs.
453
+
454
+ Anthropic's structured outputs require:
455
+ - additionalProperties: false on all objects
456
+ - All properties in required array
457
+ """
458
+ json_schema = schema.model_json_schema()
459
+ return _ensure_anthropic_strict_mode_compliance(json_schema)
460
+
461
+ async def __call__(
462
+ self,
463
+ prompt: str,
464
+ schema: type[BaseModel],
465
+ **kwargs,
466
+ ) -> str:
467
+ """Asynchronously generate structured output using Anthropic's native JSON schema support.
468
+
469
+ Args:
470
+ prompt: The input prompt
471
+ schema: Pydantic model class for structured output
472
+ **kwargs: Additional generation parameters (temperature, max_tokens, etc.)
473
+
474
+ Returns:
475
+ JSON string matching the schema
476
+ """
477
+ prepared_schema = self._prepare_json_schema(schema)
478
+
479
+ # Extract max_tokens with a sensible default (16384 for structured outputs)
480
+ max_tokens = kwargs.pop("max_tokens", 16384)
481
+
482
+ try:
483
+ response = await self.client.beta.messages.create(
484
+ model=self.model_name,
485
+ max_tokens=max_tokens,
486
+ betas=[self.STRUCTURED_OUTPUTS_BETA],
487
+ messages=[{"role": "user", "content": prompt}],
488
+ output_format={
489
+ "type": "json_schema",
490
+ "schema": prepared_schema,
491
+ },
492
+ **kwargs,
493
+ )
494
+ except anthropic.BadRequestError as e:
495
+ raise DataSetGeneratorError(
496
+ f"Anthropic structured output request failed: {e}",
497
+ context={"schema": prepared_schema},
498
+ ) from e
499
+
500
+ # Check for refusals
501
+ if response.stop_reason == "refusal":
502
+ raise DataSetGeneratorError(
503
+ "Anthropic refused the request for safety reasons",
504
+ context={"stop_reason": response.stop_reason},
505
+ )
506
+
507
+ # Check for max_tokens truncation
508
+ if response.stop_reason == "max_tokens":
509
+ raise DataSetGeneratorError(
510
+ "Anthropic response truncated due to max_tokens limit",
511
+ context={"stop_reason": response.stop_reason},
512
+ )
513
+
514
+ # Extract text from response
515
+ if not response.content:
516
+ raise DataSetGeneratorError("Anthropic returned empty response")
517
+
518
+ text_block = response.content[0]
519
+ if not hasattr(text_block, "text") or not text_block.text:
520
+ raise DataSetGeneratorError("Anthropic returned empty text response")
521
+
522
+ return text_block.text
523
+
524
+
525
+ def _raise_api_key_error(env_var: str) -> None:
526
+ """Raise an error for missing API key."""
527
+ msg = f"{env_var} environment variable not set"
528
+ raise DataSetGeneratorError(msg)
529
+
530
+
531
+ def _get_gemini_api_key() -> str:
532
+ """Retrieve Gemini API key from environment variables.
533
+
534
+ Returns:
535
+ The API key string
536
+
537
+ Raises:
538
+ DataSetGeneratorError: If no API key is found
539
+ """
540
+ for name in ("GOOGLE_API_KEY", "GEMINI_API_KEY"):
541
+ if api_key := os.getenv(name):
542
+ return api_key
543
+ _raise_api_key_error("GOOGLE_API_KEY or GEMINI_API_KEY")
544
+ # This return is never reached but satisfies type checker
545
+ raise AssertionError("unreachable")
546
+
547
+
548
+ # Provider to environment variable mapping
549
+ PROVIDER_API_KEY_MAP: dict[str, list[str]] = {
550
+ "openai": ["OPENAI_API_KEY"],
551
+ "anthropic": ["ANTHROPIC_API_KEY"],
552
+ "gemini": ["GOOGLE_API_KEY", "GEMINI_API_KEY"],
553
+ "openrouter": ["OPENROUTER_API_KEY"],
554
+ "ollama": [], # No API key required
555
+ # Test providers for unit tests, no API key required
556
+ "test": [],
557
+ "override": [],
558
+ }
559
+
560
+
561
+ def validate_provider_api_key(provider: str) -> tuple[bool, str | None]:
562
+ """Validate that the required API key exists for a provider.
563
+
564
+ Args:
565
+ provider: Provider name (openai, anthropic, gemini, ollama, openrouter)
566
+
567
+ Returns:
568
+ Tuple of (is_valid, error_message). If valid, error_message is None.
569
+ """
570
+ env_vars = PROVIDER_API_KEY_MAP.get(provider)
571
+
572
+ if env_vars is None:
573
+ return False, f"Unknown provider: {provider}"
574
+
575
+ # Ollama doesn't need an API key
576
+ if not env_vars:
577
+ return True, None
578
+
579
+ # Check if any of the required env vars are set
580
+ for env_var in env_vars:
581
+ if os.getenv(env_var):
582
+ return True, None
583
+
584
+ # Build helpful error message
585
+ if len(env_vars) == 1:
586
+ return False, f"Missing API key: {env_vars[0]} environment variable is not set"
587
+ var_list = " or ".join(env_vars)
588
+ return False, f"Missing API key: Set {var_list} environment variable"
589
+
590
+
591
+ def get_required_api_key_env_var(provider: str) -> str | None:
592
+ """Get the environment variable name(s) required for a provider.
593
+
594
+ Args:
595
+ provider: Provider name
596
+
597
+ Returns:
598
+ Human-readable string of required env var(s), or None if no key required
599
+ """
600
+ env_vars = PROVIDER_API_KEY_MAP.get(provider)
601
+ if not env_vars:
602
+ return None
603
+ if len(env_vars) == 1:
604
+ return env_vars[0]
605
+ return " or ".join(env_vars)
606
+
607
+
608
+ def _raise_unsupported_provider_error(provider: str) -> None:
609
+ """Raise an error for unsupported provider."""
610
+ msg = f"Unsupported provider: {provider}"
611
+ raise DataSetGeneratorError(msg)
612
+
613
+
614
+ def _get_openai_client_config(
615
+ api_key_env_var: str | None,
616
+ default_base_url: str | None,
617
+ dummy_key: str | None = None,
618
+ **kwargs,
619
+ ) -> tuple[str | None, dict[str, Any]]:
620
+ """Extract common configuration for OpenAI-compatible clients.
621
+
622
+ Args:
623
+ api_key_env_var: Environment variable name for API key (None to skip check)
624
+ default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
625
+ dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
626
+ **kwargs: Additional client configuration (may include base_url override)
627
+
628
+ Returns:
629
+ Tuple of (api_key, client_kwargs) for client initialization
630
+
631
+ Raises:
632
+ DataSetGeneratorError: If required API key is missing
633
+ """
634
+ # Get API key from environment or use dummy key
635
+ if api_key_env_var:
636
+ api_key = os.getenv(api_key_env_var)
637
+ if not api_key:
638
+ _raise_api_key_error(api_key_env_var)
639
+ else:
640
+ api_key = dummy_key
641
+
642
+ # Set up base_url if provided
643
+ client_kwargs = {k: v for k, v in kwargs.items() if k != "base_url"}
644
+ if default_base_url:
645
+ client_kwargs["base_url"] = kwargs.get("base_url", default_base_url)
646
+
647
+ return api_key, client_kwargs
648
+
649
+
650
+ def _create_openai_compatible_client(
651
+ api_key_env_var: str | None,
652
+ default_base_url: str | None,
653
+ dummy_key: str | None = None,
654
+ **kwargs,
655
+ ) -> openai.OpenAI:
656
+ """Create an OpenAI-compatible client for providers that use OpenAI's API format.
657
+
658
+ Args:
659
+ api_key_env_var: Environment variable name for API key (None to skip check)
660
+ default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
661
+ dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
662
+ **kwargs: Additional client configuration (may include base_url override)
663
+
664
+ Returns:
665
+ Configured OpenAI client instance
666
+
667
+ Raises:
668
+ DataSetGeneratorError: If required API key is missing
669
+ """
670
+ api_key, client_kwargs = _get_openai_client_config(
671
+ api_key_env_var, default_base_url, dummy_key, **kwargs
672
+ )
673
+ return openai.OpenAI(api_key=api_key, **client_kwargs)
674
+
675
+
676
+ def _create_async_openai_compatible_client(
677
+ api_key_env_var: str | None,
678
+ default_base_url: str | None,
679
+ dummy_key: str | None = None,
680
+ **kwargs,
681
+ ) -> openai.AsyncOpenAI:
682
+ """Create an async OpenAI-compatible client for providers that use OpenAI's API format.
683
+
684
+ Args:
685
+ api_key_env_var: Environment variable name for API key (None to skip check)
686
+ default_base_url: Default base URL if not provided in kwargs (None for OpenAI default)
687
+ dummy_key: Dummy API key to use if api_key_env_var is None (e.g., for Ollama)
688
+ **kwargs: Additional client configuration (may include base_url override)
689
+
690
+ Returns:
691
+ Configured AsyncOpenAI client instance
692
+
693
+ Raises:
694
+ DataSetGeneratorError: If required API key is missing
695
+ """
696
+ api_key, client_kwargs = _get_openai_client_config(
697
+ api_key_env_var, default_base_url, dummy_key, **kwargs
698
+ )
699
+ return openai.AsyncOpenAI(api_key=api_key, **client_kwargs)
700
+
701
+
702
+ def _is_incompatible_object(schema: dict) -> bool:
703
+ """Check if a schema represents an object incompatible with Gemini.
704
+
705
+ Gemini rejects objects with no properties defined (like dict[str, Any]).
706
+
707
+ Args:
708
+ schema: JSON schema dictionary
709
+
710
+ Returns:
711
+ True if the schema is an incompatible object type
712
+ """
713
+ return schema.get("type") == "object" and "properties" not in schema and "$ref" not in schema
714
+
715
+
716
+ def _is_incompatible_array(schema: dict) -> bool:
717
+ """Check if a schema represents an array with incompatible items.
718
+
719
+ Arrays with object items that have no properties (like list[dict[str, Any]])
720
+ are incompatible with Gemini.
721
+
722
+ Args:
723
+ schema: JSON schema dictionary
724
+
725
+ Returns:
726
+ True if the schema is an incompatible array type
727
+ """
728
+ if schema.get("type") != "array" or "items" not in schema:
729
+ return False
730
+
731
+ items = schema["items"]
732
+ return isinstance(items, dict) and _is_incompatible_object(items)
733
+
734
+
735
+ def _inline_refs(schema_dict: dict, defs: dict | None = None) -> dict:
736
+ """
737
+ Recursively inline $ref references in a JSON schema.
738
+
739
+ Gemini's structured output can be unreliable with $ref. This function
740
+ resolves all references by inlining the definitions directly.
741
+
742
+ Args:
743
+ schema_dict: JSON schema dictionary
744
+ defs: The $defs dictionary from the root schema
745
+
746
+ Returns:
747
+ Modified schema dict with all $ref inlined
748
+ """
749
+ if not isinstance(schema_dict, dict):
750
+ return schema_dict
751
+
752
+ # Use provided defs or extract from schema (guaranteed to be dict after this)
753
+ resolved_defs: dict = defs if defs is not None else schema_dict.get("$defs", {})
754
+
755
+ # Handle $ref - inline the referenced definition
756
+ if "$ref" in schema_dict:
757
+ ref_path = schema_dict["$ref"]
758
+ # Parse reference like "#/$defs/GraphSubtopic"
759
+ if ref_path.startswith("#/$defs/"):
760
+ def_name = ref_path[len("#/$defs/") :]
761
+ if def_name in resolved_defs:
762
+ # Return a copy of the definition with refs inlined
763
+ inlined = _inline_refs(dict(resolved_defs[def_name]), resolved_defs)
764
+ # Preserve any other properties from the original (like description)
765
+ for key, value in schema_dict.items():
766
+ if key != "$ref":
767
+ inlined[key] = value
768
+ return inlined
769
+ # If we can't resolve, return as-is
770
+ return schema_dict
771
+
772
+ # Create a copy to modify
773
+ result: dict[str, Any] = {}
774
+
775
+ for key, value in schema_dict.items():
776
+ if key == "$defs":
777
+ # Skip $defs since we're inlining them
778
+ continue
779
+ if key == "properties" and isinstance(value, dict):
780
+ result[key] = {
781
+ prop_name: _inline_refs(prop_schema, resolved_defs)
782
+ for prop_name, prop_schema in value.items()
783
+ }
784
+ elif key == "items" and isinstance(value, dict):
785
+ result[key] = _inline_refs(value, resolved_defs)
786
+ elif key in _UNION_KEYS and isinstance(value, list):
787
+ result[key] = [_inline_refs(variant, resolved_defs) for variant in value]
788
+ else:
789
+ result[key] = value
790
+
791
+ return result
792
+
793
+
794
+ def _strip_additional_properties(schema_dict: dict) -> dict:
795
+ """
796
+ Recursively remove additionalProperties from JSON schema and handle dict[str, Any] fields.
797
+
798
+ Gemini doesn't support:
799
+ 1. additionalProperties field in JSON schemas
800
+ 2. Objects with no properties defined (e.g., dict[str, Any])
801
+ 3. Arrays whose items are objects with no properties (e.g., list[dict[str, Any]])
802
+
803
+ Fields like dict[str, Any] have additionalProperties: true and no properties defined.
804
+ Gemini requires that object-type fields must have properties, so we exclude these
805
+ fields from the schema entirely.
806
+
807
+ Note: This function preserves $defs and does NOT inline refs. Use _inline_refs
808
+ separately if ref inlining is needed.
809
+
810
+ Args:
811
+ schema_dict: JSON schema dictionary
812
+
813
+ Returns:
814
+ Modified schema dict without additionalProperties and dict[str, Any] fields
815
+ """
816
+ if not isinstance(schema_dict, dict):
817
+ return schema_dict
818
+
819
+ # For Gemini, identify and remove incompatible fields
820
+ if "properties" in schema_dict:
821
+ properties_to_remove = []
822
+ for prop_name, prop_schema in schema_dict["properties"].items():
823
+ if not isinstance(prop_schema, dict):
824
+ continue
825
+
826
+ # Check for direct incompatibilities
827
+ if prop_schema.get("additionalProperties") is True:
828
+ # Remove fields with additionalProperties: true (e.g., dict[str, Any])
829
+ properties_to_remove.append(prop_name)
830
+ elif _is_incompatible_object(prop_schema):
831
+ # Remove objects with no properties
832
+ properties_to_remove.append(prop_name)
833
+ elif _is_incompatible_array(prop_schema):
834
+ # Remove arrays with incompatible items (e.g., list[dict[str, Any]])
835
+ properties_to_remove.append(prop_name)
836
+ elif "anyOf" in prop_schema and any(
837
+ isinstance(variant, dict)
838
+ and (_is_incompatible_object(variant) or _is_incompatible_array(variant))
839
+ for variant in prop_schema["anyOf"]
840
+ ):
841
+ # Check if anyOf contains incompatible variants
842
+ properties_to_remove.append(prop_name)
843
+
844
+ # Remove incompatible properties from the schema
845
+ for prop_name in properties_to_remove:
846
+ del schema_dict["properties"][prop_name]
847
+
848
+ # Update required array to exclude removed properties
849
+ if "required" in schema_dict:
850
+ schema_dict["required"] = [
851
+ r for r in schema_dict["required"] if r not in properties_to_remove
852
+ ]
853
+
854
+ # Remove additionalProperties from current level
855
+ schema_dict.pop("additionalProperties", None)
856
+
857
+ # Recursively process nested structures
858
+ if "$defs" in schema_dict:
859
+ for def_name, def_schema in schema_dict["$defs"].items():
860
+ schema_dict["$defs"][def_name] = _strip_additional_properties(def_schema)
861
+
862
+ # Process properties recursively
863
+ if "properties" in schema_dict:
864
+ for prop_name, prop_schema in schema_dict["properties"].items():
865
+ schema_dict["properties"][prop_name] = _strip_additional_properties(prop_schema)
866
+
867
+ # Process items (for arrays)
868
+ if "items" in schema_dict:
869
+ schema_dict["items"] = _strip_additional_properties(schema_dict["items"])
870
+
871
+ # Process union types (anyOf, oneOf, allOf)
872
+ for union_key in _UNION_KEYS:
873
+ if union_key in schema_dict:
874
+ schema_dict[union_key] = [
875
+ _strip_additional_properties(variant) for variant in schema_dict[union_key]
876
+ ]
877
+
878
+ return schema_dict
879
+
880
+
881
+ def _ensure_anthropic_strict_mode_compliance(schema_dict: dict) -> dict:
882
+ """Ensure schema complies with Anthropic's structured outputs requirements.
883
+
884
+ Anthropic's structured outputs require:
885
+ 1. For objects, 'additionalProperties' must be explicitly set to false
886
+ 2. ALL properties must be in the 'required' array (no optional fields allowed)
887
+ 3. No fields with additionalProperties: true (incompatible with strict mode)
888
+
889
+ This is similar to OpenAI's strict mode requirements.
890
+
891
+ Args:
892
+ schema_dict: JSON schema dictionary
893
+
894
+ Returns:
895
+ Modified schema dict meeting Anthropic structured outputs requirements
896
+ """
897
+ # Reuse OpenAI's strict mode compliance function as requirements are similar
898
+ return _ensure_openai_strict_mode_compliance(schema_dict)
899
+
900
+
901
+ def _strip_ref_sibling_keywords(schema_dict: dict) -> dict:
902
+ """
903
+ Remove sibling keywords from $ref properties.
904
+
905
+ OpenAI's strict mode doesn't allow $ref to have additional keywords like 'description'.
906
+ When Pydantic generates schemas, it adds 'description' alongside '$ref' for nested models
907
+ that have Field(description=...). This function strips those extra keywords.
908
+
909
+ Args:
910
+ schema_dict: JSON schema dictionary
911
+
912
+ Returns:
913
+ Modified schema dict with $ref siblings removed
914
+ """
915
+ if not isinstance(schema_dict, dict):
916
+ return schema_dict
917
+
918
+ # If this dict has $ref, remove all sibling keywords except $ref itself
919
+ if "$ref" in schema_dict:
920
+ return {"$ref": schema_dict["$ref"]}
921
+
922
+ # Process properties recursively
923
+ if "properties" in schema_dict:
924
+ for prop_name, prop_schema in schema_dict["properties"].items():
925
+ schema_dict["properties"][prop_name] = _strip_ref_sibling_keywords(prop_schema)
926
+
927
+ # Process $defs recursively
928
+ if "$defs" in schema_dict:
929
+ for def_name, def_schema in schema_dict["$defs"].items():
930
+ schema_dict["$defs"][def_name] = _strip_ref_sibling_keywords(def_schema)
931
+
932
+ # Process items (for arrays)
933
+ if "items" in schema_dict:
934
+ schema_dict["items"] = _strip_ref_sibling_keywords(schema_dict["items"])
935
+
936
+ # Process union types (anyOf, oneOf, allOf)
937
+ for union_key in _UNION_KEYS:
938
+ if union_key in schema_dict:
939
+ schema_dict[union_key] = [
940
+ _strip_ref_sibling_keywords(variant) for variant in schema_dict[union_key]
941
+ ]
942
+
943
+ return schema_dict
944
+
945
+
946
+ def _ensure_openai_strict_mode_compliance(schema_dict: dict) -> dict:
947
+ """
948
+ Ensure schema complies with OpenAI's strict mode requirements.
949
+
950
+ OpenAI's strict mode requires:
951
+ 1. For objects, 'additionalProperties' must be explicitly set to false
952
+ 2. ALL properties must be in the 'required' array (no optional fields allowed)
953
+ 3. No fields with additionalProperties: true (incompatible with strict mode)
954
+ 4. $ref cannot have sibling keywords like 'description'
955
+
956
+ Fields like dict[str, Any] have additionalProperties: true and cannot be
957
+ represented in strict mode, so they are excluded from the schema entirely.
958
+
959
+ Args:
960
+ schema_dict: JSON schema dictionary
961
+
962
+ Returns:
963
+ Modified schema dict meeting OpenAI strict mode requirements
964
+ """
965
+ if not isinstance(schema_dict, dict):
966
+ return schema_dict
967
+
968
+ # First, strip sibling keywords from $ref properties
969
+ schema_dict = _strip_ref_sibling_keywords(schema_dict)
970
+
971
+ # For OpenAI strict mode, identify and remove dict[str, Any] fields
972
+ # These have additionalProperties: true which is incompatible with strict mode
973
+ if "properties" in schema_dict:
974
+ properties_to_remove = []
975
+ for prop_name, prop_schema in schema_dict["properties"].items():
976
+ # Check for direct additionalProperties: true
977
+ if isinstance(prop_schema, dict) and prop_schema.get("additionalProperties") is True:
978
+ # Remove fields with additionalProperties: true (e.g., dict[str, Any])
979
+ properties_to_remove.append(prop_name)
980
+ # Check for anyOf containing object variants with additionalProperties: true
981
+ elif isinstance(prop_schema, dict) and "anyOf" in prop_schema:
982
+ for variant in prop_schema["anyOf"]:
983
+ if isinstance(variant, dict) and variant.get("additionalProperties") is True:
984
+ # This anyOf contains an incompatible object variant - remove entire field
985
+ properties_to_remove.append(prop_name)
986
+ break
987
+
988
+ # Remove incompatible properties from the schema
989
+ for prop_name in properties_to_remove:
990
+ del schema_dict["properties"][prop_name]
991
+
992
+ # Update required array to exclude removed properties
993
+ if "required" in schema_dict:
994
+ schema_dict["required"] = [
995
+ r for r in schema_dict["required"] if r not in properties_to_remove
996
+ ]
997
+
998
+ # After removing incompatible fields, ensure ALL remaining properties are required
999
+ # OpenAI strict mode doesn't allow optional fields
1000
+ property_keys = list(schema_dict["properties"].keys())
1001
+ schema_dict["required"] = property_keys
1002
+ schema_dict["additionalProperties"] = False
1003
+
1004
+ # For all objects (including those without properties), set additionalProperties to false
1005
+ if schema_dict.get("type") == "object":
1006
+ schema_dict["additionalProperties"] = False
1007
+
1008
+ # Recursively process nested structures
1009
+ if "$defs" in schema_dict:
1010
+ for def_name, def_schema in schema_dict["$defs"].items():
1011
+ schema_dict["$defs"][def_name] = _ensure_openai_strict_mode_compliance(def_schema)
1012
+
1013
+ # Process properties recursively
1014
+ if "properties" in schema_dict:
1015
+ for prop_name, prop_schema in schema_dict["properties"].items():
1016
+ schema_dict["properties"][prop_name] = _ensure_openai_strict_mode_compliance(
1017
+ prop_schema
1018
+ )
1019
+
1020
+ # Process items (for arrays)
1021
+ if "items" in schema_dict:
1022
+ schema_dict["items"] = _ensure_openai_strict_mode_compliance(schema_dict["items"])
1023
+
1024
+ # Process union types (anyOf, oneOf, allOf)
1025
+ # This must be done to handle nested structures like list[dict[str, Any]] | None
1026
+ # where the dict[str, Any] is inside an array variant
1027
+ for union_key in _UNION_KEYS:
1028
+ if union_key in schema_dict:
1029
+ schema_dict[union_key] = [
1030
+ _ensure_openai_strict_mode_compliance(variant) for variant in schema_dict[union_key]
1031
+ ]
1032
+
1033
+ return schema_dict
1034
+
1035
+
1036
+ @lru_cache(maxsize=128)
1037
+ def _get_cached_openai_schema(schema: type[BaseModel]) -> type[BaseModel]:
1038
+ """
1039
+ Get or create a cached OpenAI-compatible version of a Pydantic schema.
1040
+
1041
+ This function caches transformed schemas to avoid repeated processing
1042
+ of the same Pydantic model during multiple generation calls.
1043
+
1044
+ Args:
1045
+ schema: Original Pydantic model
1046
+
1047
+ Returns:
1048
+ Cached wrapper model that generates OpenAI-compatible schemas
1049
+ """
1050
+
1051
+ # Create a new model class that overrides model_json_schema
1052
+ class OpenAICompatModel(schema): # type: ignore[misc,valid-type]
1053
+ @classmethod
1054
+ def model_json_schema(cls, **kwargs):
1055
+ # Get the original schema
1056
+ original_schema = super().model_json_schema(**kwargs)
1057
+ # Ensure OpenAI strict mode compliance
1058
+ return _ensure_openai_strict_mode_compliance(original_schema)
1059
+
1060
+ # Set name and docstring
1061
+ OpenAICompatModel.__name__ = f"{schema.__name__}OpenAICompat"
1062
+ OpenAICompatModel.__doc__ = schema.__doc__
1063
+
1064
+ # Rebuild model to resolve forward references (e.g., PendingToolCall in AgentStep)
1065
+ OpenAICompatModel.model_rebuild()
1066
+
1067
+ return OpenAICompatModel
1068
+
1069
+
1070
+ def _create_openai_compatible_schema(schema: type[BaseModel]) -> type[BaseModel]:
1071
+ """
1072
+ Create an OpenAI-compatible version of a Pydantic schema.
1073
+
1074
+ OpenAI's strict mode requires that all objects have 'additionalProperties: false'.
1075
+ This function ensures the schema meets those requirements while preserving
1076
+ Pydantic's correct handling of required vs optional fields.
1077
+
1078
+ Uses caching to avoid repeated transformation of the same schema.
1079
+
1080
+ Args:
1081
+ schema: Original Pydantic model
1082
+
1083
+ Returns:
1084
+ Wrapper model that generates OpenAI-compatible schemas
1085
+ """
1086
+ return _get_cached_openai_schema(schema)
1087
+
1088
+
1089
+ def make_outlines_model(provider: str, model_name: str, **kwargs) -> Any:
1090
+ """Create an Outlines model for the specified provider and model.
1091
+
1092
+ Args:
1093
+ provider: Provider name (openai, anthropic, gemini, ollama)
1094
+ model_name: Model identifier
1095
+ **kwargs: Additional parameters passed to the client
1096
+
1097
+ Returns:
1098
+ Outlines model instance
1099
+
1100
+ Raises:
1101
+ DataSetGeneratorError: If provider is unsupported or configuration fails
1102
+ """
1103
+ try:
1104
+ if provider == "openai":
1105
+ client = _create_openai_compatible_client(
1106
+ api_key_env_var="OPENAI_API_KEY",
1107
+ default_base_url=None, # Use OpenAI's default
1108
+ **kwargs,
1109
+ )
1110
+ return outlines.from_openai(client, model_name)
1111
+
1112
+ if provider == "ollama":
1113
+ client = _create_openai_compatible_client(
1114
+ api_key_env_var=None, # No API key required
1115
+ default_base_url="http://localhost:11434/v1",
1116
+ dummy_key="ollama",
1117
+ **kwargs,
1118
+ )
1119
+ return outlines.from_openai(client, model_name)
1120
+
1121
+ if provider == "openrouter":
1122
+ client = _create_openai_compatible_client(
1123
+ api_key_env_var="OPENROUTER_API_KEY",
1124
+ default_base_url="https://openrouter.ai/api/v1",
1125
+ **kwargs,
1126
+ )
1127
+ return outlines.from_openai(client, model_name)
1128
+
1129
+ if provider == "anthropic":
1130
+ api_key = os.getenv("ANTHROPIC_API_KEY")
1131
+ if not api_key:
1132
+ _raise_api_key_error("ANTHROPIC_API_KEY")
1133
+ # Use native Anthropic structured outputs API
1134
+ client = anthropic.AsyncAnthropic(api_key=api_key, **kwargs)
1135
+ return AnthropicModel(client, model_name)
1136
+
1137
+ if provider == "gemini":
1138
+ api_key = _get_gemini_api_key()
1139
+ # Use direct Gemini API instead of Outlines for better structured output reliability
1140
+ client = genai.Client(api_key=api_key)
1141
+ return GeminiModel(client, model_name)
1142
+
1143
+ _raise_unsupported_provider_error(provider)
1144
+
1145
+ except DataSetGeneratorError:
1146
+ # Re-raise our own errors (like missing API keys)
1147
+ raise
1148
+ except Exception as e:
1149
+ # Use the organized error handler
1150
+ raise handle_provider_error(e, provider, model_name) from e
1151
+
1152
+
1153
+ def make_async_outlines_model(provider: str, model_name: str, **kwargs) -> Any | None:
1154
+ """Create an async Outlines model when the provider supports it.
1155
+
1156
+ Returns ``None`` for providers without async-capable clients.
1157
+ """
1158
+
1159
+ try:
1160
+ if provider == "openai":
1161
+ client = _create_async_openai_compatible_client(
1162
+ api_key_env_var="OPENAI_API_KEY",
1163
+ default_base_url=None, # Use OpenAI's default
1164
+ **kwargs,
1165
+ )
1166
+ return outlines.from_openai(client, model_name)
1167
+
1168
+ if provider == "ollama":
1169
+ client = _create_async_openai_compatible_client(
1170
+ api_key_env_var=None, # No API key required
1171
+ default_base_url="http://localhost:11434/v1",
1172
+ dummy_key="ollama",
1173
+ **kwargs,
1174
+ )
1175
+ return outlines.from_openai(client, model_name)
1176
+
1177
+ if provider == "openrouter":
1178
+ client = _create_async_openai_compatible_client(
1179
+ api_key_env_var="OPENROUTER_API_KEY",
1180
+ default_base_url="https://openrouter.ai/api/v1",
1181
+ **kwargs,
1182
+ )
1183
+ return outlines.from_openai(client, model_name)
1184
+
1185
+ if provider == "anthropic":
1186
+ api_key = os.getenv("ANTHROPIC_API_KEY")
1187
+ if not api_key:
1188
+ _raise_api_key_error("ANTHROPIC_API_KEY")
1189
+ # Use native Anthropic structured outputs API
1190
+ client = anthropic.AsyncAnthropic(api_key=api_key, **kwargs)
1191
+ return AnthropicModel(client, model_name)
1192
+
1193
+ if provider == "gemini":
1194
+ api_key = _get_gemini_api_key()
1195
+ # Use direct async Gemini API for better structured output reliability
1196
+ client = genai.Client(api_key=api_key)
1197
+ return GeminiModel(client, model_name)
1198
+
1199
+ except DataSetGeneratorError:
1200
+ raise
1201
+ except Exception as e:
1202
+ raise handle_provider_error(e, provider, model_name) from e
1203
+
1204
+ # Outlines does not currently expose async structured generation wrappers
1205
+ # for the remaining providers. Fallback to synchronous execution later.
1206
+ return None