openaivec 0.13.3__tar.gz → 0.13.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openaivec-0.13.4/.github/copilot-instructions.md +199 -0
  2. {openaivec-0.13.3 → openaivec-0.13.4}/PKG-INFO +1 -1
  3. {openaivec-0.13.3 → openaivec-0.13.4}/mkdocs.yml +1 -0
  4. openaivec-0.13.4/src/openaivec/optimize.py +108 -0
  5. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/proxy.py +73 -90
  6. openaivec-0.13.4/tests/test_optimize.py +318 -0
  7. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_proxy.py +4 -6
  8. openaivec-0.13.4/tests/test_proxy_suggester.py +201 -0
  9. {openaivec-0.13.3 → openaivec-0.13.4}/.env.example +0 -0
  10. {openaivec-0.13.3 → openaivec-0.13.4}/.github/workflows/python-mkdocs.yml +0 -0
  11. {openaivec-0.13.3 → openaivec-0.13.4}/.github/workflows/python-package.yml +0 -0
  12. {openaivec-0.13.3 → openaivec-0.13.4}/.github/workflows/python-test.yml +0 -0
  13. {openaivec-0.13.3 → openaivec-0.13.4}/.github/workflows/python-update.yml +0 -0
  14. {openaivec-0.13.3 → openaivec-0.13.4}/.gitignore +0 -0
  15. {openaivec-0.13.3 → openaivec-0.13.4}/CODE_OF_CONDUCT.md +0 -0
  16. {openaivec-0.13.3 → openaivec-0.13.4}/LICENSE +0 -0
  17. {openaivec-0.13.3 → openaivec-0.13.4}/README.md +0 -0
  18. {openaivec-0.13.3 → openaivec-0.13.4}/SECURITY.md +0 -0
  19. {openaivec-0.13.3 → openaivec-0.13.4}/SUPPORT.md +0 -0
  20. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/di.md +0 -0
  21. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/embeddings.md +0 -0
  22. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/pandas_ext.md +0 -0
  23. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/prompt.md +0 -0
  24. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/proxy.md +0 -0
  25. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/responses.md +0 -0
  26. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/spark.md +0 -0
  27. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/task.md +0 -0
  28. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/customer_sentiment.md +0 -0
  29. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/inquiry_classification.md +0 -0
  30. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/inquiry_summary.md +0 -0
  31. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/intent_analysis.md +0 -0
  32. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/response_suggestion.md +0 -0
  33. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/customer_support/urgency_analysis.md +0 -0
  34. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/dependency_parsing.md +0 -0
  35. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/keyword_extraction.md +0 -0
  36. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/morphological_analysis.md +0 -0
  37. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/named_entity_recognition.md +0 -0
  38. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/sentiment_analysis.md +0 -0
  39. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/tasks/nlp/translation.md +0 -0
  40. {openaivec-0.13.3 → openaivec-0.13.4}/docs/api/util.md +0 -0
  41. {openaivec-0.13.3 → openaivec-0.13.4}/docs/index.md +0 -0
  42. {openaivec-0.13.3 → openaivec-0.13.4}/docs/robots.txt +0 -0
  43. {openaivec-0.13.3 → openaivec-0.13.4}/pyproject.toml +0 -0
  44. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/__init__.py +0 -0
  45. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/di.py +0 -0
  46. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/embeddings.py +0 -0
  47. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/log.py +0 -0
  48. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/model.py +0 -0
  49. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/pandas_ext.py +0 -0
  50. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/prompt.py +0 -0
  51. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/provider.py +0 -0
  52. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/responses.py +0 -0
  53. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/serialize.py +0 -0
  54. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/spark.py +0 -0
  55. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/__init__.py +0 -0
  56. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/__init__.py +0 -0
  57. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/customer_sentiment.py +0 -0
  58. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/inquiry_classification.py +0 -0
  59. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/inquiry_summary.py +0 -0
  60. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/intent_analysis.py +0 -0
  61. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/response_suggestion.py +0 -0
  62. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/customer_support/urgency_analysis.py +0 -0
  63. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/__init__.py +0 -0
  64. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/dependency_parsing.py +0 -0
  65. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/keyword_extraction.py +0 -0
  66. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/morphological_analysis.py +0 -0
  67. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/named_entity_recognition.py +0 -0
  68. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/sentiment_analysis.py +0 -0
  69. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/nlp/translation.py +0 -0
  70. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/table/__init__.py +0 -0
  71. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/task/table/fillna.py +0 -0
  72. {openaivec-0.13.3 → openaivec-0.13.4}/src/openaivec/util.py +0 -0
  73. {openaivec-0.13.3 → openaivec-0.13.4}/tests/__init__.py +0 -0
  74. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_di.py +0 -0
  75. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_embeddings.py +0 -0
  76. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_pandas_ext.py +0 -0
  77. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_prompt.py +0 -0
  78. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_provider.py +0 -0
  79. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_responses.py +0 -0
  80. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_serialize.py +0 -0
  81. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_spark.py +0 -0
  82. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_task.py +0 -0
  83. {openaivec-0.13.3 → openaivec-0.13.4}/tests/test_util.py +0 -0
  84. {openaivec-0.13.3 → openaivec-0.13.4}/uv.lock +0 -0
