openaivec 0.10.0__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/__init__.py +13 -4
- openaivec/_cache/__init__.py +12 -0
- openaivec/_cache/optimize.py +109 -0
- openaivec/_cache/proxy.py +806 -0
- openaivec/_di.py +326 -0
- openaivec/_embeddings.py +203 -0
- openaivec/{log.py → _log.py} +2 -2
- openaivec/_model.py +113 -0
- openaivec/{prompt.py → _prompt.py} +95 -28
- openaivec/_provider.py +207 -0
- openaivec/_responses.py +511 -0
- openaivec/_schema/__init__.py +9 -0
- openaivec/_schema/infer.py +340 -0
- openaivec/_schema/spec.py +350 -0
- openaivec/_serialize.py +234 -0
- openaivec/{util.py → _util.py} +25 -85
- openaivec/pandas_ext.py +1635 -425
- openaivec/spark.py +604 -335
- openaivec/task/__init__.py +27 -29
- openaivec/task/customer_support/__init__.py +9 -15
- openaivec/task/customer_support/customer_sentiment.py +51 -41
- openaivec/task/customer_support/inquiry_classification.py +86 -61
- openaivec/task/customer_support/inquiry_summary.py +44 -45
- openaivec/task/customer_support/intent_analysis.py +56 -41
- openaivec/task/customer_support/response_suggestion.py +49 -43
- openaivec/task/customer_support/urgency_analysis.py +76 -71
- openaivec/task/nlp/__init__.py +4 -4
- openaivec/task/nlp/dependency_parsing.py +19 -20
- openaivec/task/nlp/keyword_extraction.py +22 -24
- openaivec/task/nlp/morphological_analysis.py +25 -25
- openaivec/task/nlp/named_entity_recognition.py +26 -28
- openaivec/task/nlp/sentiment_analysis.py +29 -21
- openaivec/task/nlp/translation.py +24 -30
- openaivec/task/table/__init__.py +3 -0
- openaivec/task/table/fillna.py +183 -0
- openaivec-1.0.10.dist-info/METADATA +399 -0
- openaivec-1.0.10.dist-info/RECORD +39 -0
- {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
- openaivec/embeddings.py +0 -172
- openaivec/responses.py +0 -392
- openaivec/serialize.py +0 -225
- openaivec/task/model.py +0 -84
- openaivec-0.10.0.dist-info/METADATA +0 -546
- openaivec-0.10.0.dist-info/RECORD +0 -29
- {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
openaivec/_serialize.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Refactored serialization utilities for Pydantic BaseModel classes.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for converting Pydantic BaseModel classes
|
|
4
|
+
to and from JSON schema representations with simplified, maintainable code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, create_model
|
|
10
|
+
|
|
11
|
+
__all__ = []
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def serialize_base_model(obj: type[BaseModel]) -> dict[str, Any]:
|
|
15
|
+
"""Serialize a Pydantic BaseModel to JSON schema."""
|
|
16
|
+
return obj.model_json_schema()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def dereference_json_schema(json_schema: dict[str, Any]) -> dict[str, Any]:
|
|
20
|
+
"""Dereference JSON schema by resolving $ref pointers with circular reference protection."""
|
|
21
|
+
model_map = json_schema.get("$defs", {})
|
|
22
|
+
|
|
23
|
+
def dereference(obj, current_path=None):
|
|
24
|
+
if current_path is None:
|
|
25
|
+
current_path = []
|
|
26
|
+
|
|
27
|
+
if isinstance(obj, dict):
|
|
28
|
+
if "$ref" in obj:
|
|
29
|
+
ref = obj["$ref"].split("/")[-1]
|
|
30
|
+
|
|
31
|
+
# Check for circular reference
|
|
32
|
+
if ref in current_path:
|
|
33
|
+
# Return a placeholder to break the cycle
|
|
34
|
+
return {"type": "object", "description": f"Circular reference to {ref}"}
|
|
35
|
+
|
|
36
|
+
if ref in model_map:
|
|
37
|
+
# Add to path and recurse
|
|
38
|
+
new_path = current_path + [ref]
|
|
39
|
+
return dereference(model_map[ref], new_path)
|
|
40
|
+
else:
|
|
41
|
+
# Invalid reference, return placeholder
|
|
42
|
+
return {"type": "object", "description": f"Invalid reference to {ref}"}
|
|
43
|
+
else:
|
|
44
|
+
return {k: dereference(v, current_path) for k, v in obj.items()}
|
|
45
|
+
elif isinstance(obj, list):
|
|
46
|
+
return [dereference(x, current_path) for x in obj]
|
|
47
|
+
else:
|
|
48
|
+
return obj
|
|
49
|
+
|
|
50
|
+
result = {}
|
|
51
|
+
for k, v in json_schema.items():
|
|
52
|
+
if k == "$defs":
|
|
53
|
+
continue
|
|
54
|
+
result[k] = dereference(v)
|
|
55
|
+
|
|
56
|
+
return result
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ============================================================================
|
|
60
|
+
# Type Resolution - Separated into focused functions
|
|
61
|
+
# ============================================================================
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _resolve_union_type(union_options: list[dict[str, Any]]) -> type:
|
|
65
|
+
"""Resolve anyOf/oneOf to Union type."""
|
|
66
|
+
union_types = []
|
|
67
|
+
for option in union_options:
|
|
68
|
+
if option.get("type") == "null":
|
|
69
|
+
union_types.append(type(None))
|
|
70
|
+
else:
|
|
71
|
+
union_types.append(parse_field(option))
|
|
72
|
+
|
|
73
|
+
if len(union_types) == 1:
|
|
74
|
+
return union_types[0]
|
|
75
|
+
elif len(union_types) == 2 and type(None) in union_types:
|
|
76
|
+
# Optional type: T | None
|
|
77
|
+
non_none_type = next(t for t in union_types if t is not type(None))
|
|
78
|
+
return non_none_type | None # type: ignore[return-value]
|
|
79
|
+
else:
|
|
80
|
+
from typing import Union
|
|
81
|
+
|
|
82
|
+
return Union[tuple(union_types)] # type: ignore[return-value]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _resolve_basic_type(type_name: str, field_def: dict[str, Any]) -> type:
|
|
86
|
+
"""Resolve basic JSON schema types to Python types."""
|
|
87
|
+
type_mapping = {
|
|
88
|
+
"string": str,
|
|
89
|
+
"integer": int,
|
|
90
|
+
"number": float,
|
|
91
|
+
"boolean": bool,
|
|
92
|
+
"null": type(None),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
if type_name in type_mapping:
|
|
96
|
+
return type_mapping[type_name] # type: ignore[return-value]
|
|
97
|
+
elif type_name == "object":
|
|
98
|
+
# Check if it's a nested model or generic dict
|
|
99
|
+
if "properties" in field_def:
|
|
100
|
+
return deserialize_base_model(field_def)
|
|
101
|
+
else:
|
|
102
|
+
return dict
|
|
103
|
+
elif type_name == "array":
|
|
104
|
+
if "items" in field_def:
|
|
105
|
+
inner_type = parse_field(field_def["items"])
|
|
106
|
+
return list[inner_type]
|
|
107
|
+
else:
|
|
108
|
+
return list[Any]
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError(f"Unsupported type: {type_name}")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def parse_field(field_def: dict[str, Any]) -> type:
|
|
114
|
+
"""Parse a JSON schema field definition to a Python type.
|
|
115
|
+
|
|
116
|
+
Simplified version with clear separation of concerns.
|
|
117
|
+
"""
|
|
118
|
+
# Handle union types
|
|
119
|
+
if "anyOf" in field_def:
|
|
120
|
+
return _resolve_union_type(field_def["anyOf"])
|
|
121
|
+
if "oneOf" in field_def:
|
|
122
|
+
return _resolve_union_type(field_def["oneOf"])
|
|
123
|
+
|
|
124
|
+
# Handle basic types
|
|
125
|
+
if "type" not in field_def:
|
|
126
|
+
return Any # type: ignore[return-value]
|
|
127
|
+
|
|
128
|
+
return _resolve_basic_type(field_def["type"], field_def)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ============================================================================
|
|
132
|
+
# Field Information Creation - Centralized logic
|
|
133
|
+
# ============================================================================
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _create_field_info(description: str | None, default_value: Any, is_required: bool) -> Field: # type: ignore[type-arg]
|
|
137
|
+
"""Create Field info with consistent logic."""
|
|
138
|
+
if is_required and default_value is None:
|
|
139
|
+
# Required field without default
|
|
140
|
+
return Field(description=description) if description else Field()
|
|
141
|
+
else:
|
|
142
|
+
# Optional field or field with default
|
|
143
|
+
return Field(default=default_value, description=description) if description else Field(default=default_value)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _make_optional_if_needed(field_type: type, is_required: bool, has_default: bool) -> type:
|
|
147
|
+
"""Make field type optional if needed."""
|
|
148
|
+
if is_required or has_default:
|
|
149
|
+
return field_type
|
|
150
|
+
|
|
151
|
+
# Check if already nullable
|
|
152
|
+
from typing import Union
|
|
153
|
+
|
|
154
|
+
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union and type(None) in field_type.__args__:
|
|
155
|
+
return field_type
|
|
156
|
+
|
|
157
|
+
# Make optional
|
|
158
|
+
return field_type | None # type: ignore[return-value]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ============================================================================
|
|
162
|
+
# Field Processing - Separated enum and regular field logic
|
|
163
|
+
# ============================================================================
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _process_enum_field(field_name: str, field_def: dict[str, Any], is_required: bool) -> tuple[type, Field]: # type: ignore[type-arg]
|
|
167
|
+
"""Process enum field with Literal type."""
|
|
168
|
+
enum_values = field_def["enum"]
|
|
169
|
+
|
|
170
|
+
# Create Literal type
|
|
171
|
+
if len(enum_values) == 1:
|
|
172
|
+
literal_type = Literal[enum_values[0]]
|
|
173
|
+
else:
|
|
174
|
+
literal_type = Literal[tuple(enum_values)]
|
|
175
|
+
|
|
176
|
+
# Handle optionality
|
|
177
|
+
description = field_def.get("description")
|
|
178
|
+
default_value = field_def.get("default")
|
|
179
|
+
has_default = default_value is not None
|
|
180
|
+
|
|
181
|
+
if not is_required and not has_default:
|
|
182
|
+
literal_type = literal_type | None # type: ignore[assignment]
|
|
183
|
+
default_value = None
|
|
184
|
+
|
|
185
|
+
field_info = _create_field_info(description, default_value, is_required)
|
|
186
|
+
return literal_type, field_info # type: ignore[return-value]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _process_regular_field(field_name: str, field_def: dict[str, Any], is_required: bool) -> tuple[type, Field]: # type: ignore[type-arg]
|
|
190
|
+
"""Process regular (non-enum) field."""
|
|
191
|
+
field_type = parse_field(field_def)
|
|
192
|
+
description = field_def.get("description")
|
|
193
|
+
default_value = field_def.get("default")
|
|
194
|
+
has_default = default_value is not None
|
|
195
|
+
|
|
196
|
+
# Handle optionality
|
|
197
|
+
field_type = _make_optional_if_needed(field_type, is_required, has_default)
|
|
198
|
+
|
|
199
|
+
if not is_required and not has_default:
|
|
200
|
+
default_value = None
|
|
201
|
+
|
|
202
|
+
field_info = _create_field_info(description, default_value, is_required)
|
|
203
|
+
return field_type, field_info
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# ============================================================================
|
|
207
|
+
# Main Schema Processing - Clean and focused
|
|
208
|
+
# ============================================================================
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def deserialize_base_model(json_schema: dict[str, Any]) -> type[BaseModel]:
|
|
212
|
+
"""Deserialize a JSON schema to a Pydantic BaseModel class.
|
|
213
|
+
|
|
214
|
+
Refactored version with clear separation of concerns and simplified logic.
|
|
215
|
+
"""
|
|
216
|
+
# Basic setup
|
|
217
|
+
title = json_schema.get("title", "DynamicModel")
|
|
218
|
+
dereferenced_schema = dereference_json_schema(json_schema)
|
|
219
|
+
properties = dereferenced_schema.get("properties", {})
|
|
220
|
+
required_fields = set(dereferenced_schema.get("required", []))
|
|
221
|
+
|
|
222
|
+
# Process each field
|
|
223
|
+
fields = {}
|
|
224
|
+
for field_name, field_def in properties.items():
|
|
225
|
+
is_required = field_name in required_fields
|
|
226
|
+
|
|
227
|
+
if "enum" in field_def:
|
|
228
|
+
field_type, field_info = _process_enum_field(field_name, field_def, is_required)
|
|
229
|
+
else:
|
|
230
|
+
field_type, field_info = _process_regular_field(field_name, field_def, is_required)
|
|
231
|
+
|
|
232
|
+
fields[field_name] = (field_type, field_info)
|
|
233
|
+
|
|
234
|
+
return create_model(title, **fields)
|
openaivec/{util.py → _util.py}
RENAMED
|
@@ -2,12 +2,14 @@ import asyncio
|
|
|
2
2
|
import functools
|
|
3
3
|
import re
|
|
4
4
|
import time
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
5
6
|
from dataclasses import dataclass
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import TypeVar
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import tiktoken
|
|
10
11
|
|
|
12
|
+
__all__ = []
|
|
11
13
|
|
|
12
14
|
T = TypeVar("T")
|
|
13
15
|
U = TypeVar("U")
|
|
@@ -34,24 +36,28 @@ def get_exponential_with_cutoff(scale: float) -> float:
|
|
|
34
36
|
return v
|
|
35
37
|
|
|
36
38
|
|
|
37
|
-
def backoff(
|
|
39
|
+
def backoff(
|
|
40
|
+
exceptions: list[type[Exception]],
|
|
41
|
+
scale: int | None = None,
|
|
42
|
+
max_retries: int | None = None,
|
|
43
|
+
) -> Callable[..., V]:
|
|
38
44
|
"""Decorator implementing exponential back‑off retry logic.
|
|
39
45
|
|
|
40
46
|
Args:
|
|
41
|
-
|
|
47
|
+
exceptions (list[type[Exception]]): List of exception types that trigger a retry.
|
|
42
48
|
scale (int | None): Initial scale parameter for the exponential jitter.
|
|
43
49
|
This scale is used as the mean for the first delay's exponential
|
|
44
50
|
distribution and doubles with each subsequent retry. If ``None``,
|
|
45
51
|
an initial scale of 1.0 is used.
|
|
46
|
-
max_retries (
|
|
52
|
+
max_retries (int | None): Maximum number of retries. ``None`` means
|
|
47
53
|
retry indefinitely.
|
|
48
54
|
|
|
49
55
|
Returns:
|
|
50
56
|
Callable[..., V]: A decorated function that retries on the specified
|
|
51
|
-
|
|
57
|
+
exceptions with exponential back‑off.
|
|
52
58
|
|
|
53
59
|
Raises:
|
|
54
|
-
|
|
60
|
+
Exception: Re‑raised when the maximum number of retries is exceeded.
|
|
55
61
|
"""
|
|
56
62
|
|
|
57
63
|
def decorator(func: Callable[..., V]) -> Callable[..., V]:
|
|
@@ -65,7 +71,7 @@ def backoff(exception: type[Exception], scale: int | None = None, max_retries: i
|
|
|
65
71
|
while True:
|
|
66
72
|
try:
|
|
67
73
|
return func(*args, **kwargs)
|
|
68
|
-
except
|
|
74
|
+
except tuple(exceptions):
|
|
69
75
|
attempt += 1
|
|
70
76
|
if max_retries is not None and attempt >= max_retries:
|
|
71
77
|
raise
|
|
@@ -79,16 +85,18 @@ def backoff(exception: type[Exception], scale: int | None = None, max_retries: i
|
|
|
79
85
|
|
|
80
86
|
return wrapper
|
|
81
87
|
|
|
82
|
-
return decorator
|
|
88
|
+
return decorator # type: ignore[return-value]
|
|
83
89
|
|
|
84
90
|
|
|
85
91
|
def backoff_async(
|
|
86
|
-
|
|
92
|
+
exceptions: list[type[Exception]],
|
|
93
|
+
scale: int | None = None,
|
|
94
|
+
max_retries: int | None = None,
|
|
87
95
|
) -> Callable[..., Awaitable[V]]:
|
|
88
96
|
"""Asynchronous version of the backoff decorator.
|
|
89
97
|
|
|
90
98
|
Args:
|
|
91
|
-
|
|
99
|
+
exceptions (list[type[Exception]]): List of exception types that trigger a retry.
|
|
92
100
|
scale (int | None): Initial scale parameter for the exponential jitter.
|
|
93
101
|
This scale is used as the mean for the first delay's exponential
|
|
94
102
|
distribution and doubles with each subsequent retry. If ``None``,
|
|
@@ -98,10 +106,10 @@ def backoff_async(
|
|
|
98
106
|
|
|
99
107
|
Returns:
|
|
100
108
|
Callable[..., Awaitable[V]]: A decorated asynchronous function that
|
|
101
|
-
retries on the specified
|
|
109
|
+
retries on the specified exceptions with exponential back‑off.
|
|
102
110
|
|
|
103
111
|
Raises:
|
|
104
|
-
|
|
112
|
+
Exception: Re‑raised when the maximum number of retries is exceeded.
|
|
105
113
|
"""
|
|
106
114
|
|
|
107
115
|
def decorator(func: Callable[..., Awaitable[V]]) -> Callable[..., Awaitable[V]]:
|
|
@@ -115,7 +123,7 @@ def backoff_async(
|
|
|
115
123
|
while True:
|
|
116
124
|
try:
|
|
117
125
|
return await func(*args, **kwargs)
|
|
118
|
-
except
|
|
126
|
+
except tuple(exceptions):
|
|
119
127
|
attempt += 1
|
|
120
128
|
if max_retries is not None and attempt >= max_retries:
|
|
121
129
|
raise
|
|
@@ -129,7 +137,7 @@ def backoff_async(
|
|
|
129
137
|
|
|
130
138
|
return wrapper
|
|
131
139
|
|
|
132
|
-
return decorator
|
|
140
|
+
return decorator # type: ignore[return-value]
|
|
133
141
|
|
|
134
142
|
|
|
135
143
|
@dataclass(frozen=True)
|
|
@@ -138,7 +146,7 @@ class TextChunker:
|
|
|
138
146
|
|
|
139
147
|
enc: tiktoken.Encoding
|
|
140
148
|
|
|
141
|
-
def split(self, original: str, max_tokens: int, sep:
|
|
149
|
+
def split(self, original: str, max_tokens: int, sep: list[str]) -> list[str]:
|
|
142
150
|
"""Token‑aware sentence segmentation.
|
|
143
151
|
|
|
144
152
|
The text is first split by the given separators, then greedily packed
|
|
@@ -147,11 +155,11 @@ class TextChunker:
|
|
|
147
155
|
Args:
|
|
148
156
|
original (str): Original text to split.
|
|
149
157
|
max_tokens (int): Maximum number of tokens allowed per chunk.
|
|
150
|
-
sep (
|
|
158
|
+
sep (list[str]): List of separator patterns used by
|
|
151
159
|
:pyfunc:`re.split`.
|
|
152
160
|
|
|
153
161
|
Returns:
|
|
154
|
-
|
|
162
|
+
list[str]: List of text chunks respecting the ``max_tokens`` limit.
|
|
155
163
|
"""
|
|
156
164
|
sentences = re.split(f"({'|'.join(sep)})", original)
|
|
157
165
|
sentences = [s.strip() for s in sentences if s.strip()]
|
|
@@ -174,71 +182,3 @@ class TextChunker:
|
|
|
174
182
|
chunks.append(sentence)
|
|
175
183
|
|
|
176
184
|
return chunks
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
async def map_async(inputs: List[T], f: Callable[[List[T]], Awaitable[List[U]]], batch_size: int = 128) -> List[U]:
|
|
180
|
-
"""Asynchronously map a function `f` over a list of inputs in batches.
|
|
181
|
-
|
|
182
|
-
This function divides the input list into smaller batches and applies the
|
|
183
|
-
asynchronous function `f` to each batch concurrently. It gathers the results
|
|
184
|
-
and returns them in the same order as the original inputs.
|
|
185
|
-
|
|
186
|
-
Args:
|
|
187
|
-
inputs (List[T]): List of inputs to be processed.
|
|
188
|
-
f (Callable[[List[T]], Awaitable[List[U]]]): Asynchronous function to apply.
|
|
189
|
-
It takes a batch of inputs (List[T]) and must return a list of
|
|
190
|
-
corresponding outputs (List[U]) of the same size.
|
|
191
|
-
batch_size (int): Size of each batch for processing.
|
|
192
|
-
|
|
193
|
-
Returns:
|
|
194
|
-
List[U]: List of outputs corresponding to the original inputs, in order.
|
|
195
|
-
"""
|
|
196
|
-
original_hashes: List[int] = [hash(str(v)) for v in inputs] # Use str(v) for hash if T is not hashable
|
|
197
|
-
hash_inputs: Dict[int, T] = {k: v for k, v in zip(original_hashes, inputs)}
|
|
198
|
-
unique_hashes: List[int] = list(hash_inputs.keys())
|
|
199
|
-
unique_inputs: List[T] = list(hash_inputs.values())
|
|
200
|
-
input_batches: List[List[T]] = [unique_inputs[i : i + batch_size] for i in range(0, len(unique_inputs), batch_size)]
|
|
201
|
-
# Ensure f is awaited correctly within gather
|
|
202
|
-
tasks = [f(batch) for batch in input_batches]
|
|
203
|
-
output_batches: List[List[U]] = await asyncio.gather(*tasks)
|
|
204
|
-
unique_outputs: List[U] = [u for batch in output_batches for u in batch]
|
|
205
|
-
if len(unique_hashes) != len(unique_outputs):
|
|
206
|
-
raise ValueError(
|
|
207
|
-
f"Number of unique inputs ({len(unique_hashes)}) does not match number of unique outputs ({len(unique_outputs)}). Check the function f."
|
|
208
|
-
)
|
|
209
|
-
hash_outputs: Dict[int, U] = {k: v for k, v in zip(unique_hashes, unique_outputs)}
|
|
210
|
-
outputs: List[U] = [hash_outputs[k] for k in original_hashes]
|
|
211
|
-
return outputs
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def map(inputs: List[T], f: Callable[[List[T]], List[U]], batch_size: int = 128) -> List[U]:
|
|
215
|
-
"""Map a function `f` over a list of inputs in batches.
|
|
216
|
-
|
|
217
|
-
This function divides the input list into smaller batches and applies the
|
|
218
|
-
function `f` to each batch. It gathers the results and returns them in the
|
|
219
|
-
same order as the original inputs.
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
inputs (List[T]): List of inputs to be processed.
|
|
223
|
-
f (Callable[[List[T]], List[U]]): Function to apply. It takes a batch of
|
|
224
|
-
inputs (List[T]) and must return a list of corresponding outputs
|
|
225
|
-
(List[U]) of the same size.
|
|
226
|
-
batch_size (int): Size of each batch for processing.
|
|
227
|
-
|
|
228
|
-
Returns:
|
|
229
|
-
List[U]: List of outputs corresponding to the original inputs, in order.
|
|
230
|
-
"""
|
|
231
|
-
original_hashes: List[int] = [hash(str(v)) for v in inputs] # Use str(v) for hash if T is not hashable
|
|
232
|
-
hash_inputs: Dict[int, T] = {k: v for k, v in zip(original_hashes, inputs)}
|
|
233
|
-
unique_hashes: List[int] = list(hash_inputs.keys())
|
|
234
|
-
unique_inputs: List[T] = list(hash_inputs.values())
|
|
235
|
-
input_batches: List[List[T]] = [unique_inputs[i : i + batch_size] for i in range(0, len(unique_inputs), batch_size)]
|
|
236
|
-
output_batches: List[List[U]] = [f(batch) for batch in input_batches]
|
|
237
|
-
unique_outputs: List[U] = [u for batch in output_batches for u in batch]
|
|
238
|
-
if len(unique_hashes) != len(unique_outputs):
|
|
239
|
-
raise ValueError(
|
|
240
|
-
f"Number of unique inputs ({len(unique_hashes)}) does not match number of unique outputs ({len(unique_outputs)}). Check the function f."
|
|
241
|
-
)
|
|
242
|
-
hash_outputs: Dict[int, U] = {k: v for k, v in zip(unique_hashes, unique_outputs)}
|
|
243
|
-
outputs: List[U] = [hash_outputs[k] for k in original_hashes]
|
|
244
|
-
return outputs
|