openaivec 0.12.5__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. openaivec/__init__.py +13 -4
  2. openaivec/_cache/__init__.py +12 -0
  3. openaivec/_cache/optimize.py +109 -0
  4. openaivec/_cache/proxy.py +806 -0
  5. openaivec/{di.py → _di.py} +36 -12
  6. openaivec/_embeddings.py +203 -0
  7. openaivec/{log.py → _log.py} +2 -2
  8. openaivec/_model.py +113 -0
  9. openaivec/{prompt.py → _prompt.py} +95 -28
  10. openaivec/_provider.py +207 -0
  11. openaivec/_responses.py +511 -0
  12. openaivec/_schema/__init__.py +9 -0
  13. openaivec/_schema/infer.py +340 -0
  14. openaivec/_schema/spec.py +350 -0
  15. openaivec/_serialize.py +234 -0
  16. openaivec/{util.py → _util.py} +25 -85
  17. openaivec/pandas_ext.py +1496 -318
  18. openaivec/spark.py +485 -183
  19. openaivec/task/__init__.py +9 -7
  20. openaivec/task/customer_support/__init__.py +9 -15
  21. openaivec/task/customer_support/customer_sentiment.py +17 -15
  22. openaivec/task/customer_support/inquiry_classification.py +23 -22
  23. openaivec/task/customer_support/inquiry_summary.py +14 -13
  24. openaivec/task/customer_support/intent_analysis.py +21 -19
  25. openaivec/task/customer_support/response_suggestion.py +16 -16
  26. openaivec/task/customer_support/urgency_analysis.py +24 -25
  27. openaivec/task/nlp/__init__.py +4 -4
  28. openaivec/task/nlp/dependency_parsing.py +10 -12
  29. openaivec/task/nlp/keyword_extraction.py +11 -14
  30. openaivec/task/nlp/morphological_analysis.py +12 -14
  31. openaivec/task/nlp/named_entity_recognition.py +16 -18
  32. openaivec/task/nlp/sentiment_analysis.py +14 -11
  33. openaivec/task/nlp/translation.py +6 -9
  34. openaivec/task/table/__init__.py +2 -2
  35. openaivec/task/table/fillna.py +11 -11
  36. openaivec-1.0.10.dist-info/METADATA +399 -0
  37. openaivec-1.0.10.dist-info/RECORD +39 -0
  38. {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
  39. openaivec/embeddings.py +0 -172
  40. openaivec/model.py +0 -67
  41. openaivec/provider.py +0 -45
  42. openaivec/responses.py +0 -393
  43. openaivec/serialize.py +0 -225
  44. openaivec-0.12.5.dist-info/METADATA +0 -696
  45. openaivec-0.12.5.dist-info/RECORD +0 -33
  46. {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
openaivec/provider.py DELETED
@@ -1,45 +0,0 @@
1
- import os
2
-
3
- from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
4
-
5
-
6
- def provide_openai_client() -> OpenAI:
7
- """Provide OpenAI client based on environment variables. Prioritizes OpenAI over Azure."""
8
- if os.getenv("OPENAI_API_KEY"):
9
- return OpenAI()
10
-
11
- if all(
12
- os.getenv(name) for name in ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_API_ENDPOINT", "AZURE_OPENAI_API_VERSION"]
13
- ):
14
- return AzureOpenAI(
15
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
16
- azure_endpoint=os.getenv("AZURE_OPENAI_API_ENDPOINT"),
17
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
18
- )
19
-
20
- raise ValueError(
21
- "No valid OpenAI or Azure OpenAI environment variables found. "
22
- "Please set either OPENAI_API_KEY or AZURE_OPENAI_API_KEY, "
23
- "AZURE_OPENAI_API_ENDPOINT, and AZURE_OPENAI_API_VERSION."
24
- )
25
-
26
-
27
- def provide_async_openai_client() -> AsyncOpenAI:
28
- """Provide async OpenAI client based on environment variables. Prioritizes OpenAI over Azure."""
29
- if os.getenv("OPENAI_API_KEY"):
30
- return AsyncOpenAI()
31
-
32
- if all(
33
- os.getenv(name) for name in ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_API_ENDPOINT", "AZURE_OPENAI_API_VERSION"]
34
- ):
35
- return AsyncAzureOpenAI(
36
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
37
- azure_endpoint=os.getenv("AZURE_OPENAI_API_ENDPOINT"),
38
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
39
- )
40
-
41
- raise ValueError(
42
- "No valid OpenAI or Azure OpenAI environment variables found. "
43
- "Please set either OPENAI_API_KEY or AZURE_OPENAI_API_KEY, "
44
- "AZURE_OPENAI_API_ENDPOINT, and AZURE_OPENAI_API_VERSION."
45
- )
openaivec/responses.py DELETED
@@ -1,393 +0,0 @@
1
- import asyncio
2
- from dataclasses import dataclass, field
3
- from logging import Logger, getLogger
4
- from typing import Generic, List, Type, cast
5
-
6
- from openai import AsyncOpenAI, OpenAI, RateLimitError
7
- from openai.types.responses import ParsedResponse
8
- from pydantic import BaseModel
9
-
10
- from .log import observe
11
- from .model import PreparedTask, ResponseFormat
12
- from .util import backoff, backoff_async, map, map_async
13
-
14
- __all__ = [
15
- "BatchResponses",
16
- "AsyncBatchResponses",
17
- ]
18
-
19
- _LOGGER: Logger = getLogger(__name__)
20
-
21
-
22
- def _vectorize_system_message(system_message: str) -> str:
23
- """Return the system prompt that instructs the model to work on a batch.
24
-
25
- The returned XML‐ish prompt explains two things to the LLM:
26
-
27
- 1. The *general* system instruction coming from the caller (`system_message`)
28
- is preserved verbatim.
29
- 2. Extra instructions describe how the model should treat the incoming JSON
30
- that contains multiple user messages and how it must shape its output.
31
-
32
- Args:
33
- system_message (str): A single‑instance system instruction the caller would
34
- normally send to the model.
35
-
36
- Returns:
37
- A long, composite system prompt with embedded examples that can be
38
- supplied to the `instructions=` field of the OpenAI **JSON mode**
39
- endpoint.
40
- """
41
- return f"""
42
- <SystemMessage>
43
- <ElementInstructions>
44
- <Instruction>{system_message}</Instruction>
45
- </ElementInstructions>
46
- <BatchInstructions>
47
- <Instruction>
48
- You will receive multiple user messages at once.
49
- Please provide an appropriate response to each message individually.
50
- </Instruction>
51
- </BatchInstructions>
52
- <Examples>
53
- <Example>
54
- <Input>
55
- {{
56
- "user_messages": [
57
- {{
58
- "id": 1,
59
- "body": "{{user_message_1}}"
60
- }},
61
- {{
62
- "id": 2,
63
- "body": "{{user_message_2}}"
64
- }}
65
- ]
66
- }}
67
- </Input>
68
- <Output>
69
- {{
70
- "assistant_messages": [
71
- {{
72
- "id": 1,
73
- "body": "{{assistant_response_1}}"
74
- }},
75
- {{
76
- "id": 2,
77
- "body": "{{assistant_response_2}}"
78
- }}
79
- ]
80
- }}
81
- </Output>
82
- </Example>
83
- </Examples>
84
- </SystemMessage>
85
- """
86
-
87
-
88
-
89
-
90
- class Message(BaseModel, Generic[ResponseFormat]):
91
- id: int
92
- body: ResponseFormat
93
-
94
-
95
- class Request(BaseModel):
96
- user_messages: List[Message[str]]
97
-
98
-
99
- class Response(BaseModel, Generic[ResponseFormat]):
100
- assistant_messages: List[Message[ResponseFormat]]
101
-
102
-
103
- @dataclass(frozen=True)
104
- class BatchResponses(Generic[ResponseFormat]):
105
- """Stateless façade that turns OpenAI's JSON‑mode API into a batched API.
106
-
107
- This wrapper allows you to submit *multiple* user prompts in one JSON‑mode
108
- request and receive the answers in the original order.
109
-
110
- Example:
111
- ```python
112
- vector_llm = BatchResponses(
113
- client=openai_client,
114
- model_name="gpt‑4o‑mini",
115
- system_message="You are a helpful assistant."
116
- )
117
- answers = vector_llm.parse(questions, batch_size=32)
118
- ```
119
-
120
- Attributes:
121
- client: Initialised ``openai.OpenAI`` client.
122
- model_name: Name of the model (or Azure deployment) to invoke.
123
- system_message: System prompt prepended to every request.
124
- temperature: Sampling temperature passed to the model.
125
- top_p: Nucleus‑sampling parameter.
126
- response_format: Expected Pydantic BaseModel subclass or str type for each assistant message
127
- (defaults to ``str``).
128
-
129
- Notes:
130
- Internally the work is delegated to two helpers:
131
-
132
- * ``_predict_chunk`` – fragments the workload and restores ordering.
133
- * ``_request_llm`` – performs a single OpenAI API call.
134
- """
135
-
136
- client: OpenAI
137
- model_name: str # it would be the name of deployment for Azure
138
- system_message: str
139
- temperature: float = 0.0
140
- top_p: float = 1.0
141
- response_format: Type[ResponseFormat] = str
142
- _vectorized_system_message: str = field(init=False)
143
- _model_json_schema: dict = field(init=False)
144
-
145
- @classmethod
146
- def of_task(cls, client: OpenAI, model_name: str, task: PreparedTask) -> "BatchResponses":
147
- """Create a BatchResponses instance from a PreparedTask."""
148
- return cls(
149
- client=client,
150
- model_name=model_name,
151
- system_message=task.instructions,
152
- temperature=task.temperature,
153
- top_p=task.top_p,
154
- response_format=task.response_format,
155
- )
156
-
157
- def __post_init__(self):
158
- object.__setattr__(
159
- self,
160
- "_vectorized_system_message",
161
- _vectorize_system_message(self.system_message),
162
- )
163
-
164
- @observe(_LOGGER)
165
- @backoff(exception=RateLimitError, scale=15, max_retries=8)
166
- def _request_llm(self, user_messages: List[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
167
- """Make a single call to the OpenAI *JSON mode* endpoint.
168
-
169
- Args:
170
- user_messages (List[Message[str]]): Sequence of `Message[str]` objects representing the
171
- prompts for this minibatch. Each message carries a unique `id`
172
- so we can restore ordering later.
173
-
174
- Returns:
175
- ParsedResponse containing `Response[ResponseFormat]` which in turn holds the
176
- assistant messages in arbitrary order.
177
-
178
- Raises:
179
- openai.RateLimitError: Transparently re‑raised after the
180
- exponential back‑off decorator exhausts all retries.
181
- """
182
- response_format = self.response_format
183
-
184
- class MessageT(BaseModel):
185
- id: int
186
- body: response_format # type: ignore
187
-
188
- class ResponseT(BaseModel):
189
- assistant_messages: List[MessageT]
190
-
191
- completion: ParsedResponse[ResponseT] = self.client.responses.parse(
192
- model=self.model_name,
193
- instructions=self._vectorized_system_message,
194
- input=Request(user_messages=user_messages).model_dump_json(),
195
- temperature=self.temperature,
196
- top_p=self.top_p,
197
- text_format=ResponseT,
198
- )
199
- return cast(ParsedResponse[Response[ResponseFormat]], completion)
200
-
201
- @observe(_LOGGER)
202
- def _predict_chunk(self, user_messages: List[str]) -> List[ResponseFormat]:
203
- """Helper executed for every unique minibatch.
204
-
205
- This method:
206
- 1. Converts plain strings into `Message[str]` with stable indices.
207
- 2. Delegates the request to `_request_llm`.
208
- 3. Reorders the responses so they match the original indices.
209
-
210
- The function is *pure* – it has no side‑effects and the result depends
211
- only on its arguments – which allows it to be used safely in both
212
- serial and parallel execution paths.
213
- """
214
- messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
215
- responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages)
216
- response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
217
- sorted_responses = [response_dict.get(m.id, None) for m in messages]
218
- return sorted_responses
219
-
220
- @observe(_LOGGER)
221
- def parse(self, inputs: List[str], batch_size: int) -> List[ResponseFormat]:
222
- """Public API: batched predict.
223
-
224
- Args:
225
- inputs (List[str]): All prompts that require a response. Duplicate
226
- entries are de‑duplicated under the hood to save tokens.
227
- batch_size (int): Maximum number of *unique* prompts per LLM call.
228
-
229
- Returns:
230
- A list containing the assistant responses in the same order as
231
- *inputs*.
232
- """
233
- return map(inputs, self._predict_chunk, batch_size)
234
-
235
-
236
- @dataclass(frozen=True)
237
- class AsyncBatchResponses(Generic[ResponseFormat]):
238
- """Stateless façade that turns OpenAI's JSON-mode API into a batched API (Async version).
239
-
240
- This wrapper allows you to submit *multiple* user prompts in one JSON-mode
241
- request and receive the answers in the original order asynchronously. It also
242
- controls the maximum number of concurrent requests to the OpenAI API.
243
-
244
- Example:
245
- ```python
246
- import asyncio
247
- from openai import AsyncOpenAI
248
- from openaivec.aio.responses import AsyncBatchResponses
249
-
250
- # Assuming openai_async_client is an initialized AsyncOpenAI client
251
- openai_async_client = AsyncOpenAI() # Replace with your actual client initialization
252
-
253
- vector_llm = AsyncBatchResponses(
254
- client=openai_async_client,
255
- model_name="gpt-4.1-mini",
256
- system_message="You are a helpful assistant.",
257
- max_concurrency=5 # Limit concurrent requests
258
- )
259
- questions = ["What is the capital of France?", "Explain quantum physics simply."]
260
- # Asynchronous call
261
- async def main():
262
- answers = await vector_llm.parse(questions, batch_size=32)
263
- print(answers)
264
-
265
- # Run the async function
266
- asyncio.run(main())
267
- ```
268
-
269
- Attributes:
270
- client: Initialised `openai.AsyncOpenAI` client.
271
- model_name: Name of the model (or Azure deployment) to invoke.
272
- system_message: System prompt prepended to every request.
273
- temperature: Sampling temperature passed to the model.
274
- top_p: Nucleus-sampling parameter.
275
- response_format: Expected Pydantic BaseModel subclass or str type for each assistant message
276
- (defaults to `str`).
277
- max_concurrency: Maximum number of concurrent requests to the OpenAI API.
278
- """
279
-
280
- client: AsyncOpenAI
281
- model_name: str # it would be the name of deployment for Azure
282
- system_message: str
283
- temperature: float = 0.0
284
- top_p: float = 1.0
285
- response_format: Type[ResponseFormat] = str
286
- max_concurrency: int = 8 # Default concurrency limit
287
- _vectorized_system_message: str = field(init=False)
288
- _model_json_schema: dict = field(init=False)
289
- _semaphore: asyncio.Semaphore = field(init=False, repr=False)
290
-
291
- @classmethod
292
- def of_task(
293
- cls, client: AsyncOpenAI, model_name: str, task: PreparedTask, max_concurrency: int = 8
294
- ) -> "AsyncBatchResponses":
295
- """Create an AsyncBatchResponses instance from a PreparedTask."""
296
- return cls(
297
- client=client,
298
- model_name=model_name,
299
- system_message=task.instructions,
300
- temperature=task.temperature,
301
- top_p=task.top_p,
302
- response_format=task.response_format,
303
- max_concurrency=max_concurrency,
304
- )
305
-
306
- def __post_init__(self):
307
- object.__setattr__(
308
- self,
309
- "_vectorized_system_message",
310
- _vectorize_system_message(self.system_message),
311
- )
312
- # Initialize the semaphore after the object is created
313
- # Use object.__setattr__ because the dataclass is frozen
314
- object.__setattr__(self, "_semaphore", asyncio.Semaphore(self.max_concurrency))
315
-
316
- @observe(_LOGGER)
317
- @backoff_async(exception=RateLimitError, scale=15, max_retries=8)
318
- async def _request_llm(self, user_messages: List[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
319
- """Make a single async call to the OpenAI *JSON mode* endpoint, respecting concurrency limits.
320
-
321
- Args:
322
- user_messages (List[Message[str]]): Sequence of `Message[str]` objects representing the
323
- prompts for this minibatch. Each message carries a unique `id`
324
- so we can restore ordering later.
325
-
326
- Returns:
327
- ParsedResponse containing `Response[ResponseFormat]` which in turn holds the
328
- assistant messages in arbitrary order.
329
-
330
- Raises:
331
- openai.RateLimitError: Transparently re-raised after the
332
- exponential back-off decorator exhausts all retries.
333
- """
334
- response_format = self.response_format
335
-
336
- class MessageT(BaseModel):
337
- id: int
338
- body: response_format # type: ignore
339
-
340
- class ResponseT(BaseModel):
341
- assistant_messages: List[MessageT]
342
-
343
- # Acquire semaphore before making the API call
344
- async with self._semaphore:
345
- # Directly await the async call instead of using asyncio.run()
346
- completion: ParsedResponse[ResponseT] = await self.client.responses.parse(
347
- model=self.model_name,
348
- instructions=self._vectorized_system_message,
349
- input=Request(user_messages=user_messages).model_dump_json(),
350
- temperature=self.temperature,
351
- top_p=self.top_p,
352
- text_format=ResponseT,
353
- )
354
- return cast(ParsedResponse[Response[ResponseFormat]], completion)
355
-
356
- @observe(_LOGGER)
357
- async def _predict_chunk(self, user_messages: List[str]) -> List[ResponseFormat]:
358
- """Helper executed asynchronously for every unique minibatch.
359
-
360
- This method:
361
- 1. Converts plain strings into `Message[str]` with stable indices.
362
- 2. Delegates the request to `_request_llm`.
363
- 3. Reorders the responses so they match the original indices.
364
-
365
- The function is *pure* – it has no side-effects and the result depends
366
- only on its arguments.
367
- """
368
- messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
369
- responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages)
370
- response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
371
- # Ensure proper handling for missing IDs - this shouldn't happen in normal operation
372
- sorted_responses = [response_dict.get(m.id, None) for m in messages]
373
- return sorted_responses
374
-
375
- @observe(_LOGGER)
376
- async def parse(self, inputs: List[str], batch_size: int) -> List[ResponseFormat]:
377
- """Asynchronous public API: batched predict.
378
-
379
- Args:
380
- inputs (List[str]): All prompts that require a response. Duplicate
381
- entries are de-duplicated under the hood to save tokens.
382
- batch_size (int): Maximum number of *unique* prompts per LLM call.
383
-
384
- Returns:
385
- A list containing the assistant responses in the same order as
386
- *inputs*.
387
- """
388
-
389
- return await map_async(
390
- inputs=inputs,
391
- f=self._predict_chunk,
392
- batch_size=batch_size, # Use the batch_size argument passed to the method
393
- )
openaivec/serialize.py DELETED
@@ -1,225 +0,0 @@
1
- """Serialization utilities for Pydantic BaseModel classes.
2
-
3
- This module provides utilities for converting Pydantic BaseModel classes
4
- to and from JSON schema representations. It supports dynamic model creation
5
- from JSON schemas with special handling for enum fields, which are converted
6
- to Literal types for better type safety and compatibility.
7
-
8
- Example:
9
- Basic serialization and deserialization:
10
-
11
- ```python
12
- from pydantic import BaseModel
13
- from typing import Literal
14
-
15
- class Status(BaseModel):
16
- value: Literal["active", "inactive"]
17
- description: str
18
-
19
- # Serialize to JSON schema
20
- schema = serialize_base_model(Status)
21
-
22
- # Deserialize back to BaseModel class
23
- DynamicStatus = deserialize_base_model(schema)
24
- instance = DynamicStatus(value="active", description="User is active")
25
- ```
26
- """
27
-
28
- from typing import Any, Dict, List, Type, Literal
29
-
30
- from pydantic import BaseModel, Field, create_model
31
-
32
- __all__ = ["deserialize_base_model", "serialize_base_model"]
33
-
34
-
35
- def serialize_base_model(obj: Type[BaseModel]) -> Dict[str, Any]:
36
- """Serialize a Pydantic BaseModel to JSON schema.
37
-
38
- Args:
39
- obj (Type[BaseModel]): The Pydantic BaseModel class to serialize.
40
-
41
- Returns:
42
- A dictionary containing the JSON schema representation of the model.
43
-
44
- Example:
45
- ```python
46
- from pydantic import BaseModel
47
-
48
- class Person(BaseModel):
49
- name: str
50
- age: int
51
-
52
- schema = serialize_base_model(Person)
53
- ```
54
- """
55
- return obj.model_json_schema()
56
-
57
-
58
- def dereference_json_schema(json_schema: Dict[str, Any]) -> Dict[str, Any]:
59
- """Dereference JSON schema by resolving $ref pointers.
60
-
61
- This function resolves all $ref references in a JSON schema by replacing
62
- them with the actual referenced definitions from the $defs section.
63
-
64
- Args:
65
- json_schema (Dict[str, Any]): The JSON schema containing potential $ref references.
66
-
67
- Returns:
68
- A dereferenced JSON schema with all $ref pointers resolved.
69
-
70
- Example:
71
- ```python
72
- schema = {
73
- "properties": {
74
- "user": {"$ref": "#/$defs/User"}
75
- },
76
- "$defs": {
77
- "User": {"type": "object", "properties": {"name": {"type": "string"}}}
78
- }
79
- }
80
- dereferenced = dereference_json_schema(schema)
81
- # user property will contain the actual User definition
82
- ```
83
- """
84
- model_map = json_schema.get("$defs", {})
85
-
86
- def dereference(obj):
87
- if isinstance(obj, dict):
88
- if "$ref" in obj:
89
- ref = obj["$ref"].split("/")[-1]
90
- return dereference(model_map[ref])
91
- else:
92
- return {k: dereference(v) for k, v in obj.items()}
93
-
94
- elif isinstance(obj, list):
95
- return [dereference(x) for x in obj]
96
- else:
97
- return obj
98
-
99
- result = {}
100
- for k, v in json_schema.items():
101
- if k == "$defs":
102
- continue
103
-
104
- result[k] = dereference(v)
105
-
106
- return result
107
-
108
-
109
- def parse_field(v: Dict[str, Any]) -> Any:
110
- """Parse a JSON schema field definition to a Python type.
111
-
112
- Converts JSON schema field definitions to corresponding Python types
113
- for use in Pydantic model creation.
114
-
115
- Args:
116
- v (Dict[str, Any]): A dictionary containing the JSON schema field definition.
117
-
118
- Returns:
119
- The corresponding Python type (str, int, float, bool, dict, List, or BaseModel).
120
-
121
- Raises:
122
- ValueError: If the field type is not supported.
123
-
124
- Example:
125
- ```python
126
- field_def = {"type": "string"}
127
- python_type = parse_field(field_def) # Returns str
128
-
129
- array_def = {"type": "array", "items": {"type": "integer"}}
130
- python_type = parse_field(array_def) # Returns List[int]
131
- ```
132
- """
133
- t = v["type"]
134
- if t == "string":
135
- return str
136
- elif t == "integer":
137
- return int
138
- elif t == "number":
139
- return float
140
- elif t == "boolean":
141
- return bool
142
- elif t == "object":
143
- # Check if it's a generic object (dict) or a nested model
144
- if "properties" in v:
145
- return deserialize_base_model(v)
146
- else:
147
- return dict
148
- elif t == "array":
149
- inner_type = parse_field(v["items"])
150
- return List[inner_type]
151
- else:
152
- raise ValueError(f"Unsupported type: {t}")
153
-
154
-
155
- def deserialize_base_model(json_schema: Dict[str, Any]) -> Type[BaseModel]:
156
- """Deserialize a JSON schema to a Pydantic BaseModel class.
157
-
158
- Creates a dynamic Pydantic BaseModel class from a JSON schema definition.
159
- For enum fields, this function uses Literal types instead of Enum classes
160
- for better type safety and compatibility with systems like Apache Spark.
161
-
162
- Args:
163
- json_schema (Dict[str, Any]): A dictionary containing the JSON schema definition.
164
-
165
- Returns:
166
- A dynamically created Pydantic BaseModel class.
167
-
168
- Example:
169
- ```python
170
- schema = {
171
- "title": "Person",
172
- "type": "object",
173
- "properties": {
174
- "name": {"type": "string", "description": "Person's name"},
175
- "status": {
176
- "type": "string",
177
- "enum": ["active", "inactive"],
178
- "description": "Person's status"
179
- }
180
- }
181
- }
182
-
183
- PersonModel = deserialize_base_model(schema)
184
- person = PersonModel(name="John", status="active")
185
- ```
186
-
187
- Note:
188
- Enum fields are converted to Literal types for improved compatibility
189
- and type safety. This ensures better integration with data processing
190
- frameworks like Apache Spark.
191
- """
192
- fields = {}
193
- properties = dereference_json_schema(json_schema).get("properties", {})
194
-
195
- for k, v in properties.items():
196
- if "enum" in v:
197
- enum_values = v["enum"]
198
-
199
- # Always use Literal instead of Enum for better type safety and Spark compatibility
200
- if len(enum_values) == 1:
201
- literal_type = Literal[enum_values[0]]
202
- else:
203
- # Create Literal with multiple values
204
- literal_type = Literal[tuple(enum_values)]
205
-
206
- description = v.get("description")
207
- default_value = v.get("default")
208
-
209
- if default_value is not None:
210
- field_info = Field(default=default_value, description=description) if description is not None else Field(default=default_value)
211
- else:
212
- field_info = Field(description=description) if description is not None else Field()
213
-
214
- fields[k] = (literal_type, field_info)
215
- else:
216
- description = v.get("description")
217
- default_value = v.get("default")
218
-
219
- if default_value is not None:
220
- field_info = Field(default=default_value, description=description) if description is not None else Field(default=default_value)
221
- else:
222
- field_info = Field(description=description) if description is not None else Field()
223
-
224
- fields[k] = (parse_field(v), field_info)
225
- return create_model(json_schema["title"], **fields)