@@ -0,0 +1,199 @@
1
+ # Copilot instructions for openaivec
2
+
3
+ This repository-wide guide tells GitHub Copilot how to propose code that fits our architecture, APIs, style, tests, and docs. Prefer these rules when completing or generating code.
4
+
5
+ ## Project overview
6
+
7
+ - Goal: Provide a vectorized (batched) interface to OpenAI/Azure OpenAI so pandas/Spark can process large text corpora with high throughput.
8
+ - Public API exports (`src/openaivec/__init__.py`):
9
+ - Responses: `BatchResponses`, `AsyncBatchResponses`
10
+ - Embeddings: `BatchEmbeddings`, `AsyncBatchEmbeddings`
11
+ - First-class pandas extensions (`.ai` / `.aio`) and Spark UDF builders
12
+ - Azure OpenAI is supported with the same APIs (use deployment name as the “model” for Azure)
13
+
14
+ ## Architecture and roles
15
+
16
+ - `src/openaivec/proxy.py`
17
+ - Core batching, deduplication, order preservation, and caching
18
+ - `BatchingMapProxy[S, T]` (sync) / `AsyncBatchingMapProxy[S, T]` (async)
19
+ - The map_func contract is strict: return a list of the same length and order as the inputs
20
+ - Progress bars only in notebook environments via `tqdm.auto`, gated by `show_progress=True`
21
+ - `src/openaivec/responses.py`
22
+ - Batched wrapper over OpenAI Responses JSON-mode API
23
+ - `BatchResponses` / `AsyncBatchResponses` use the proxy internally
24
+ - Retries via `backoff`/`backoff_async` for transient errors (RateLimit, 5xx)
25
+ - Reasoning models (o1/o3 family) must use `temperature=None`; helpful guidance on errors
26
+ - `src/openaivec/embeddings.py`
27
+ - Batched embeddings (sync/async)
28
+ - `src/openaivec/pandas_ext.py`
29
+ - `Series.ai` / `Series.aio` entry points for responses/embeddings
30
+ - Uses DI container (`provider.CONTAINER`) to get client and model names
31
+ - Supports batch size, progress, and cache sharing (`*_with_cache`)
32
+ - `src/openaivec/spark.py`
33
+ - UDF builders: `responses_udf` / `task_udf` / `embeddings_udf` / `count_tokens_udf` / `split_to_chunks_udf`
34
+ - Per-partition duplicate caching to reduce API calls
35
+ - Pydantic → Spark StructType schema conversion
36
+ - `src/openaivec/provider.py`
37
+ - DI container and automatic OpenAI/Azure OpenAI client provisioning
38
+ - Warns if Azure base URL isn’t v1 format
39
+ - `src/openaivec/util.py`
40
+ - `backoff` / `backoff_async` and `TextChunker`
41
+ - Additional modules from CLAUDE.md
42
+ - `src/openaivec/di.py`: lightweight DI container
43
+ - `src/openaivec/log.py`: logging/observe helpers
44
+ - `src/openaivec/prompt.py`: few-shot prompt building
45
+ - `src/openaivec/serialize.py`: Pydantic schema (de)serialization
46
+ - `src/openaivec/task/`: pre-built, structured task library
47
+
48
+ ## Dev commands (uv)
49
+
50
+ ```bash
51
+ # Install all dependencies (dev + extras)
52
+ uv sync --all-extras --dev
53
+
54
+ # Install in editable mode
55
+ uv pip install -e .
56
+
57
+ # Lint and format
58
+ uv run ruff check . --fix && uv run ruff format .
59
+
60
+ # Run tests
61
+ uv run pytest
62
+
63
+ # Build/serve docs
64
+ uv run mkdocs serve
65
+ ```
66
+
67
+ ## Coding standards (Ruff/types/style)
68
+
69
+ - Python ≥ 3.10
70
+ - Lint/format via Ruff (line-length=120, target=py310)
71
+ - Imports: absolute only (enforced by TID252), except `__init__.py` may re-export relatively
72
+ - Type hints required using modern syntax (`str | None` over `Optional[str]`)
73
+ - Public APIs should document return values and exceptions (Google-style docstrings preferred)
74
+ - Favor `@dataclass` for simple data contracts; separate mutable state cleanly
75
+ - Don’t swallow errors broadly; raise `ValueError` etc. on contract violations
76
+
77
+ ## API contracts and critical rules
78
+
79
+ - Proxy (BatchingMapProxy / AsyncBatchingMapProxy)
80
+ - map_func must return a list with the same length and order as inputs; on mismatch, release events and raise ValueError
81
+ - Inputs are de-duplicated while preserving first-occurrence order; outputs are restored to the original order
82
+ - Progress is only shown in notebooks when `show_progress=True`
83
+ - Async version enforces `max_concurrency` via `asyncio.Semaphore`
84
+ - Responses
85
+ - Use OpenAI Responses JSON mode (`responses.parse`)
86
+ - For reasoning models (o1/o3 families), you MUST set `temperature=None`; helpful error messaging is built-in
87
+ - Strongly prefer structured outputs with Pydantic models
88
+ - Retries with exponential backoff for RateLimit/5xx
89
+ - Embeddings
90
+ - Return NumPy float32 arrays
91
+ - pandas extensions
92
+ - `.ai.responses` / `.ai.embeddings` strictly preserve Series index and length
93
+ - `.aio` provides async variants; tune with `max_concurrency` and `batch_size`
94
+ - `*_with_cache` variants let callers share external caches across ops
95
+ - Spark UDFs
96
+ - Cache duplicates within each partition to minimize API cost
97
+ - Convert Pydantic models to Spark schemas; treat Enum/Literal as strings
98
+ - Reasoning models require `temperature=None`
99
+ - Provide token counting and text chunking helpers
100
+ - Provider/DI and Azure
101
+ - Auto-detect OpenAI vs Azure OpenAI from env vars
102
+ - Azure requires v1 base URL (warn otherwise) and uses deployment name as the “model”
103
+
104
+ ## Preferred patterns (Do) and Avoid (Don’t)
105
+
106
+ Do
107
+
108
+ - Batch through the Proxy rather than per-item loops
109
+ - Attach `backoff`/`backoff_async` to external API calls (RateLimit, 5xx)
110
+ - Preserve index/order/schema for pandas/Spark APIs
111
+ - Clarify Azure specifics (“deployment name” vs “model name”); respect `_check_azure_v1_api_url`
112
+ - When changing public APIs, update `__all__` and docs in `docs/`
113
+
114
+ Don’t
115
+
116
+ - Break the Proxy contract (same-length, ordered result)
117
+ - Fire one API request per item—always batch via the Proxy
118
+ - Show progress outside notebook contexts or ignore `show_progress`
119
+ - Use relative imports (except `__init__.py` re-exports)
120
+ - Hit real external APIs in unit tests (prefer mocks/stubs)
121
+
122
+ ## Performance guidance
123
+
124
+ - Typical batch size ranges: Responses 32–128, Embeddings 64–256 (defaults are 128 in code)
125
+ - Async `max_concurrency` commonly 4–12 per process/partition; scale with rate limits in mind
126
+ - Partition-level caching (Spark) and cross-op cache sharing (pandas `*_with_cache`) greatly reduce costs
127
+
128
+ ## Testing strategy (pytest)
129
+
130
+ - Tests live in `tests/`; cover both sync and async where applicable
131
+ - Prefer mocks/stubs for external API calls; keep data small and deterministic
132
+ - Focus areas:
133
+ - Order/length preservation
134
+ - Deduplication and cache reuse
135
+ - Event release on exceptions (deadlock prevention)
136
+ - `max_concurrency` is not exceeded
137
+ - Reasoning model guidance (`temperature=None`)
138
+ - Use `asyncio.run` in async tests (mirrors existing tests)
139
+ - Optional integration tests can run with valid API keys; keep unit tests independent of network
140
+
141
+ ## Documentation (MkDocs)
142
+
143
+ - For new developer-facing APIs, update `docs/api/` and consider a short example under `docs/examples/`
144
+ - Keep pandas/Spark examples concise to minimize learning curve
145
+ - Update `mkdocs.yml` navigation when adding modules or examples
146
+
147
+ ## PR checklist (pre-merge)
148
+
149
+ - [ ] Ruff check/format passes (line-length 120, absolute imports)
150
+ - [ ] Public API contracts (order/length/types) are satisfied
151
+ - [ ] Large-scale processing is batched via the Proxy
152
+ - [ ] Reasoning models use `temperature=None` where applicable
153
+ - [ ] Tests added/updated without calling live external APIs
154
+ - [ ] Docs updated if needed (`docs/` and/or examples)
155
+
156
+ ## Common snippets (what to suggest)
157
+
158
+ - New batched API wrapper (sync)
159
+
160
+ ```python
161
+ @observe(_LOGGER)
162
+ @backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
163
+ def _unit_of_work(self, xs: list[str]) -> list[TOut]:
164
+ resp = self.client.api(xs) # real API call
165
+ return convert(resp) # same length/order as xs
166
+
167
+ def create(self, inputs: list[str]) -> list[TOut]:
168
+ return self.cache.map(inputs, self._unit_of_work)
169
+ ```
170
+
171
+ - Reasoning model temperature
172
+ ```python
173
+ # o1/o3 and similar reasoning models must use None
174
+ temperature=None
175
+ ```
176
+ - pandas `.ai` with shared cache
177
+
178
+ ```python
179
+ from openaivec.proxy import BatchingMapProxy
180
+
181
+ shared = BatchingMapProxy[str, str](batch_size=64)
182
+ df["text"].ai.responses_with_cache("instructions", cache=shared)
183
+ ```
184
+
185
+ - Spark UDF (structured output)
186
+
187
+ ```python
188
+ from pydantic import BaseModel
189
+ from openaivec.spark import responses_udf
190
+
191
+ class R(BaseModel):
192
+ value: str
193
+
194
+ udf = responses_udf("do something", response_format=R, batch_size=64, max_concurrency=8)
195
+ ```
196
+
197
+ ---
198
+
199
+ By following this guide, Copilot suggestions will match the repository’s design, performance goals, and testing standards. When in doubt, read the implementations in `proxy.py`, `responses.py`, `pandas_ext.py`, and `spark.py`, and the tests under `tests/`.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openaivec
3
- Version: 0.13.3
3
+ Version: 0.13.4
4
4
  Summary: Generative mutation for tabular calculation
