openaivec 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/_responses.py +77 -25
- openaivec/_schema.py +413 -0
- openaivec/pandas_ext.py +242 -140
- openaivec/spark.py +21 -1
- {openaivec-0.14.2.dist-info → openaivec-0.14.3.dist-info}/METADATA +1 -1
- {openaivec-0.14.2.dist-info → openaivec-0.14.3.dist-info}/RECORD +8 -7
- {openaivec-0.14.2.dist-info → openaivec-0.14.3.dist-info}/WHEEL +0 -0
- {openaivec-0.14.2.dist-info → openaivec-0.14.3.dist-info}/licenses/LICENSE +0 -0
openaivec/_responses.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from logging import Logger, getLogger
|
|
4
|
-
from typing import Generic, List, Type, cast
|
|
4
|
+
from typing import Any, Generic, List, Type, cast
|
|
5
5
|
|
|
6
6
|
from openai import AsyncOpenAI, BadRequestError, InternalServerError, OpenAI, RateLimitError
|
|
7
7
|
from openai.types.responses import ParsedResponse
|
|
@@ -163,7 +163,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
163
163
|
client: OpenAI
|
|
164
164
|
model_name: str # For Azure: deployment name, for OpenAI: model name
|
|
165
165
|
system_message: str
|
|
166
|
-
temperature: float | None =
|
|
166
|
+
temperature: float | None = None
|
|
167
167
|
top_p: float = 1.0
|
|
168
168
|
response_format: Type[ResponseFormat] = str # type: ignore[assignment]
|
|
169
169
|
cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
|
|
@@ -241,7 +241,9 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
241
241
|
|
|
242
242
|
@observe(_LOGGER)
|
|
243
243
|
@backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
244
|
-
def _request_llm(
|
|
244
|
+
def _request_llm(
|
|
245
|
+
self, user_messages: List[Message[str]], **extra_api_params: Any
|
|
246
|
+
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
245
247
|
"""Make a single call to the OpenAI JSON‑mode endpoint.
|
|
246
248
|
|
|
247
249
|
Args:
|
|
@@ -265,16 +267,29 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
265
267
|
class ResponseT(BaseModel):
|
|
266
268
|
assistant_messages: List[MessageT]
|
|
267
269
|
|
|
268
|
-
#
|
|
269
|
-
api_params = {
|
|
270
|
+
# Build base API parameters (cannot be overridden by caller)
|
|
271
|
+
api_params: dict[str, Any] = {
|
|
270
272
|
"model": self.model_name,
|
|
271
273
|
"instructions": self._vectorized_system_message,
|
|
272
274
|
"input": Request(user_messages=user_messages).model_dump_json(),
|
|
273
|
-
"top_p": self.top_p,
|
|
274
275
|
"text_format": ResponseT,
|
|
275
276
|
}
|
|
276
|
-
|
|
277
|
-
|
|
277
|
+
|
|
278
|
+
# Resolve nucleus sampling (caller can override)
|
|
279
|
+
top_p = extra_api_params.pop("top_p", self.top_p)
|
|
280
|
+
if top_p is not None:
|
|
281
|
+
api_params["top_p"] = top_p
|
|
282
|
+
|
|
283
|
+
# Resolve temperature (caller can override). If None, omit entirely for reasoning models.
|
|
284
|
+
temperature = extra_api_params.pop("temperature", self.temperature)
|
|
285
|
+
if temperature is not None:
|
|
286
|
+
api_params["temperature"] = temperature
|
|
287
|
+
|
|
288
|
+
# Merge remaining user supplied params, excluding protected keys
|
|
289
|
+
for k, v in extra_api_params.items():
|
|
290
|
+
if k in {"model", "instructions", "input", "text_format"}:
|
|
291
|
+
continue # ignore attempts to override core batching contract
|
|
292
|
+
api_params[k] = v
|
|
278
293
|
|
|
279
294
|
try:
|
|
280
295
|
completion: ParsedResponse[ResponseT] = self.client.responses.parse(**api_params)
|
|
@@ -285,7 +300,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
285
300
|
return cast(ParsedResponse[Response[ResponseFormat]], completion)
|
|
286
301
|
|
|
287
302
|
@observe(_LOGGER)
|
|
288
|
-
def _predict_chunk(self, user_messages: List[str]) -> List[ResponseFormat | None]:
|
|
303
|
+
def _predict_chunk(self, user_messages: List[str], **api_kwargs: Any) -> List[ResponseFormat | None]:
|
|
289
304
|
"""Helper executed for every unique minibatch.
|
|
290
305
|
|
|
291
306
|
This method:
|
|
@@ -297,7 +312,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
297
312
|
only on its arguments – which allows safe reuse.
|
|
298
313
|
"""
|
|
299
314
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
300
|
-
responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages)
|
|
315
|
+
responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages, **api_kwargs)
|
|
301
316
|
if not responses.output_parsed:
|
|
302
317
|
return [None] * len(messages)
|
|
303
318
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
@@ -305,17 +320,28 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
305
320
|
return sorted_responses
|
|
306
321
|
|
|
307
322
|
@observe(_LOGGER)
|
|
308
|
-
def parse(self, inputs: List[str]) -> List[ResponseFormat | None]:
|
|
323
|
+
def parse(self, inputs: List[str], **api_kwargs: Any) -> List[ResponseFormat | None]:
|
|
309
324
|
"""Batched predict.
|
|
310
325
|
|
|
326
|
+
Accepts arbitrary keyword arguments that are forwarded to the underlying
|
|
327
|
+
``OpenAI.responses.parse`` call for future‑proofing (e.g., ``max_output_tokens``,
|
|
328
|
+
penalties, etc.). ``top_p`` and ``temperature`` default to the instance's
|
|
329
|
+
configured values but can be overridden explicitly.
|
|
330
|
+
|
|
311
331
|
Args:
|
|
312
332
|
inputs (List[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
333
|
+
**api_kwargs: Extra keyword args forwarded to the OpenAI Responses API.
|
|
313
334
|
|
|
314
335
|
Returns:
|
|
315
336
|
List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
316
337
|
"""
|
|
317
|
-
|
|
318
|
-
|
|
338
|
+
if not api_kwargs:
|
|
339
|
+
return self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
340
|
+
|
|
341
|
+
def _predict_with(xs: List[str]) -> List[ResponseFormat | None]:
|
|
342
|
+
return self._predict_chunk(xs, **api_kwargs)
|
|
343
|
+
|
|
344
|
+
return self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|
|
319
345
|
|
|
320
346
|
|
|
321
347
|
@dataclass(frozen=True)
|
|
@@ -382,7 +408,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
382
408
|
client: AsyncOpenAI,
|
|
383
409
|
model_name: str,
|
|
384
410
|
system_message: str,
|
|
385
|
-
temperature: float | None =
|
|
411
|
+
temperature: float | None = None,
|
|
386
412
|
top_p: float = 1.0,
|
|
387
413
|
response_format: Type[ResponseFormat] = str,
|
|
388
414
|
batch_size: int | None = None,
|
|
@@ -455,7 +481,9 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
455
481
|
|
|
456
482
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
457
483
|
@observe(_LOGGER)
|
|
458
|
-
async def _request_llm(
|
|
484
|
+
async def _request_llm(
|
|
485
|
+
self, user_messages: List[Message[str]], **extra_api_params: Any
|
|
486
|
+
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
459
487
|
"""Make a single async call to the OpenAI JSON‑mode endpoint.
|
|
460
488
|
|
|
461
489
|
Args:
|
|
@@ -476,16 +504,29 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
476
504
|
class ResponseT(BaseModel):
|
|
477
505
|
assistant_messages: List[MessageT]
|
|
478
506
|
|
|
479
|
-
#
|
|
480
|
-
api_params = {
|
|
507
|
+
# Build base API parameters (cannot be overridden by caller)
|
|
508
|
+
api_params: dict[str, Any] = {
|
|
481
509
|
"model": self.model_name,
|
|
482
510
|
"instructions": self._vectorized_system_message,
|
|
483
511
|
"input": Request(user_messages=user_messages).model_dump_json(),
|
|
484
|
-
"top_p": self.top_p,
|
|
485
512
|
"text_format": ResponseT,
|
|
486
513
|
}
|
|
487
|
-
|
|
488
|
-
|
|
514
|
+
|
|
515
|
+
# Resolve nucleus sampling (caller can override)
|
|
516
|
+
top_p = extra_api_params.pop("top_p", self.top_p)
|
|
517
|
+
if top_p is not None:
|
|
518
|
+
api_params["top_p"] = top_p
|
|
519
|
+
|
|
520
|
+
# Resolve temperature (caller can override). If None, omit entirely for reasoning models.
|
|
521
|
+
temperature = extra_api_params.pop("temperature", self.temperature)
|
|
522
|
+
if temperature is not None:
|
|
523
|
+
api_params["temperature"] = temperature
|
|
524
|
+
|
|
525
|
+
# Merge remaining user supplied params, excluding protected keys
|
|
526
|
+
for k, v in extra_api_params.items():
|
|
527
|
+
if k in {"model", "instructions", "input", "text_format"}:
|
|
528
|
+
continue
|
|
529
|
+
api_params[k] = v
|
|
489
530
|
|
|
490
531
|
try:
|
|
491
532
|
completion: ParsedResponse[ResponseT] = await self.client.responses.parse(**api_params)
|
|
@@ -496,7 +537,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
496
537
|
return cast(ParsedResponse[Response[ResponseFormat]], completion)
|
|
497
538
|
|
|
498
539
|
@observe(_LOGGER)
|
|
499
|
-
async def _predict_chunk(self, user_messages: List[str]) -> List[ResponseFormat | None]:
|
|
540
|
+
async def _predict_chunk(self, user_messages: List[str], **api_kwargs: Any) -> List[ResponseFormat | None]:
|
|
500
541
|
"""Async helper executed for every unique minibatch.
|
|
501
542
|
|
|
502
543
|
This method:
|
|
@@ -507,7 +548,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
507
548
|
The function is pure – it has no side‑effects and the result depends only on its arguments.
|
|
508
549
|
"""
|
|
509
550
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
510
|
-
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages) # type: ignore[call-issue]
|
|
551
|
+
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages, **api_kwargs) # type: ignore[call-issue]
|
|
511
552
|
if not responses.output_parsed:
|
|
512
553
|
return [None] * len(messages)
|
|
513
554
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
@@ -516,14 +557,25 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
516
557
|
return sorted_responses
|
|
517
558
|
|
|
518
559
|
@observe(_LOGGER)
|
|
519
|
-
async def parse(self, inputs: List[str]) -> List[ResponseFormat | None]:
|
|
560
|
+
async def parse(self, inputs: List[str], **api_kwargs: Any) -> List[ResponseFormat | None]:
|
|
520
561
|
"""Batched predict (async).
|
|
521
562
|
|
|
563
|
+
Accepts arbitrary keyword arguments forwarded to ``AsyncOpenAI.responses.parse``.
|
|
564
|
+
``top_p`` and ``temperature`` default to instance configuration but can be
|
|
565
|
+
overridden per call. This prepares for future API parameters without
|
|
566
|
+
changing the public surface again.
|
|
567
|
+
|
|
522
568
|
Args:
|
|
523
569
|
inputs (List[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
570
|
+
**api_kwargs: Extra keyword args for the OpenAI Responses API.
|
|
524
571
|
|
|
525
572
|
Returns:
|
|
526
573
|
List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
527
574
|
"""
|
|
528
|
-
|
|
529
|
-
|
|
575
|
+
if not api_kwargs:
|
|
576
|
+
return await self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
577
|
+
|
|
578
|
+
async def _predict_with(xs: List[str]) -> List[ResponseFormat | None]:
|
|
579
|
+
return await self._predict_chunk(xs, **api_kwargs)
|
|
580
|
+
|
|
581
|
+
return await self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|
openaivec/_schema.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""Internal schema inference & dynamic model materialization utilities.
|
|
2
|
+
|
|
3
|
+
This (non-public) module converts a small *representative* sample of free‑text
|
|
4
|
+
examples plus a *purpose* statement into:
|
|
5
|
+
|
|
6
|
+
1. A vetted, flat list of scalar field specifications (``FieldSpec``) that can
|
|
7
|
+
be *reliably* extracted across similar future inputs.
|
|
8
|
+
2. A reusable, self‑contained extraction prompt (``inference_prompt``) that
|
|
9
|
+
freezes the agreed schema contract (no additions / renames / omissions).
|
|
10
|
+
3. A dynamically generated Pydantic model whose fields mirror the inferred
|
|
11
|
+
schema, enabling immediate typed parsing with the OpenAI Responses API.
|
|
12
|
+
4. A ``PreparedTask`` wrapper (``InferredSchema.task``) for downstream batched
|
|
13
|
+
responses/structured extraction flows in pandas or Spark.
|
|
14
|
+
|
|
15
|
+
Core goals:
|
|
16
|
+
* Minimize manual, subjective schema design iterations.
|
|
17
|
+
* Enforce objective naming / typing / enum rules early (guard rails rather than
|
|
18
|
+
after‑the‑fact cleaning).
|
|
19
|
+
* Provide deterministic reusability: the same prompt + model yield stable
|
|
20
|
+
column ordering & types for analytics or feature engineering.
|
|
21
|
+
* Avoid outcome / target label leakage in predictive (feature engineering)
|
|
22
|
+
contexts by explicitly excluding direct target restatements.
|
|
23
|
+
|
|
24
|
+
This module is intentionally **internal** (``__all__ = []``). Public users
|
|
25
|
+
should interact through higher‑level batch APIs once a schema has been inferred.
|
|
26
|
+
|
|
27
|
+
Design constraints:
|
|
28
|
+
* Flat schema only (no nesting / arrays) to simplify Spark & pandas alignment.
|
|
29
|
+
* Primitive types limited to {string, integer, float, boolean}.
|
|
30
|
+
* Optional enumerations for *closed*, *observed* categorical sets only.
|
|
31
|
+
* Validation retries ensure a structurally coherent suggestion before returning.
|
|
32
|
+
|
|
33
|
+
Example (conceptual):
|
|
34
|
+
from openai import OpenAI
|
|
35
|
+
client = OpenAI()
|
|
36
|
+
inferer = SchemaInferer(client=client, model_name="gpt-4.1-mini")
|
|
37
|
+
schema = inferer.infer_schema(
|
|
38
|
+
SchemaInferenceInput(
|
|
39
|
+
examples=["Order #123 delayed due to weather", "Order #456 delivered"],
|
|
40
|
+
purpose="Extract operational status signals for logistics analytics",
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
Model = schema.model # dynamic Pydantic model
|
|
44
|
+
task = schema.task # PreparedTask for batch extraction
|
|
45
|
+
|
|
46
|
+
The implementation purposefully does *not* emit or depend on JSON Schema; the
|
|
47
|
+
authoritative contract is the ordered list of ``FieldSpec`` instances.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
from dataclasses import dataclass
|
|
51
|
+
from enum import Enum
|
|
52
|
+
from typing import List, Literal, Optional, Type
|
|
53
|
+
|
|
54
|
+
from openai import OpenAI
|
|
55
|
+
from openai.types.responses import ParsedResponse
|
|
56
|
+
from pydantic import BaseModel, Field, create_model
|
|
57
|
+
|
|
58
|
+
from openaivec._model import PreparedTask
|
|
59
|
+
|
|
60
|
+
# Internal module: explicitly not part of public API
|
|
61
|
+
__all__: list[str] = []
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class FieldSpec(BaseModel):
|
|
65
|
+
"""Specification for a single candidate output field.
|
|
66
|
+
|
|
67
|
+
Each ``FieldSpec`` encodes a *flat*, scalar, semantically atomic unit the
|
|
68
|
+
model should extract. These become columns in downstream DataFrames.
|
|
69
|
+
|
|
70
|
+
Validation focuses on: objective naming, primitive typing, and *optional*
|
|
71
|
+
closed categorical vocabularies. Enumerations are intentionally conservative
|
|
72
|
+
(must derive from clear evidence) to reduce over‑fitted schemas.
|
|
73
|
+
|
|
74
|
+
Attributes:
|
|
75
|
+
name: Lower snake_case unique identifier (regex ^[a-z][a-z0-9_]*$). Avoid
|
|
76
|
+
subjective modifiers ("best", "great", "high_quality").
|
|
77
|
+
type: One of ``string|integer|float|boolean``. ``integer`` only if all
|
|
78
|
+
observed numeric values are whole numbers; ``float`` if any decimal
|
|
79
|
+
or ratio appears. ``boolean`` strictly for explicit binary forms.
|
|
80
|
+
description: Concise, objective extraction rule (what qualifies / what
|
|
81
|
+
to ignore). Disambiguate from overlapping fields if needed.
|
|
82
|
+
enum_values: Optional stable closed set of lowercase string labels
|
|
83
|
+
(2–24). Only for *string* type when the vocabulary is clearly
|
|
84
|
+
evidenced; never hallucinate or extrapolate.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
name: str = Field(
|
|
88
|
+
description=(
|
|
89
|
+
"Lower snake_case identifier (regex: ^[a-z][a-z0-9_]*$). Must be unique across all fields and "
|
|
90
|
+
"express the semantic meaning succinctly (no adjectives like 'best', 'great')."
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
type: Literal["string", "integer", "float", "boolean"] = Field(
|
|
94
|
+
description=(
|
|
95
|
+
"Primitive type. Use 'integer' only if all observed numeric values are whole numbers. "
|
|
96
|
+
"Use 'float' if any value can contain a decimal or represents a ratio/score. Use 'boolean' only for "
|
|
97
|
+
"explicit binary states (yes/no, true/false, present/absent) consistently encoded. Use 'string' otherwise. "
|
|
98
|
+
"Never output arrays, objects, or composite encodings; flatten to the most specific scalar value."
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
description: str = Field(
|
|
102
|
+
description=(
|
|
103
|
+
"Concise, objective definition plus extraction rule (what qualifies / what to ignore). Avoid subjective, "
|
|
104
|
+
"speculative, or promotional language. If ambiguity exists with another field, clarify the distinction."
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
enum_values: Optional[List[str]] = Field(
|
|
108
|
+
default=None,
|
|
109
|
+
description=(
|
|
110
|
+
"Optional finite categorical label set (classification) for a string field. Provide ONLY when a closed, "
|
|
111
|
+
"stable vocabulary (2–24 lowercase tokens) is clearly evidenced or strongly implied by examples. "
|
|
112
|
+
"Do NOT invent labels. Omit if open-ended or ambiguous. Order must be stable and semantically natural."
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class InferredSchema(BaseModel):
|
|
118
|
+
"""Result of a schema inference round.
|
|
119
|
+
|
|
120
|
+
Contains the normalized *purpose*, an objective *examples_summary*, the
|
|
121
|
+
ordered ``fields`` contract, and the canonical reusable ``inference_prompt``.
|
|
122
|
+
|
|
123
|
+
The prompt is constrained to be fully derivable from the other components;
|
|
124
|
+
adding novel unstated facts is disallowed to preserve traceability.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
purpose: Unambiguous restatement of the user's objective (noise &
|
|
128
|
+
redundancy removed).
|
|
129
|
+
examples_summary: Neutral description of structural / semantic patterns
|
|
130
|
+
observed in the examples (domain, recurring signals, constraints).
|
|
131
|
+
fields: Ordered list of ``FieldSpec`` objects comprising the schema's
|
|
132
|
+
sole authoritative contract.
|
|
133
|
+
inference_prompt: Self-contained extraction instructions enforcing an
|
|
134
|
+
exact field set (names, order, primitive types) with prohibition on
|
|
135
|
+
alterations or subjective flourishes.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
purpose: str = Field(
|
|
139
|
+
description=(
|
|
140
|
+
"Normalized, unambiguous restatement of the user objective with redundant, vague, or "
|
|
141
|
+
"conflicting phrasing removed."
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
examples_summary: str = Field(
|
|
145
|
+
description=(
|
|
146
|
+
"Objective characterization of the provided examples: content domain, structure, recurring "
|
|
147
|
+
"patterns, and notable constraints."
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
fields: List[FieldSpec] = Field(
|
|
151
|
+
description=(
|
|
152
|
+
"Ordered list of proposed fields derived strictly from observable, repeatable signals in the "
|
|
153
|
+
"examples and aligned with the purpose."
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
inference_prompt: str = Field(
|
|
157
|
+
description=(
|
|
158
|
+
"Canonical, reusable extraction prompt for structuring future inputs with this schema. "
|
|
159
|
+
"Must be fully derivable from 'purpose', 'examples_summary', and 'fields' (no new unstated facts or "
|
|
160
|
+
"speculation). It MUST: (1) instruct the model to output only the listed fields with the exact names "
|
|
161
|
+
"and primitive types; (2) forbid adding, removing, or renaming fields; (3) avoid subjective or "
|
|
162
|
+
"marketing language; (4) be self-contained (no TODOs, no external references, no unresolved "
|
|
163
|
+
"placeholders). Intended for direct reuse as the prompt for deterministic alignment with 'fields'."
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
def load(cls, path: str) -> "InferredSchema":
|
|
169
|
+
"""Load an inferred schema from a JSON file.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
path (str): Path to a UTF‑8 JSON document previously produced via ``save``.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
InferredSchema: Reconstructed instance.
|
|
176
|
+
"""
|
|
177
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
178
|
+
return cls.model_validate_json(f.read())
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def model(self) -> Type[BaseModel]:
|
|
182
|
+
"""Dynamically materialized Pydantic model for the inferred schema.
|
|
183
|
+
|
|
184
|
+
Equivalent to calling :meth:`build_model` each access (not cached).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Type[BaseModel]: Fresh model type reflecting ``fields`` ordering.
|
|
188
|
+
"""
|
|
189
|
+
return self.build_model()
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def task(self) -> PreparedTask:
|
|
193
|
+
"""PreparedTask integrating the schema's extraction prompt & model.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
PreparedTask: Ready for batched structured extraction calls.
|
|
197
|
+
"""
|
|
198
|
+
return PreparedTask(
|
|
199
|
+
instructions=self.inference_prompt, response_format=self.model, top_p=None, temperature=None
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def build_model(self) -> Type[BaseModel]:
|
|
203
|
+
"""Create a new dynamic ``BaseModel`` class adhering to this schema.
|
|
204
|
+
|
|
205
|
+
Implementation details:
|
|
206
|
+
* Maps primitive types: string→``str``, integer→``int``, float→``float``, boolean→``bool``.
|
|
207
|
+
* For enumerated string fields, constructs an ad‑hoc ``Enum`` subclass with
|
|
208
|
+
stable member names (collision‑safe, normalized to ``UPPER_SNAKE``).
|
|
209
|
+
* All fields are required (ellipsis ``...``). Optionality can be
|
|
210
|
+
introduced later by modifying this logic if needed.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Type[BaseModel]: New (not cached) model type; order matches ``fields``.
|
|
214
|
+
"""
|
|
215
|
+
type_map: dict[str, type] = {"string": str, "integer": int, "float": float, "boolean": bool}
|
|
216
|
+
fields: dict[str, tuple[type, object]] = {}
|
|
217
|
+
|
|
218
|
+
for spec in self.fields:
|
|
219
|
+
py_type: type
|
|
220
|
+
if spec.enum_values:
|
|
221
|
+
enum_class_name = "Enum_" + "".join(part.capitalize() for part in spec.name.split("_"))
|
|
222
|
+
members: dict[str, str] = {}
|
|
223
|
+
for raw in spec.enum_values:
|
|
224
|
+
sanitized = raw.upper().replace("-", "_").replace(" ", "_")
|
|
225
|
+
if not sanitized or sanitized[0].isdigit():
|
|
226
|
+
sanitized = f"V_{sanitized}"
|
|
227
|
+
base = sanitized
|
|
228
|
+
i = 2
|
|
229
|
+
while sanitized in members:
|
|
230
|
+
sanitized = f"{base}_{i}"
|
|
231
|
+
i += 1
|
|
232
|
+
members[sanitized] = raw
|
|
233
|
+
enum_cls = Enum(enum_class_name, members) # type: ignore[arg-type]
|
|
234
|
+
py_type = enum_cls
|
|
235
|
+
else:
|
|
236
|
+
py_type = type_map[spec.type]
|
|
237
|
+
fields[spec.name] = (py_type, ...)
|
|
238
|
+
|
|
239
|
+
model = create_model("InferredSchema", **fields) # type: ignore[call-arg]
|
|
240
|
+
return model
|
|
241
|
+
|
|
242
|
+
def save(self, path: str) -> None:
|
|
243
|
+
"""Persist this inferred schema as pretty‑printed JSON.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
path (str): Destination filesystem path.
|
|
247
|
+
"""
|
|
248
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
249
|
+
f.write(self.model_dump_json(indent=2))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class SchemaInferenceInput(BaseModel):
|
|
253
|
+
"""Input payload for schema inference.
|
|
254
|
+
|
|
255
|
+
Attributes:
|
|
256
|
+
examples: Representative sample texts restricted to the in‑scope
|
|
257
|
+
distribution (exclude outliers / noise). Size should be *minimal*
|
|
258
|
+
yet sufficient to surface recurring patterns.
|
|
259
|
+
purpose: Plain language description of downstream usage (analytics,
|
|
260
|
+
filtering, enrichment, feature engineering, etc.). Guides field
|
|
261
|
+
relevance & exclusion of outcome labels.
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
examples: List[str] = Field(
|
|
265
|
+
description=(
|
|
266
|
+
"Representative sample texts (strings). Provide only data the schema should generalize over; "
|
|
267
|
+
"exclude outliers not in scope."
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
purpose: str = Field(
|
|
271
|
+
description=(
|
|
272
|
+
"Plain language statement describing the downstream use of the extracted structured data (e.g. "
|
|
273
|
+
"analytics, filtering, enrichment)."
|
|
274
|
+
)
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
_INFER_INSTRUCTIONS = """
|
|
279
|
+
You are a schema inference engine.
|
|
280
|
+
|
|
281
|
+
Task:
|
|
282
|
+
1. Normalize the user's purpose (eliminate ambiguity, redundancy, contradictions).
|
|
283
|
+
2. Objectively summarize observable patterns in the example texts.
|
|
284
|
+
3. Propose a minimal flat set of scalar fields (no nesting / arrays) that are reliably extractable.
|
|
285
|
+
4. Skip fields likely missing in a large share (>~20%) of realistic inputs.
|
|
286
|
+
5. Provide enum_values ONLY when a small stable closed categorical set (2–24 lowercase tokens)
|
|
287
|
+
is clearly evidenced; never invent.
|
|
288
|
+
6. If the purpose indicates prediction (predict / probability / likelihood), output only
|
|
289
|
+
explanatory features (no target restatement).
|
|
290
|
+
|
|
291
|
+
Rules:
|
|
292
|
+
- Names: lower snake_case, unique, regex ^[a-z][a-z0-9_]*$, no subjective adjectives.
|
|
293
|
+
- Types: string | integer | float | boolean
|
|
294
|
+
* integer = all whole numbers
|
|
295
|
+
* float = any decimals / ratios
|
|
296
|
+
* boolean = explicit binary
|
|
297
|
+
* else use string
|
|
298
|
+
- No arrays, objects, composite encodings, or merged multi-concept fields.
|
|
299
|
+
- Descriptions: concise, objective extraction rules (no marketing/emotion/speculation).
|
|
300
|
+
- enum_values only for string fields with stable closed vocab; omit otherwise.
|
|
301
|
+
- Exclude direct outcome labels (e.g. attrition_probability, will_buy, purchase_likelihood)
|
|
302
|
+
in predictive / feature engineering contexts.
|
|
303
|
+
|
|
304
|
+
Output contract:
|
|
305
|
+
Return exactly an InferredSchema object with JSON keys:
|
|
306
|
+
- purpose (string)
|
|
307
|
+
- examples_summary (string)
|
|
308
|
+
- fields (array of FieldSpec objects: name, type, description, enum_values?)
|
|
309
|
+
- inference_prompt (string)
|
|
310
|
+
""".strip()
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@dataclass(frozen=True)
|
|
314
|
+
class SchemaInferer:
|
|
315
|
+
"""High-level orchestrator for schema inference against the Responses API.
|
|
316
|
+
|
|
317
|
+
Responsibilities:
|
|
318
|
+
* Issue a structured parsing request with strict instructions.
|
|
319
|
+
* Retry (up to ``max_retries``) when the produced field list violates
|
|
320
|
+
baseline structural rules (duplicate names, unsupported types, etc.).
|
|
321
|
+
* Return a fully validated ``InferredSchema`` ready for dynamic model
|
|
322
|
+
generation & downstream batch extraction.
|
|
323
|
+
|
|
324
|
+
The inferred schema intentionally avoids JSON Schema intermediates; the
|
|
325
|
+
authoritative contract is the ordered ``FieldSpec`` list.
|
|
326
|
+
|
|
327
|
+
Attributes:
|
|
328
|
+
client: OpenAI client for calling ``responses.parse``.
|
|
329
|
+
model_name: Model / deployment identifier.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
client: OpenAI
|
|
333
|
+
model_name: str
|
|
334
|
+
|
|
335
|
+
def infer_schema(self, data: "SchemaInferenceInput", *args, max_retries: int = 3, **kwargs) -> "InferredSchema":
|
|
336
|
+
"""Infer a validated schema from representative examples.
|
|
337
|
+
|
|
338
|
+
Workflow:
|
|
339
|
+
1. Submit ``SchemaInferenceInput`` (JSON) + instructions via
|
|
340
|
+
``responses.parse`` requesting an ``InferredSchema`` object.
|
|
341
|
+
2. Validate the returned field list with ``_basic_field_list_validation``.
|
|
342
|
+
3. Retry (up to ``max_retries``) if validation fails.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
data (SchemaInferenceInput): Representative examples + purpose.
|
|
346
|
+
*args: Positional passthrough to ``client.responses.parse``.
|
|
347
|
+
max_retries (int, optional): Attempts before surfacing the last validation error
|
|
348
|
+
(must be >= 1). Defaults to 3.
|
|
349
|
+
**kwargs: Keyword passthrough to ``client.responses.parse``.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
InferredSchema: Fully validated schema (purpose, examples summary,
|
|
353
|
+
ordered fields, extraction prompt).
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
ValueError: Validation still fails after exhausting retries.
|
|
357
|
+
"""
|
|
358
|
+
if max_retries < 1:
|
|
359
|
+
raise ValueError("max_retries must be >= 1")
|
|
360
|
+
|
|
361
|
+
last_err: Exception | None = None
|
|
362
|
+
for attempt in range(max_retries):
|
|
363
|
+
response: ParsedResponse[InferredSchema] = self.client.responses.parse(
|
|
364
|
+
model=self.model_name,
|
|
365
|
+
instructions=_INFER_INSTRUCTIONS,
|
|
366
|
+
input=data.model_dump_json(),
|
|
367
|
+
text_format=InferredSchema,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
)
|
|
371
|
+
parsed = response.output_parsed
|
|
372
|
+
try:
|
|
373
|
+
_basic_field_list_validation(parsed)
|
|
374
|
+
except ValueError as e:
|
|
375
|
+
last_err = e
|
|
376
|
+
if attempt == max_retries - 1:
|
|
377
|
+
raise
|
|
378
|
+
continue
|
|
379
|
+
return parsed
|
|
380
|
+
if last_err: # pragma: no cover
|
|
381
|
+
raise last_err
|
|
382
|
+
raise RuntimeError("unreachable retry loop state") # pragma: no cover
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _basic_field_list_validation(parsed: InferredSchema) -> None:
|
|
386
|
+
"""Lightweight structural validation of an inferred field list.
|
|
387
|
+
|
|
388
|
+
Checks:
|
|
389
|
+
* Non-empty field set.
|
|
390
|
+
* No duplicate names.
|
|
391
|
+
* All types in the allowed primitive set.
|
|
392
|
+
* ``enum_values`` only on string fields and size within bounds (2–24).
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
parsed (InferredSchema): Candidate ``InferredSchema`` instance.
|
|
396
|
+
|
|
397
|
+
Raises:
|
|
398
|
+
ValueError: Any invariant is violated.
|
|
399
|
+
"""
|
|
400
|
+
names = [f.name for f in parsed.fields]
|
|
401
|
+
if not names:
|
|
402
|
+
raise ValueError("no fields suggested")
|
|
403
|
+
if len(names) != len(set(names)):
|
|
404
|
+
raise ValueError("duplicate field names detected")
|
|
405
|
+
allowed = {"string", "integer", "float", "boolean"}
|
|
406
|
+
for f in parsed.fields:
|
|
407
|
+
if f.type not in allowed:
|
|
408
|
+
raise ValueError(f"unsupported field type: {f.type}")
|
|
409
|
+
if f.enum_values is not None:
|
|
410
|
+
if f.type != "string":
|
|
411
|
+
raise ValueError(f"enum_values only allowed for string field: {f.name}")
|
|
412
|
+
if not (2 <= len(f.enum_values) <= 24):
|
|
413
|
+
raise ValueError(f"enum_values length out of bounds for field {f.name}")
|