ai-lib-python 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +43 -0
- ai_lib_python/batch/__init__.py +15 -0
- ai_lib_python/batch/collector.py +244 -0
- ai_lib_python/batch/executor.py +224 -0
- ai_lib_python/cache/__init__.py +26 -0
- ai_lib_python/cache/backends.py +380 -0
- ai_lib_python/cache/key.py +237 -0
- ai_lib_python/cache/manager.py +332 -0
- ai_lib_python/client/__init__.py +37 -0
- ai_lib_python/client/builder.py +528 -0
- ai_lib_python/client/cancel.py +368 -0
- ai_lib_python/client/core.py +433 -0
- ai_lib_python/client/response.py +134 -0
- ai_lib_python/embeddings/__init__.py +36 -0
- ai_lib_python/embeddings/client.py +339 -0
- ai_lib_python/embeddings/types.py +234 -0
- ai_lib_python/embeddings/vectors.py +246 -0
- ai_lib_python/errors/__init__.py +41 -0
- ai_lib_python/errors/base.py +316 -0
- ai_lib_python/errors/classification.py +210 -0
- ai_lib_python/guardrails/__init__.py +35 -0
- ai_lib_python/guardrails/base.py +336 -0
- ai_lib_python/guardrails/filters.py +583 -0
- ai_lib_python/guardrails/validators.py +475 -0
- ai_lib_python/pipeline/__init__.py +55 -0
- ai_lib_python/pipeline/accumulate.py +248 -0
- ai_lib_python/pipeline/base.py +240 -0
- ai_lib_python/pipeline/decode.py +281 -0
- ai_lib_python/pipeline/event_map.py +506 -0
- ai_lib_python/pipeline/fan_out.py +284 -0
- ai_lib_python/pipeline/select.py +297 -0
- ai_lib_python/plugins/__init__.py +32 -0
- ai_lib_python/plugins/base.py +294 -0
- ai_lib_python/plugins/hooks.py +296 -0
- ai_lib_python/plugins/middleware.py +285 -0
- ai_lib_python/plugins/registry.py +294 -0
- ai_lib_python/protocol/__init__.py +71 -0
- ai_lib_python/protocol/loader.py +317 -0
- ai_lib_python/protocol/manifest.py +385 -0
- ai_lib_python/protocol/validator.py +460 -0
- ai_lib_python/py.typed +1 -0
- ai_lib_python/resilience/__init__.py +102 -0
- ai_lib_python/resilience/backpressure.py +225 -0
- ai_lib_python/resilience/circuit_breaker.py +318 -0
- ai_lib_python/resilience/executor.py +343 -0
- ai_lib_python/resilience/fallback.py +341 -0
- ai_lib_python/resilience/preflight.py +413 -0
- ai_lib_python/resilience/rate_limiter.py +291 -0
- ai_lib_python/resilience/retry.py +299 -0
- ai_lib_python/resilience/signals.py +283 -0
- ai_lib_python/routing/__init__.py +118 -0
- ai_lib_python/routing/manager.py +593 -0
- ai_lib_python/routing/strategy.py +345 -0
- ai_lib_python/routing/types.py +397 -0
- ai_lib_python/structured/__init__.py +33 -0
- ai_lib_python/structured/json_mode.py +281 -0
- ai_lib_python/structured/schema.py +316 -0
- ai_lib_python/structured/validator.py +334 -0
- ai_lib_python/telemetry/__init__.py +127 -0
- ai_lib_python/telemetry/exporters/__init__.py +9 -0
- ai_lib_python/telemetry/exporters/prometheus.py +111 -0
- ai_lib_python/telemetry/feedback.py +446 -0
- ai_lib_python/telemetry/health.py +409 -0
- ai_lib_python/telemetry/logger.py +389 -0
- ai_lib_python/telemetry/metrics.py +496 -0
- ai_lib_python/telemetry/tracer.py +473 -0
- ai_lib_python/tokens/__init__.py +25 -0
- ai_lib_python/tokens/counter.py +282 -0
- ai_lib_python/tokens/estimator.py +286 -0
- ai_lib_python/transport/__init__.py +34 -0
- ai_lib_python/transport/auth.py +141 -0
- ai_lib_python/transport/http.py +364 -0
- ai_lib_python/transport/pool.py +425 -0
- ai_lib_python/types/__init__.py +41 -0
- ai_lib_python/types/events.py +343 -0
- ai_lib_python/types/message.py +332 -0
- ai_lib_python/types/tool.py +191 -0
- ai_lib_python/utils/__init__.py +21 -0
- ai_lib_python/utils/tool_call_assembler.py +317 -0
- ai_lib_python-0.5.0.dist-info/METADATA +837 -0
- ai_lib_python-0.5.0.dist-info/RECORD +84 -0
- ai_lib_python-0.5.0.dist-info/WHEEL +4 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-APACHE +201 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-MIT +21 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding client for generating embeddings.
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for embedding generation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from ai_lib_python.embeddings.types import (
|
|
12
|
+
Embedding,
|
|
13
|
+
EmbeddingRequest,
|
|
14
|
+
EmbeddingResponse,
|
|
15
|
+
EmbeddingUsage,
|
|
16
|
+
)
|
|
17
|
+
from ai_lib_python.protocol import ProtocolLoader
|
|
18
|
+
from ai_lib_python.transport import HttpTransport
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ai_lib_python.protocol.manifest import ProtocolManifest
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddingClient:
|
|
25
|
+
"""Client for generating embeddings.
|
|
26
|
+
|
|
27
|
+
Supports batch embedding generation with automatic chunking
|
|
28
|
+
and rate limit handling.
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> client = await EmbeddingClient.create("openai/text-embedding-3-small")
|
|
32
|
+
>>> response = await client.embed("Hello, world!")
|
|
33
|
+
>>> print(response.first.vector[:5])
|
|
34
|
+
|
|
35
|
+
>>> # Batch embedding
|
|
36
|
+
>>> texts = ["Hello", "World", "Test"]
|
|
37
|
+
>>> response = await client.embed_batch(texts)
|
|
38
|
+
>>> for emb in response.embeddings:
|
|
39
|
+
... print(f"Text {emb.index}: {len(emb.vector)} dimensions")
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
manifest: ProtocolManifest,
|
|
45
|
+
transport: HttpTransport,
|
|
46
|
+
model_id: str,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize embedding client.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
manifest: Protocol manifest
|
|
52
|
+
transport: HTTP transport
|
|
53
|
+
model_id: Model identifier
|
|
54
|
+
"""
|
|
55
|
+
self._manifest = manifest
|
|
56
|
+
self._transport = transport
|
|
57
|
+
self._model_id = model_id
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
async def create(
|
|
61
|
+
cls,
|
|
62
|
+
model: str,
|
|
63
|
+
api_key: str | None = None,
|
|
64
|
+
base_url: str | None = None,
|
|
65
|
+
dimensions: int | None = None,
|
|
66
|
+
) -> EmbeddingClient:
|
|
67
|
+
"""Create an embedding client.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model: Model identifier (e.g., "openai/text-embedding-3-small")
|
|
71
|
+
api_key: API key (optional, uses environment)
|
|
72
|
+
base_url: Base URL override
|
|
73
|
+
dimensions: Output dimensions override
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
EmbeddingClient instance
|
|
77
|
+
"""
|
|
78
|
+
return await (
|
|
79
|
+
EmbeddingClientBuilder()
|
|
80
|
+
.model(model)
|
|
81
|
+
.api_key(api_key)
|
|
82
|
+
.base_url(base_url)
|
|
83
|
+
.dimensions(dimensions)
|
|
84
|
+
.build()
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def builder(cls) -> EmbeddingClientBuilder:
|
|
89
|
+
"""Get a builder for creating embedding clients.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
EmbeddingClientBuilder instance
|
|
93
|
+
"""
|
|
94
|
+
return EmbeddingClientBuilder()
|
|
95
|
+
|
|
96
|
+
async def embed(
|
|
97
|
+
self,
|
|
98
|
+
text: str,
|
|
99
|
+
dimensions: int | None = None,
|
|
100
|
+
) -> EmbeddingResponse:
|
|
101
|
+
"""Generate embedding for a single text.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
text: Text to embed
|
|
105
|
+
dimensions: Output dimensions (if supported)
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
EmbeddingResponse with single embedding
|
|
109
|
+
"""
|
|
110
|
+
request = EmbeddingRequest(
|
|
111
|
+
input=text,
|
|
112
|
+
model=self._model_id,
|
|
113
|
+
dimensions=dimensions,
|
|
114
|
+
)
|
|
115
|
+
return await self._execute(request)
|
|
116
|
+
|
|
117
|
+
async def embed_batch(
|
|
118
|
+
self,
|
|
119
|
+
texts: list[str],
|
|
120
|
+
dimensions: int | None = None,
|
|
121
|
+
batch_size: int = 100,
|
|
122
|
+
) -> EmbeddingResponse:
|
|
123
|
+
"""Generate embeddings for multiple texts.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
texts: List of texts to embed
|
|
127
|
+
dimensions: Output dimensions (if supported)
|
|
128
|
+
batch_size: Maximum texts per API call
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
EmbeddingResponse with all embeddings
|
|
132
|
+
"""
|
|
133
|
+
if len(texts) <= batch_size:
|
|
134
|
+
request = EmbeddingRequest(
|
|
135
|
+
input=texts,
|
|
136
|
+
model=self._model_id,
|
|
137
|
+
dimensions=dimensions,
|
|
138
|
+
)
|
|
139
|
+
return await self._execute(request)
|
|
140
|
+
|
|
141
|
+
# Process in batches
|
|
142
|
+
all_embeddings: list[Embedding] = []
|
|
143
|
+
total_usage = EmbeddingUsage()
|
|
144
|
+
|
|
145
|
+
for i in range(0, len(texts), batch_size):
|
|
146
|
+
batch = texts[i : i + batch_size]
|
|
147
|
+
request = EmbeddingRequest(
|
|
148
|
+
input=batch,
|
|
149
|
+
model=self._model_id,
|
|
150
|
+
dimensions=dimensions,
|
|
151
|
+
)
|
|
152
|
+
response = await self._execute(request)
|
|
153
|
+
|
|
154
|
+
# Adjust indices for combined result
|
|
155
|
+
for emb in response.embeddings:
|
|
156
|
+
all_embeddings.append(
|
|
157
|
+
Embedding(
|
|
158
|
+
index=i + emb.index,
|
|
159
|
+
vector=emb.vector,
|
|
160
|
+
object_type=emb.object_type,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
total_usage.prompt_tokens += response.usage.prompt_tokens
|
|
165
|
+
total_usage.total_tokens += response.usage.total_tokens
|
|
166
|
+
|
|
167
|
+
return EmbeddingResponse(
|
|
168
|
+
embeddings=all_embeddings,
|
|
169
|
+
model=self._model_id,
|
|
170
|
+
usage=total_usage,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
async def _execute(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
174
|
+
"""Execute an embedding request.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
request: Embedding request
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
EmbeddingResponse
|
|
181
|
+
"""
|
|
182
|
+
endpoint = self._get_embedding_endpoint()
|
|
183
|
+
payload = request.to_dict()
|
|
184
|
+
|
|
185
|
+
response = await self._transport.post(endpoint, json=payload)
|
|
186
|
+
data = response.json()
|
|
187
|
+
|
|
188
|
+
return EmbeddingResponse.from_openai_format(data)
|
|
189
|
+
|
|
190
|
+
def _get_embedding_endpoint(self) -> str:
|
|
191
|
+
"""Get the embedding API endpoint.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Endpoint path
|
|
195
|
+
"""
|
|
196
|
+
# Try to get from manifest, default to OpenAI-style
|
|
197
|
+
if hasattr(self._manifest, "embedding_endpoint"):
|
|
198
|
+
return self._manifest.embedding_endpoint
|
|
199
|
+
return "/v1/embeddings"
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def model(self) -> str:
|
|
203
|
+
"""Get the model identifier."""
|
|
204
|
+
return self._model_id
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def provider(self) -> str:
|
|
208
|
+
"""Get the provider identifier."""
|
|
209
|
+
return self._manifest.id
|
|
210
|
+
|
|
211
|
+
async def close(self) -> None:
|
|
212
|
+
"""Close the client."""
|
|
213
|
+
await self._transport.close()
|
|
214
|
+
|
|
215
|
+
async def __aenter__(self) -> EmbeddingClient:
|
|
216
|
+
"""Async context manager entry."""
|
|
217
|
+
return self
|
|
218
|
+
|
|
219
|
+
async def __aexit__(
|
|
220
|
+
self, exc_type: Any, exc_val: Any, exc_tb: Any
|
|
221
|
+
) -> None:
|
|
222
|
+
"""Async context manager exit."""
|
|
223
|
+
await self.close()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class EmbeddingClientBuilder:
|
|
227
|
+
"""Builder for creating EmbeddingClient instances.
|
|
228
|
+
|
|
229
|
+
Example:
|
|
230
|
+
>>> client = await (
|
|
231
|
+
... EmbeddingClient.builder()
|
|
232
|
+
... .model("openai/text-embedding-3-small")
|
|
233
|
+
... .dimensions(512)
|
|
234
|
+
... .build()
|
|
235
|
+
... )
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self) -> None:
|
|
239
|
+
"""Initialize builder."""
|
|
240
|
+
self._model: str | None = None
|
|
241
|
+
self._api_key: str | None = None
|
|
242
|
+
self._base_url: str | None = None
|
|
243
|
+
self._dimensions: int | None = None
|
|
244
|
+
self._timeout: float | None = None
|
|
245
|
+
|
|
246
|
+
def model(self, model: str) -> EmbeddingClientBuilder:
|
|
247
|
+
"""Set the model.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
model: Model identifier (e.g., "openai/text-embedding-3-small")
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Self for chaining
|
|
254
|
+
"""
|
|
255
|
+
self._model = model
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def api_key(self, api_key: str | None) -> EmbeddingClientBuilder:
|
|
259
|
+
"""Set the API key.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
api_key: API key
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Self for chaining
|
|
266
|
+
"""
|
|
267
|
+
self._api_key = api_key
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
def base_url(self, base_url: str | None) -> EmbeddingClientBuilder:
|
|
271
|
+
"""Set the base URL.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
base_url: Base URL override
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Self for chaining
|
|
278
|
+
"""
|
|
279
|
+
self._base_url = base_url
|
|
280
|
+
return self
|
|
281
|
+
|
|
282
|
+
def dimensions(self, dimensions: int | None) -> EmbeddingClientBuilder:
|
|
283
|
+
"""Set the output dimensions.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
dimensions: Output dimensions
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Self for chaining
|
|
290
|
+
"""
|
|
291
|
+
self._dimensions = dimensions
|
|
292
|
+
return self
|
|
293
|
+
|
|
294
|
+
def timeout(self, timeout: float) -> EmbeddingClientBuilder:
|
|
295
|
+
"""Set the request timeout.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
timeout: Timeout in seconds
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Self for chaining
|
|
302
|
+
"""
|
|
303
|
+
self._timeout = timeout
|
|
304
|
+
return self
|
|
305
|
+
|
|
306
|
+
async def build(self) -> EmbeddingClient:
|
|
307
|
+
"""Build the embedding client.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
EmbeddingClient instance
|
|
311
|
+
|
|
312
|
+
Raises:
|
|
313
|
+
ValueError: If model is not set
|
|
314
|
+
"""
|
|
315
|
+
if not self._model:
|
|
316
|
+
raise ValueError("Model must be set before building")
|
|
317
|
+
|
|
318
|
+
# Parse model identifier
|
|
319
|
+
parts = self._model.split("/")
|
|
320
|
+
provider_id = parts[0] if len(parts) >= 2 else "openai"
|
|
321
|
+
model_id = parts[-1]
|
|
322
|
+
|
|
323
|
+
# Load protocol manifest
|
|
324
|
+
loader = ProtocolLoader()
|
|
325
|
+
manifest = await loader.load_provider(provider_id)
|
|
326
|
+
|
|
327
|
+
# Create transport
|
|
328
|
+
transport = HttpTransport.from_manifest(
|
|
329
|
+
manifest,
|
|
330
|
+
api_key=self._api_key,
|
|
331
|
+
base_url_override=self._base_url,
|
|
332
|
+
timeout=self._timeout,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return EmbeddingClient(
|
|
336
|
+
manifest=manifest,
|
|
337
|
+
transport=transport,
|
|
338
|
+
model_id=model_id,
|
|
339
|
+
)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding types and data models.
|
|
3
|
+
|
|
4
|
+
Defines the core types for embedding operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EmbeddingModel(str, Enum):
|
|
15
|
+
"""Standard embedding models."""
|
|
16
|
+
|
|
17
|
+
# OpenAI
|
|
18
|
+
TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"
|
|
19
|
+
TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"
|
|
20
|
+
TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"
|
|
21
|
+
|
|
22
|
+
# Cohere
|
|
23
|
+
EMBED_ENGLISH_V3 = "embed-english-v3.0"
|
|
24
|
+
EMBED_MULTILINGUAL_V3 = "embed-multilingual-v3.0"
|
|
25
|
+
EMBED_ENGLISH_LIGHT_V3 = "embed-english-light-v3.0"
|
|
26
|
+
|
|
27
|
+
# Voyage
|
|
28
|
+
VOYAGE_LARGE_2 = "voyage-large-2"
|
|
29
|
+
VOYAGE_CODE_2 = "voyage-code-2"
|
|
30
|
+
|
|
31
|
+
# Google
|
|
32
|
+
TEXT_EMBEDDING_004 = "text-embedding-004"
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def dimensions(self) -> int:
|
|
36
|
+
"""Get default dimensions for the model."""
|
|
37
|
+
dimensions_map = {
|
|
38
|
+
"text-embedding-3-small": 1536,
|
|
39
|
+
"text-embedding-3-large": 3072,
|
|
40
|
+
"text-embedding-ada-002": 1536,
|
|
41
|
+
"embed-english-v3.0": 1024,
|
|
42
|
+
"embed-multilingual-v3.0": 1024,
|
|
43
|
+
"embed-english-light-v3.0": 384,
|
|
44
|
+
"voyage-large-2": 1536,
|
|
45
|
+
"voyage-code-2": 1536,
|
|
46
|
+
"text-embedding-004": 768,
|
|
47
|
+
}
|
|
48
|
+
return dimensions_map.get(self.value, 1536)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def max_tokens(self) -> int:
|
|
52
|
+
"""Get maximum input tokens for the model."""
|
|
53
|
+
max_tokens_map = {
|
|
54
|
+
"text-embedding-3-small": 8191,
|
|
55
|
+
"text-embedding-3-large": 8191,
|
|
56
|
+
"text-embedding-ada-002": 8191,
|
|
57
|
+
"embed-english-v3.0": 512,
|
|
58
|
+
"embed-multilingual-v3.0": 512,
|
|
59
|
+
"embed-english-light-v3.0": 512,
|
|
60
|
+
"voyage-large-2": 16000,
|
|
61
|
+
"voyage-code-2": 16000,
|
|
62
|
+
"text-embedding-004": 2048,
|
|
63
|
+
}
|
|
64
|
+
return max_tokens_map.get(self.value, 8191)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class EmbeddingUsage:
|
|
69
|
+
"""Token usage information for embedding request.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
prompt_tokens: Number of tokens in the input
|
|
73
|
+
total_tokens: Total tokens used
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
prompt_tokens: int = 0
|
|
77
|
+
total_tokens: int = 0
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> dict[str, int]:
|
|
80
|
+
"""Convert to dictionary."""
|
|
81
|
+
return {
|
|
82
|
+
"prompt_tokens": self.prompt_tokens,
|
|
83
|
+
"total_tokens": self.total_tokens,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class Embedding:
|
|
89
|
+
"""A single embedding result.
|
|
90
|
+
|
|
91
|
+
Attributes:
|
|
92
|
+
index: Index in the batch
|
|
93
|
+
vector: The embedding vector
|
|
94
|
+
object_type: Object type (always "embedding")
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
index: int
|
|
98
|
+
vector: list[float]
|
|
99
|
+
object_type: str = "embedding"
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def dimensions(self) -> int:
|
|
103
|
+
"""Get the dimensionality of the embedding."""
|
|
104
|
+
return len(self.vector)
|
|
105
|
+
|
|
106
|
+
def to_dict(self) -> dict[str, Any]:
|
|
107
|
+
"""Convert to dictionary."""
|
|
108
|
+
return {
|
|
109
|
+
"index": self.index,
|
|
110
|
+
"embedding": self.vector,
|
|
111
|
+
"object": self.object_type,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_openai_format(cls, data: dict[str, Any]) -> Embedding:
|
|
116
|
+
"""Create from OpenAI API response format.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
data: OpenAI embedding object
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Embedding instance
|
|
123
|
+
"""
|
|
124
|
+
return cls(
|
|
125
|
+
index=data.get("index", 0),
|
|
126
|
+
vector=data.get("embedding", []),
|
|
127
|
+
object_type=data.get("object", "embedding"),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class EmbeddingRequest:
|
|
133
|
+
"""Request for generating embeddings.
|
|
134
|
+
|
|
135
|
+
Attributes:
|
|
136
|
+
input: Text or list of texts to embed
|
|
137
|
+
model: Model to use
|
|
138
|
+
dimensions: Output dimensions (if supported)
|
|
139
|
+
encoding_format: Output format ("float" or "base64")
|
|
140
|
+
user: User identifier for tracking
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
input: str | list[str]
|
|
144
|
+
model: str = "text-embedding-3-small"
|
|
145
|
+
dimensions: int | None = None
|
|
146
|
+
encoding_format: str = "float"
|
|
147
|
+
user: str | None = None
|
|
148
|
+
|
|
149
|
+
def to_dict(self) -> dict[str, Any]:
|
|
150
|
+
"""Convert to API request format."""
|
|
151
|
+
data: dict[str, Any] = {
|
|
152
|
+
"input": self.input,
|
|
153
|
+
"model": self.model,
|
|
154
|
+
}
|
|
155
|
+
if self.dimensions is not None:
|
|
156
|
+
data["dimensions"] = self.dimensions
|
|
157
|
+
if self.encoding_format != "float":
|
|
158
|
+
data["encoding_format"] = self.encoding_format
|
|
159
|
+
if self.user:
|
|
160
|
+
data["user"] = self.user
|
|
161
|
+
return data
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def is_batch(self) -> bool:
|
|
165
|
+
"""Check if this is a batch request."""
|
|
166
|
+
return isinstance(self.input, list)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def batch_size(self) -> int:
|
|
170
|
+
"""Get the number of inputs."""
|
|
171
|
+
if isinstance(self.input, list):
|
|
172
|
+
return len(self.input)
|
|
173
|
+
return 1
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclass
|
|
177
|
+
class EmbeddingResponse:
|
|
178
|
+
"""Response from embedding generation.
|
|
179
|
+
|
|
180
|
+
Attributes:
|
|
181
|
+
embeddings: List of embedding results
|
|
182
|
+
model: Model used
|
|
183
|
+
usage: Token usage information
|
|
184
|
+
object_type: Object type (always "list")
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
embeddings: list[Embedding] = field(default_factory=list)
|
|
188
|
+
model: str = ""
|
|
189
|
+
usage: EmbeddingUsage = field(default_factory=EmbeddingUsage)
|
|
190
|
+
object_type: str = "list"
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def first(self) -> Embedding | None:
|
|
194
|
+
"""Get the first embedding."""
|
|
195
|
+
return self.embeddings[0] if self.embeddings else None
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def vectors(self) -> list[list[float]]:
|
|
199
|
+
"""Get all vectors."""
|
|
200
|
+
return [e.vector for e in self.embeddings]
|
|
201
|
+
|
|
202
|
+
def to_dict(self) -> dict[str, Any]:
|
|
203
|
+
"""Convert to dictionary."""
|
|
204
|
+
return {
|
|
205
|
+
"object": self.object_type,
|
|
206
|
+
"data": [e.to_dict() for e in self.embeddings],
|
|
207
|
+
"model": self.model,
|
|
208
|
+
"usage": self.usage.to_dict(),
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def from_openai_format(cls, data: dict[str, Any]) -> EmbeddingResponse:
|
|
213
|
+
"""Create from OpenAI API response format.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
data: OpenAI embeddings response
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
EmbeddingResponse instance
|
|
220
|
+
"""
|
|
221
|
+
embeddings = [
|
|
222
|
+
Embedding.from_openai_format(e) for e in data.get("data", [])
|
|
223
|
+
]
|
|
224
|
+
usage_data = data.get("usage", {})
|
|
225
|
+
usage = EmbeddingUsage(
|
|
226
|
+
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
227
|
+
total_tokens=usage_data.get("total_tokens", 0),
|
|
228
|
+
)
|
|
229
|
+
return cls(
|
|
230
|
+
embeddings=embeddings,
|
|
231
|
+
model=data.get("model", ""),
|
|
232
|
+
usage=usage,
|
|
233
|
+
object_type=data.get("object", "list"),
|
|
234
|
+
)
|