5
5
  Project-URL: Homepage, https://microsoft.github.io/openaivec/
6
6
  Project-URL: Repository, https://github.com/microsoft/openaivec
@@ -61,6 +61,7 @@ nav:
61
61
  - Async Workflows: examples/aio.ipynb
62
62
  - Prompt Engineering: examples/prompt.ipynb
63
63
  - FAQ Generation: examples/generate_faq.ipynb
64
+ - Token Count and Processing Time: examples/batch_size.ipynb
64
65
  - API Reference:
65
66
  - di: api/di.md
66
67
  - pandas_ext: api/pandas_ext.md
@@ -0,0 +1,108 @@
1
+ import threading
2
+ import time
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime, timezone
6
+ from typing import List
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class PerformanceMetric:
11
+ duration: float
12
+ batch_size: int
13
+ executed_at: datetime
14
+ exception: BaseException | None = None
15
+
16
+
17
+ @dataclass
18
+ class BatchSizeSuggester:
19
+ current_batch_size: int = 10
20
+ min_batch_size: int = 10
21
+ min_duration: float = 30.0
22
+ max_duration: float = 60.0
23
+ step_ratio: float = 0.1
24
+ sample_size: int = 10
25
+ _history: List[PerformanceMetric] = field(default_factory=list)
26
+ _lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
27
+ _batch_size_changed_at: datetime | None = field(default=None, init=False)
28
+
29
+ def __post_init__(self) -> None:
30
+ if self.min_batch_size <= 0:
31
+ raise ValueError("min_batch_size must be > 0")
32
+ if self.current_batch_size < self.min_batch_size:
33
+ raise ValueError("current_batch_size must be >= min_batch_size")
34
+ if self.sample_size <= 0:
35
+ raise ValueError("sample_size must be > 0")
36
+ if self.step_ratio <= 0:
37
+ raise ValueError("step_ratio must be > 0")
38
+ if self.min_duration <= 0 or self.max_duration <= 0:
39
+ raise ValueError("min_duration and max_duration must be > 0")
40
+ if self.min_duration >= self.max_duration:
41
+ raise ValueError("min_duration must be < max_duration")
42
+
43
+ @contextmanager
44
+ def record(self, batch_size: int):
45
+ start_time = time.perf_counter()
46
+ executed_at = datetime.now(timezone.utc)
47
+ caught_exception: BaseException | None = None
48
+ try:
49
+ yield
50
+ except BaseException as e:
51
+ caught_exception = e
52
+ raise
53
+ finally:
54
+ duration = time.perf_counter() - start_time
55
+ with self._lock:
56
+ self._history.append(
57
+ PerformanceMetric(
58
+ duration=duration,
59
+ batch_size=batch_size,
60
+ executed_at=executed_at,
61
+ exception=caught_exception,
62
+ )
63
+ )
64
+
65
+ @property
66
+ def samples(self) -> List[PerformanceMetric]:
67
+ with self._lock:
68
+ selected: List[PerformanceMetric] = []
69
+ for metric in reversed(self._history):
70
+ if metric.exception is not None:
71
+ continue
72
+ if self._batch_size_changed_at and metric.executed_at < self._batch_size_changed_at:
73
+ continue
74
+ selected.append(metric)
75
+ if len(selected) >= self.sample_size:
76
+ break
77
+ return list(reversed(selected))
78
+
79
+ def clear_history(self):
80
+ with self._lock:
81
+ self._history.clear()
82
+
83
+ def suggest_batch_size(self) -> int:
84
+ selected = self.samples
85
+
86
+ if len(selected) < self.sample_size:
87
+ with self._lock:
88
+ return self.current_batch_size
89
+
90
+ average_duration = sum(m.duration for m in selected) / len(selected)
91
+
92
+ with self._lock:
93
+ current_size = self.current_batch_size
94
+
95
+ if average_duration < self.min_duration:
96
+ new_batch_size = int(current_size * (1 + self.step_ratio))
97
+ elif average_duration > self.max_duration:
98
+ new_batch_size = int(current_size * (1 - self.step_ratio))
99
+ else:
100
+ new_batch_size = current_size
101
+
102
+ new_batch_size = max(new_batch_size, self.min_batch_size)
103
+
104
+ if new_batch_size != self.current_batch_size:
105
+ self._batch_size_changed_at = datetime.now(timezone.utc)
106
+ self.current_batch_size = new_batch_size
107
+
108
+ return self.current_batch_size
@@ -4,6 +4,8 @@ from collections.abc import Hashable
4
4
  from dataclasses import dataclass, field
5
5
  from typing import Awaitable, Callable, Dict, Generic, List, Optional, TypeVar
6
6
 
7
+ from openaivec.optimize import BatchSizeSuggester
8
+
7
9
  S = TypeVar("S", bound=Hashable)
8
10
  T = TypeVar("T")
9
11
 
@@ -22,6 +24,7 @@ class ProxyBase(Generic[S, T]):
22
24
 
23
25
  batch_size: Optional[int] = None # subclasses may override via dataclass
24
26
  show_progress: bool = False # Enable progress bar display
27
+ suggester: BatchSizeSuggester = None # Batch size optimization, initialized by subclasses
25
28
 
26
29
  def _is_notebook_environment(self) -> bool:
27
30
  """Check if running in a Jupyter notebook environment.
@@ -125,7 +128,7 @@ class ProxyBase(Generic[S, T]):
125
128
  progress_bar.close()
126
129
 
127
130
  @staticmethod
128
- def __unique_in_order(seq: List[S]) -> List[S]:
131
+ def _unique_in_order(seq: List[S]) -> List[S]:
129
132
  """Return unique items preserving their first-occurrence order.
130
133
 
131
134
  Args:
@@ -143,11 +146,11 @@ class ProxyBase(Generic[S, T]):
143
146
  out.append(x)
144
147
  return out
145
148
 
146
- def __normalized_batch_size(self, total: int) -> int:
149
+ def _normalized_batch_size(self, total: int) -> int:
147
150
  """Compute the effective batch size used for processing.
148
151
 
149
- If ``batch_size`` is not set or non-positive, the entire ``total`` is
150
- processed in a single call.
152
+ If ``batch_size`` is None, use the suggester to determine optimal batch size.
153
+ If ``batch_size`` is non-positive, process the entire ``total`` in a single call.
151
154
 
152
155
  Args:
153
156
  total (int): Number of items intended to be processed.
@@ -155,7 +158,15 @@ class ProxyBase(Generic[S, T]):
155
158
  Returns:
156
159
  int: The positive batch size to use.
157
160
  """
158
- return self.batch_size if (self.batch_size and self.batch_size > 0) else total
161
+ if self.batch_size and self.batch_size > 0:
162
+ return self.batch_size
163
+ elif self.batch_size is None:
164
+ # Use suggester to determine optimal batch size
165
+ suggested = self.suggester.suggest_batch_size()
166
+ return min(suggested, total) # Don't exceed total items
167
+ else:
168
+ # batch_size is 0 or negative, process all at once
169
+ return total
159
170
 
160
171
 
161
172
  @dataclass
@@ -180,19 +191,13 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
180
191
  # Number of items to process per call to map_func. If None or <= 0, process all at once.
181
192
  batch_size: Optional[int] = None
182
193
  show_progress: bool = False
194
+ suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
195
+
196
+ # internals
183
197
  __cache: Dict[S, T] = field(default_factory=dict)
184
- # Thread-safety primitives (not part of public API)
185
198
  __lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
186
199
  __inflight: Dict[S, threading.Event] = field(default_factory=dict, repr=False)
187
200
 
188
- # ---- private helpers -------------------------------------------------
189
- # expose base helpers under subclass private names for compatibility
190
- __unique_in_order = staticmethod(ProxyBase._ProxyBase__unique_in_order)
191
- __normalized_batch_size = ProxyBase._ProxyBase__normalized_batch_size
192
- _create_progress_bar = ProxyBase._create_progress_bar
193
- _update_progress_bar = ProxyBase._update_progress_bar
194
- _close_progress_bar = ProxyBase._close_progress_bar
195
-
196
201
  def __all_cached(self, items: List[S]) -> bool:
197
202
  """Check whether all items are present in the cache.
198
203
 
@@ -320,16 +325,17 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
320
325
  """
321
326
  if not owned:
322
327
  return
323
- batch_size = self.__normalized_batch_size(len(owned))
328
+ # Setup progress bar
329
+ progress_bar = self._create_progress_bar(len(owned))
324
330
 
325
331
  # Accumulate uncached items to maximize batch size utilization
326
332
  pending_to_call: List[S] = []
327
333
 
328
- # Setup progress bar
329
- progress_bar = self._create_progress_bar(len(owned))
330
-
331
- for i in range(0, len(owned), batch_size):
332
- batch = owned[i : i + batch_size]
334
+ i = 0
335
+ while i < len(owned):
336
+ # Get dynamic batch size for each iteration
337
+ current_batch_size = self._normalized_batch_size(len(owned))
338
+ batch = owned[i : i + current_batch_size]
333
339
  # Double-check cache right before processing
334
340
  with self.__lock:
335
341
  uncached_in_batch = [x for x in batch if x not in self.__cache]
@@ -337,14 +343,16 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
337
343
  pending_to_call.extend(uncached_in_batch)
338
344
 
339
345
  # Process accumulated items when we reach batch_size or at the end
340
- is_last_batch = i + batch_size >= len(owned)
341
- if len(pending_to_call) >= batch_size or (is_last_batch and pending_to_call):
346
+ is_last_batch = i + current_batch_size >= len(owned)
347
+ if len(pending_to_call) >= current_batch_size or (is_last_batch and pending_to_call):
342
348
  # Take up to batch_size items to process
343
- to_call = pending_to_call[:batch_size]
344
- pending_to_call = pending_to_call[batch_size:]
349
+ to_call = pending_to_call[:current_batch_size]
350
+ pending_to_call = pending_to_call[current_batch_size:]
345
351
 
346
352
  try:
347
- results = map_func(to_call)
353
+ # Always measure execution time using suggester
354
+ with self.suggester.record(len(to_call)):
355
+ results = map_func(to_call)
348
356
  except Exception:
349
357
  self.__finalize_failure(to_call)
350
358
  raise
@@ -353,13 +361,19 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
353
361
  # Update progress bar
354
362
  self._update_progress_bar(progress_bar, len(to_call))
355
363
 
364
+ # Move to next batch
365
+ i += current_batch_size
366
+
356
367
  # Process any remaining items
357
368
  while pending_to_call:
358
- to_call = pending_to_call[:batch_size]
359
- pending_to_call = pending_to_call[batch_size:]
369
+ # Get dynamic batch size for remaining items
370
+ remaining_batch_size = self._normalized_batch_size(len(pending_to_call))
371
+ to_call = pending_to_call[:remaining_batch_size]
372
+ pending_to_call = pending_to_call[remaining_batch_size:]
360
373
 
361
374
  try:
362
- results = map_func(to_call)
375
+ with self.suggester.record(len(to_call)):
376
+ results = map_func(to_call)
363
377
  except Exception:
364
378
  self.__finalize_failure(to_call)
365
379
  raise
@@ -430,7 +444,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
430
444
  if self.__all_cached(items):
431
445
  return self.__values(items)
432
446
 
433
- unique_items = self.__unique_in_order(items)
447
+ unique_items = self._unique_in_order(items)
434
448
  owned, wait_for = self.__acquire_ownership(unique_items)
435
449
 
436
450
  self.__process_owned(owned, map_func)
@@ -465,6 +479,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
465
479
  batch_size: Optional[int] = None
466
480
  max_concurrency: int = 8
467
481
  show_progress: bool = False
482
+ suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
468
483
 
469
484
  # internals
470
485
  __cache: Dict[S, T] = field(default_factory=dict, repr=False)
@@ -490,14 +505,6 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
490
505
  else:
491
506
  self.__sema = None
492
507
 
493
- # ---- private helpers -------------------------------------------------
494
- # expose base helpers under subclass private names for compatibility
495
- __unique_in_order = staticmethod(ProxyBase._ProxyBase__unique_in_order)
496
- __normalized_batch_size = ProxyBase._ProxyBase__normalized_batch_size
497
- _create_progress_bar = ProxyBase._create_progress_bar
498
- _update_progress_bar = ProxyBase._update_progress_bar
499
- _close_progress_bar = ProxyBase._close_progress_bar
500
-
501
508
  async def __all_cached(self, items: List[S]) -> bool:
502
509
  """Check whether all items are present in the cache.
503
510
 
@@ -602,69 +609,43 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
602
609
  await self.clear()
603
610
 
604
611
  async def __process_owned(self, owned: List[S], map_func: Callable[[List[S]], Awaitable[List[T]]]) -> None:
605
- """Process owned keys in mini-batches, re-checking cache before awaits.
606
-
607
- Before calling ``map_func`` for each batch, the cache is re-checked to
608
- skip any keys that may have been filled in the meantime. Items
609
- are accumulated across multiple original batches to maximize batch
610
- size utilization when some items are cached. On exceptions raised
611
- by ``map_func``, all corresponding in-flight events are released
612
- to prevent deadlocks, and the exception is propagated.
612
+ """Process owned keys using Producer-Consumer pattern with dynamic batch sizing.
613
613
 
614
614
  Args:
615
- owned (list[S]): Items for which this coroutine holds computation
616
- ownership.
615
+ owned (list[S]): Items for which this coroutine holds computation ownership.
617
616
 
618
617
  Raises:
619
618
  Exception: Propagates any exception raised by ``map_func``.
620
619
  """
621
620
  if not owned:
622
621
  return
623
- batch_size = self.__normalized_batch_size(len(owned))
624
622
 
625
- # Accumulate uncached items to maximize batch size utilization
626
- pending_to_call: List[S] = []
627
-
628
- # Setup progress bar
629
623
  progress_bar = self._create_progress_bar(len(owned))
624
+ batch_queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_concurrency)
625
+
626
+ async def producer():
627
+ index = 0
628
+ while index < len(owned):
629
+ batch_size = self._normalized_batch_size(len(owned) - index)
630
+ batch = owned[index : index + batch_size]
631
+ await batch_queue.put(batch)
632
+ index += batch_size
633
+ # Send completion signals
634
+ for _ in range(self.max_concurrency):
635
+ await batch_queue.put(None)
636
+
637
+ async def consumer():
638
+ while True:
639
+ batch = await batch_queue.get()
640
+ try:
641
+ if batch is None:
642
+ break
643
+ await self.__process_single_batch(batch, map_func, progress_bar)
644
+ finally:
645
+ batch_queue.task_done()
630
646
 
631
- # Collect all batches to process
632
- batches_to_process: List[List[S]] = []
633
-
634
- for i in range(0, len(owned), batch_size):
635
- batch = owned[i : i + batch_size]
636
- async with self.__lock:
637
- uncached_in_batch = [x for x in batch if x not in self.__cache]
638
-
639
- pending_to_call.extend(uncached_in_batch)
640
-
641
- # Process accumulated items when we reach batch_size or at the end
642
- is_last_batch = i + batch_size >= len(owned)
643
- if len(pending_to_call) >= batch_size or (is_last_batch and pending_to_call):
644
- # Take up to batch_size items to process
645
- to_call = pending_to_call[:batch_size]
646
- pending_to_call = pending_to_call[batch_size:]
647
- if to_call: # Only add non-empty batches
648
- batches_to_process.append(to_call)
649
-
650
- # Process any remaining items
651
- while pending_to_call:
652
- to_call = pending_to_call[:batch_size]
653
- pending_to_call = pending_to_call[batch_size:]
654
- if to_call: # Only add non-empty batches
655
- batches_to_process.append(to_call)
656
-
657
- # Process all batches concurrently
658
- if batches_to_process:
659
- tasks = []
660
- for batch in batches_to_process:
661
- task = self.__process_single_batch(batch, map_func, progress_bar)
662
- tasks.append(task)
663
-
664
- # Wait for all batches to complete
665
- await asyncio.gather(*tasks)
647
+ await asyncio.gather(producer(), *[consumer() for _ in range(self.max_concurrency)])
666
648
 
667
- # Close progress bar
668
649
  self._close_progress_bar(progress_bar)
669
650
 
670
651
  async def __process_single_batch(
@@ -676,7 +657,9 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
676
657
  if self.__sema:
677
658
  await self.__sema.acquire()
678
659
  acquired = True
679
- results = await map_func(to_call)
660
+ # Measure async map_func execution using suggester
661
+ with self.suggester.record(len(to_call)):
662
+ results = await map_func(to_call)
680
663
  except Exception:
681
664
  await self.__finalize_failure(to_call)
682
665
  raise
@@ -737,7 +720,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
737
720
  if await self.__all_cached(items):
738
721
  return await self.__values(items)
739
722
 
740
- unique_items = self.__unique_in_order(items)
723
+ unique_items = self._unique_in_order(items)
741
724
  owned, wait_for = await self.__acquire_ownership(unique_items)
742
725
 
743
726
  await self.__process_owned(owned, map_func)