aisberg 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aisberg/api/async_endpoints.py +138 -20
- aisberg/api/endpoints.py +136 -18
- aisberg/async_client.py +8 -0
- aisberg/client.py +8 -0
- aisberg/config.py +6 -0
- aisberg/models/collections.py +15 -1
- aisberg/models/documents.py +46 -0
- aisberg/models/requests.py +5 -1
- aisberg/modules/__init__.py +5 -0
- aisberg/modules/chat.py +11 -3
- aisberg/modules/collections.py +360 -7
- aisberg/modules/documents.py +168 -0
- aisberg/modules/embeddings.py +11 -3
- aisberg/modules/me.py +1 -1
- aisberg/modules/models.py +3 -3
- aisberg/modules/s3.py +316 -0
- aisberg/modules/workflows.py +3 -3
- {aisberg-0.1.0.dist-info → aisberg-0.2.0.dist-info}/METADATA +16 -3
- {aisberg-0.1.0.dist-info → aisberg-0.2.0.dist-info}/RECORD +24 -21
- tmp/test_collection.py +65 -0
- tmp/test_doc_parse.py +31 -7
- aisberg/modules/document.py +0 -117
- {aisberg-0.1.0.dist-info → aisberg-0.2.0.dist-info}/WHEEL +0 -0
- {aisberg-0.1.0.dist-info → aisberg-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {aisberg-0.1.0.dist-info → aisberg-0.2.0.dist-info}/top_level.txt +0 -0
aisberg/modules/collections.py
CHANGED
@@ -1,11 +1,19 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List, Union, Optional
|
2
2
|
from abc import ABC
|
3
3
|
|
4
|
-
|
4
|
+
import json
|
5
|
+
from ..models.collections import (
|
6
|
+
GroupCollections,
|
7
|
+
Collection,
|
8
|
+
CollectionDetails,
|
9
|
+
CollectionDataset,
|
10
|
+
)
|
5
11
|
|
6
12
|
from abc import abstractmethod
|
7
13
|
from ..abstract.modules import SyncModule, AsyncModule
|
8
14
|
from ..api import endpoints, async_endpoints
|
15
|
+
from ..models.requests import HttpxFileField
|
16
|
+
from io import BytesIO
|
9
17
|
|
10
18
|
|
11
19
|
class AbstractCollectionsModule(ABC):
|
@@ -25,7 +33,7 @@ class AbstractCollectionsModule(ABC):
|
|
25
33
|
ValueError: If no collections are found.
|
26
34
|
Exception: If there is an error fetching the collections.
|
27
35
|
"""
|
28
|
-
|
36
|
+
...
|
29
37
|
|
30
38
|
@abstractmethod
|
31
39
|
def get_by_group(self, group_id: str) -> List[Collection]:
|
@@ -42,7 +50,7 @@ class AbstractCollectionsModule(ABC):
|
|
42
50
|
ValueError: If no collections are found for the specified group ID.
|
43
51
|
Exception: If there is an error fetching the collections.
|
44
52
|
"""
|
45
|
-
|
53
|
+
...
|
46
54
|
|
47
55
|
@abstractmethod
|
48
56
|
def details(self, collection_id: str, group_id: str) -> CollectionDetails:
|
@@ -59,7 +67,126 @@ class AbstractCollectionsModule(ABC):
|
|
59
67
|
Raises:
|
60
68
|
ValueError: If the specified collection is not found.
|
61
69
|
"""
|
62
|
-
|
70
|
+
...
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
def delete(self, name: str, **kwargs) -> bool:
|
74
|
+
"""
|
75
|
+
Delete a collection by name and group ID.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
name (str): The name of the collection to delete.
|
79
|
+
**kwargs: Additional keyword arguments, such as group ID.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
bool: True if the deletion was successful, False otherwise.
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
ValueError: If the collection could not be deleted.
|
86
|
+
Exception: If there is an error during the deletion process.
|
87
|
+
"""
|
88
|
+
...
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def create(
|
92
|
+
self,
|
93
|
+
name: str,
|
94
|
+
data: Union[dict, CollectionDataset, str],
|
95
|
+
embedding_model: Optional[str] = "BAAI/bge-m3",
|
96
|
+
normalize: bool = False,
|
97
|
+
**kwargs,
|
98
|
+
) -> CollectionDetails:
|
99
|
+
"""
|
100
|
+
Create a new collection.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
name (str): The name of the collection to create.
|
104
|
+
data (Union[dict, CollectionDataset, str]): The data to insert into the collection.
|
105
|
+
Can be a Dict, aCollectionDataset object or a string representing the file path.
|
106
|
+
embedding_model (Optional[str]): The embedding model to use for the collection.
|
107
|
+
Defaults to "BAAI/bge-m3".
|
108
|
+
normalize (bool): Whether to normalize the data before inserting it into the collection. Defaults to False.
|
109
|
+
**kwargs: Additional keyword arguments, such as group ID.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
CollectionDetails: The details of the created collection.
|
113
|
+
Raises:
|
114
|
+
ValueError: If the collection could not be created.
|
115
|
+
Exception: If there is an error during the creation process.
|
116
|
+
"""
|
117
|
+
...
|
118
|
+
|
119
|
+
@abstractmethod
|
120
|
+
def insert_points(
|
121
|
+
self,
|
122
|
+
collection_name: str,
|
123
|
+
data: Union[dict, CollectionDataset, str],
|
124
|
+
normalize: bool = False,
|
125
|
+
**kwargs,
|
126
|
+
) -> CollectionDetails:
|
127
|
+
"""
|
128
|
+
Insert points into an existing collection. All existing points in the collection won't be deleted.
|
129
|
+
This method is used to add new data to an existing collection without removing the previous data.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
collection_name (str): The name of the collection to create.
|
133
|
+
data (Union[dict, CollectionDataset, str]): The data to insert into the collection.
|
134
|
+
Can be a Dict, aCollectionDataset object or a string representing the file path.
|
135
|
+
normalize (bool): If collection already have points, the normalize parameter will be ignored. Defaults to False.
|
136
|
+
**kwargs: Additional keyword arguments, such as group ID.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
CollectionDetails: The details of the created collection.
|
140
|
+
Raises:
|
141
|
+
ValueError: If the collection could not be created.
|
142
|
+
Exception: If there is an error during the creation process.
|
143
|
+
"""
|
144
|
+
...
|
145
|
+
|
146
|
+
@abstractmethod
|
147
|
+
def delete_points(
|
148
|
+
self,
|
149
|
+
collection_name: str,
|
150
|
+
points: List[str],
|
151
|
+
**kwargs,
|
152
|
+
) -> CollectionDetails:
|
153
|
+
"""
|
154
|
+
Delete points into an existing collection. Points with the specified IDs will be removed from the collection.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
collection_name (str): The name of the collection to create.
|
158
|
+
points (List[str]): The list of point IDs to delete from the collection.
|
159
|
+
**kwargs: Additional keyword arguments, such as group ID.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
CollectionDetails: The details of the created collection.
|
163
|
+
Raises:
|
164
|
+
ValueError: If the collection could not be created.
|
165
|
+
Exception: If there is an error during the creation process.
|
166
|
+
"""
|
167
|
+
...
|
168
|
+
|
169
|
+
@abstractmethod
|
170
|
+
def clear(
|
171
|
+
self,
|
172
|
+
collection_name: str,
|
173
|
+
**kwargs,
|
174
|
+
) -> CollectionDetails:
|
175
|
+
"""
|
176
|
+
Delete ALL points into an existing collection. All points will be removed from the collection. But the collection itself will not be deleted.
|
177
|
+
So you will still be able to insert new points into the collection without creating a new one.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
collection_name (str): The name of the collection to create.
|
181
|
+
**kwargs: Additional keyword arguments, such as group ID.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
CollectionDetails: The details of the created collection.
|
185
|
+
Raises:
|
186
|
+
ValueError: If the collection could not be created.
|
187
|
+
Exception: If there is an error during the creation process.
|
188
|
+
"""
|
189
|
+
...
|
63
190
|
|
64
191
|
@staticmethod
|
65
192
|
def _get_collections_by_group(
|
@@ -70,6 +197,43 @@ class AbstractCollectionsModule(ABC):
|
|
70
197
|
return group.collections
|
71
198
|
raise ValueError("No collections found for group ID")
|
72
199
|
|
200
|
+
@staticmethod
|
201
|
+
def _data_to_httpx_file(
|
202
|
+
data: Union[dict, CollectionDataset, str],
|
203
|
+
) -> HttpxFileField:
|
204
|
+
"""
|
205
|
+
Prepare a JSON payload as a HTTPX file field (for multipart upload).
|
206
|
+
|
207
|
+
Args:
|
208
|
+
data (dict | CollectionDataset | str): The dataset as dict/obj or a path to a JSON file.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
HttpxFileField: List suitable for HTTPX multipart upload.
|
212
|
+
"""
|
213
|
+
if isinstance(data, str):
|
214
|
+
with open(data, "r", encoding="utf-8") as f:
|
215
|
+
coll_dict = json.load(f)
|
216
|
+
filename = data.split("/")[-1]
|
217
|
+
elif isinstance(data, CollectionDataset):
|
218
|
+
coll_dict = data if isinstance(data, dict) else data.model_dump()
|
219
|
+
filename = "collection.json"
|
220
|
+
elif isinstance(data, dict):
|
221
|
+
if "chunks" in data and "metadata" in data:
|
222
|
+
coll_dict = data
|
223
|
+
filename = "collection.json"
|
224
|
+
else:
|
225
|
+
raise ValueError(
|
226
|
+
"data must be a dict with 'chunks' and 'metadata' keys"
|
227
|
+
)
|
228
|
+
else:
|
229
|
+
raise ValueError(
|
230
|
+
"data must be a dict, CollectionDataset, or file path string"
|
231
|
+
)
|
232
|
+
|
233
|
+
json_bytes = json.dumps(coll_dict, ensure_ascii=False).encode("utf-8")
|
234
|
+
file_tuple = ("files", (filename, BytesIO(json_bytes), "application/json"))
|
235
|
+
return [file_tuple]
|
236
|
+
|
73
237
|
|
74
238
|
class SyncCollectionsModule(SyncModule, AbstractCollectionsModule):
|
75
239
|
def __init__(self, parent, client):
|
@@ -83,7 +247,9 @@ class SyncCollectionsModule(SyncModule, AbstractCollectionsModule):
|
|
83
247
|
collections = self.list()
|
84
248
|
return self._get_collections_by_group(collections, group_id)
|
85
249
|
|
86
|
-
def details(
|
250
|
+
def details(
|
251
|
+
self, collection_id: str, group_id: Optional[str] = None
|
252
|
+
) -> CollectionDetails:
|
87
253
|
points = endpoints.collection(self._client, collection_id, group_id)
|
88
254
|
if points is None:
|
89
255
|
raise ValueError("No collection found")
|
@@ -93,6 +259,98 @@ class SyncCollectionsModule(SyncModule, AbstractCollectionsModule):
|
|
93
259
|
points=points,
|
94
260
|
)
|
95
261
|
|
262
|
+
def delete(self, name: str, **kwargs) -> bool:
|
263
|
+
response = endpoints.delete_collection(self._client, name, **kwargs)
|
264
|
+
if response is None:
|
265
|
+
raise ValueError("Collection could not be deleted")
|
266
|
+
return True
|
267
|
+
|
268
|
+
def create(
|
269
|
+
self,
|
270
|
+
name: str,
|
271
|
+
data: Union[dict, CollectionDataset, str],
|
272
|
+
embedding_model: Optional[str] = "BAAI/bge-m3",
|
273
|
+
normalize: bool = False,
|
274
|
+
**kwargs,
|
275
|
+
) -> CollectionDetails:
|
276
|
+
create = endpoints.create_collection(
|
277
|
+
self._client, name, embedding_model, **kwargs
|
278
|
+
)
|
279
|
+
if create.message != "Creation started":
|
280
|
+
raise ValueError("Collection could not be created")
|
281
|
+
|
282
|
+
insert = endpoints.insert_points_in_collection(
|
283
|
+
self._client,
|
284
|
+
name,
|
285
|
+
self._data_to_httpx_file(data),
|
286
|
+
normalize,
|
287
|
+
**kwargs,
|
288
|
+
)
|
289
|
+
if insert.message != f"Documents inserted in {name}.":
|
290
|
+
raise ValueError("Points could not be inserted into the collection")
|
291
|
+
|
292
|
+
return self.details(name, kwargs.get("group", None))
|
293
|
+
|
294
|
+
def insert_points(
|
295
|
+
self,
|
296
|
+
collection_name: str,
|
297
|
+
data: Union[dict, CollectionDataset, str],
|
298
|
+
normalize: bool = False,
|
299
|
+
**kwargs,
|
300
|
+
) -> CollectionDetails:
|
301
|
+
insert = endpoints.insert_points_in_collection(
|
302
|
+
self._client,
|
303
|
+
collection_name,
|
304
|
+
self._data_to_httpx_file(data),
|
305
|
+
normalize,
|
306
|
+
)
|
307
|
+
if insert.message != f"Documents inserted in {collection_name}.":
|
308
|
+
raise ValueError(
|
309
|
+
f"Points could not be inserted into the collection : {insert.model_dump_json()}"
|
310
|
+
)
|
311
|
+
return self.details(collection_name, kwargs.get("group", None))
|
312
|
+
|
313
|
+
def delete_points(
|
314
|
+
self,
|
315
|
+
collection_name: str,
|
316
|
+
points: List[str],
|
317
|
+
**kwargs,
|
318
|
+
) -> CollectionDetails:
|
319
|
+
delete = endpoints.delete_points_in_collection(
|
320
|
+
self._client,
|
321
|
+
points,
|
322
|
+
collection_name,
|
323
|
+
**kwargs,
|
324
|
+
)
|
325
|
+
if (
|
326
|
+
f'{len(points)} points deleted from collection "{collection_name}"'
|
327
|
+
not in delete.message
|
328
|
+
):
|
329
|
+
raise ValueError(
|
330
|
+
f"Points could not be deleted from the collection : {delete.model_dump_json()}"
|
331
|
+
)
|
332
|
+
|
333
|
+
return self.details(collection_name, kwargs.get("group", None))
|
334
|
+
|
335
|
+
def clear(
|
336
|
+
self,
|
337
|
+
collection_name: str,
|
338
|
+
**kwargs,
|
339
|
+
) -> CollectionDetails:
|
340
|
+
clear = endpoints.delete_all_points_in_collection(
|
341
|
+
self._client,
|
342
|
+
collection_name,
|
343
|
+
**kwargs,
|
344
|
+
)
|
345
|
+
if (
|
346
|
+
f'All points deleted from collection "{collection_name}" for group'
|
347
|
+
not in clear.message
|
348
|
+
):
|
349
|
+
raise ValueError(
|
350
|
+
f"Points could not be deleted from the collection : {clear.model_dump_json()}"
|
351
|
+
)
|
352
|
+
return self.details(collection_name, kwargs.get("group", None))
|
353
|
+
|
96
354
|
|
97
355
|
class AsyncCollectionsModule(AsyncModule, AbstractCollectionsModule):
|
98
356
|
def __init__(self, parent, client):
|
@@ -106,7 +364,9 @@ class AsyncCollectionsModule(AsyncModule, AbstractCollectionsModule):
|
|
106
364
|
collections = await self.list()
|
107
365
|
return self._get_collections_by_group(collections, group_id)
|
108
366
|
|
109
|
-
async def details(
|
367
|
+
async def details(
|
368
|
+
self, collection_id: str, group_id: Optional[str] = None
|
369
|
+
) -> CollectionDetails:
|
110
370
|
points = await async_endpoints.collection(self._client, collection_id, group_id)
|
111
371
|
if points is None:
|
112
372
|
raise ValueError("No collection found")
|
@@ -115,3 +375,96 @@ class AsyncCollectionsModule(AsyncModule, AbstractCollectionsModule):
|
|
115
375
|
group=group_id,
|
116
376
|
points=points,
|
117
377
|
)
|
378
|
+
|
379
|
+
async def delete(self, name: str, **kwargs) -> bool:
|
380
|
+
response = await async_endpoints.delete_collection(self._client, name, **kwargs)
|
381
|
+
if response is None:
|
382
|
+
raise ValueError("Collection could not be deleted")
|
383
|
+
return True
|
384
|
+
|
385
|
+
async def create(
|
386
|
+
self,
|
387
|
+
name: str,
|
388
|
+
data: Union[dict, CollectionDataset, str],
|
389
|
+
embedding_model: Optional[str] = "BAAI/bge-m3",
|
390
|
+
normalize: bool = False,
|
391
|
+
**kwargs,
|
392
|
+
) -> CollectionDetails:
|
393
|
+
create = await async_endpoints.create_collection(
|
394
|
+
self._client, name, embedding_model, **kwargs
|
395
|
+
)
|
396
|
+
if create.message != "Creation started":
|
397
|
+
raise ValueError("Collection could not be created")
|
398
|
+
|
399
|
+
insert = await async_endpoints.insert_points_in_collection(
|
400
|
+
self._client,
|
401
|
+
name,
|
402
|
+
self._data_to_httpx_file(data),
|
403
|
+
normalize,
|
404
|
+
**kwargs,
|
405
|
+
)
|
406
|
+
if insert.message != f"Documents inserted in {name}.":
|
407
|
+
raise ValueError("Points could not be inserted into the collection")
|
408
|
+
|
409
|
+
return await self.details(name, kwargs.get("group", None))
|
410
|
+
|
411
|
+
async def insert_points(
|
412
|
+
self,
|
413
|
+
collection_name: str,
|
414
|
+
data: Union[dict, CollectionDataset, str],
|
415
|
+
normalize: bool = False,
|
416
|
+
**kwargs,
|
417
|
+
) -> CollectionDetails:
|
418
|
+
insert = await async_endpoints.insert_points_in_collection(
|
419
|
+
self._client,
|
420
|
+
collection_name,
|
421
|
+
self._data_to_httpx_file(data),
|
422
|
+
normalize,
|
423
|
+
)
|
424
|
+
if insert.message != f"Documents inserted in {collection_name}.":
|
425
|
+
raise ValueError(
|
426
|
+
f"Points could not be inserted into the collection : {insert.model_dump_json()}"
|
427
|
+
)
|
428
|
+
|
429
|
+
return await self.details(collection_name, kwargs.get("group", None))
|
430
|
+
|
431
|
+
async def delete_points(
|
432
|
+
self,
|
433
|
+
collection_name: str,
|
434
|
+
points: List[str],
|
435
|
+
**kwargs,
|
436
|
+
) -> CollectionDetails:
|
437
|
+
delete = await async_endpoints.delete_points_in_collection(
|
438
|
+
self._client,
|
439
|
+
points,
|
440
|
+
collection_name,
|
441
|
+
**kwargs,
|
442
|
+
)
|
443
|
+
if (
|
444
|
+
f'{len(points)} points deleted from collection "{collection_name}"'
|
445
|
+
not in delete.message
|
446
|
+
):
|
447
|
+
raise ValueError(
|
448
|
+
f"Points could not be deleted from the collection : {delete.model_dump_json()}"
|
449
|
+
)
|
450
|
+
|
451
|
+
return await self.details(collection_name, kwargs.get("group", None))
|
452
|
+
|
453
|
+
async def clear(
|
454
|
+
self,
|
455
|
+
collection_name: str,
|
456
|
+
**kwargs,
|
457
|
+
) -> CollectionDetails:
|
458
|
+
clear = await async_endpoints.delete_all_points_in_collection(
|
459
|
+
self._client,
|
460
|
+
collection_name,
|
461
|
+
**kwargs,
|
462
|
+
)
|
463
|
+
if (
|
464
|
+
f'All points deleted from collection "{collection_name}" for group'
|
465
|
+
not in clear.message
|
466
|
+
):
|
467
|
+
raise ValueError(
|
468
|
+
f"Points could not be deleted from the collection : {clear.model_dump_json()}"
|
469
|
+
)
|
470
|
+
return await self.details(collection_name, kwargs.get("group", None))
|
@@ -0,0 +1,168 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from ..api import endpoints, async_endpoints
|
3
|
+
from ..abstract.modules import SyncModule, AsyncModule
|
4
|
+
from ..models.documents import (
|
5
|
+
FileObject,
|
6
|
+
DocumentParserFileInput,
|
7
|
+
ParsedDocument,
|
8
|
+
)
|
9
|
+
from typing import List
|
10
|
+
import json
|
11
|
+
from io import BytesIO
|
12
|
+
import logging
|
13
|
+
|
14
|
+
from ..models.requests import HttpxFileField
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class AbstractDocumentsModule(ABC):
|
20
|
+
def __init__(self, parent, client):
|
21
|
+
self._parent = parent
|
22
|
+
self._client = client
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def parse(
|
26
|
+
self, files: DocumentParserFileInput, **kwargs
|
27
|
+
) -> List[ParsedDocument]: ...
|
28
|
+
|
29
|
+
def _get_parsed_files_from_s3(
|
30
|
+
self, files: List[str], bucket_name: str
|
31
|
+
) -> List[ParsedDocument]:
|
32
|
+
"""
|
33
|
+
Download and parse a list of files from an S3 bucket.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
files (List[str]): List of file names to download from S3.
|
37
|
+
bucket_name (str): Name of the S3 bucket.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
List[ParsedDocument]: Parsed documents as objects with content and metadata.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
Exception: If a file cannot be downloaded or parsed.
|
44
|
+
"""
|
45
|
+
parsed_documents = []
|
46
|
+
for file_name in files:
|
47
|
+
if not file_name.endswith(".json"):
|
48
|
+
if '"type": "error"' in file_name:
|
49
|
+
logger.error(f"[DOCUMENT PARSER] Parsing failed => {file_name}. ")
|
50
|
+
continue
|
51
|
+
|
52
|
+
logger.debug(f"Downloading file {file_name} from bucket {bucket_name}")
|
53
|
+
# Download the file as a BytesIO
|
54
|
+
doc_bytesio = self._parent._s3.download_file(bucket_name, file_name)
|
55
|
+
try:
|
56
|
+
buffer = doc_bytesio.getvalue()
|
57
|
+
content_str = buffer.decode("utf-8")
|
58
|
+
content_json = json.loads(content_str)
|
59
|
+
finally:
|
60
|
+
doc_bytesio.close()
|
61
|
+
file_object = FileObject(name=file_name, buffer=buffer)
|
62
|
+
parsed_documents.append(
|
63
|
+
ParsedDocument(
|
64
|
+
content=content_json, metadata={"name": file_object.name}
|
65
|
+
)
|
66
|
+
)
|
67
|
+
return parsed_documents
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def _prepare_files_payload(
|
71
|
+
files: DocumentParserFileInput,
|
72
|
+
) -> HttpxFileField:
|
73
|
+
"""
|
74
|
+
Prepares input files into a format compatible with HTTPX multipart uploads.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
files (DocumentParserFileInput): Files to upload (see type for options).
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
HttpxFileField: HTTPX-style list for multipart upload.
|
81
|
+
|
82
|
+
Raises:
|
83
|
+
TypeError: On unsupported type.
|
84
|
+
"""
|
85
|
+
|
86
|
+
def to_file_tuple(item):
|
87
|
+
# FileObject case
|
88
|
+
if "FileObject" in globals() and isinstance(item, FileObject):
|
89
|
+
content = item.buffer
|
90
|
+
filename = item.name
|
91
|
+
# (bytes, filename) tuple
|
92
|
+
elif isinstance(item, tuple) and len(item) == 2:
|
93
|
+
content, filename = item
|
94
|
+
# bytes or BytesIO
|
95
|
+
elif isinstance(item, (bytes, BytesIO)):
|
96
|
+
content = item
|
97
|
+
filename = "file"
|
98
|
+
# str (filepath)
|
99
|
+
elif isinstance(item, str):
|
100
|
+
with open(item, "rb") as f:
|
101
|
+
content = f.read()
|
102
|
+
filename = item.split("/")[-1]
|
103
|
+
else:
|
104
|
+
raise TypeError(
|
105
|
+
f"Unsupported file input type: {type(item)}. "
|
106
|
+
"Expected str, bytes, BytesIO, tuple, or FileObject."
|
107
|
+
)
|
108
|
+
# Normalize to BytesIO for HTTPX
|
109
|
+
if isinstance(content, bytes):
|
110
|
+
content = BytesIO(content)
|
111
|
+
elif isinstance(content, BytesIO):
|
112
|
+
content.seek(0)
|
113
|
+
else:
|
114
|
+
raise TypeError(
|
115
|
+
f"File content must be bytes or BytesIO, got {type(content)}"
|
116
|
+
)
|
117
|
+
return (filename, content)
|
118
|
+
|
119
|
+
if isinstance(files, list):
|
120
|
+
if len(files) == 0:
|
121
|
+
raise ValueError("File list cannot be empty.")
|
122
|
+
elif len(files) > 10:
|
123
|
+
raise ValueError("Too many files provided. Maximum is 10.")
|
124
|
+
|
125
|
+
normalized = [to_file_tuple(f) for f in files]
|
126
|
+
else:
|
127
|
+
normalized = [to_file_tuple(files)]
|
128
|
+
|
129
|
+
# HTTPX format: [("files", (filename, fileobj, mimetype)), ...]
|
130
|
+
httpx_files = [
|
131
|
+
("files", (filename, content, "application/octet-stream"))
|
132
|
+
for filename, content in normalized
|
133
|
+
]
|
134
|
+
return httpx_files
|
135
|
+
|
136
|
+
|
137
|
+
class SyncDocumentsModule(SyncModule, AbstractDocumentsModule):
|
138
|
+
def __init__(self, parent, client):
|
139
|
+
SyncModule.__init__(self, parent, client)
|
140
|
+
AbstractDocumentsModule.__init__(self, parent, client)
|
141
|
+
|
142
|
+
def parse(self, files, **kwargs) -> List[ParsedDocument]:
|
143
|
+
output = endpoints.parse_documents(
|
144
|
+
self._client,
|
145
|
+
self._prepare_files_payload(files),
|
146
|
+
**kwargs,
|
147
|
+
)
|
148
|
+
if output.message == "Files parsed successfully":
|
149
|
+
return self._get_parsed_files_from_s3(output.parsedFiles, output.bucketName)
|
150
|
+
else:
|
151
|
+
raise ValueError(f"Error parsing files: {output.message}")
|
152
|
+
|
153
|
+
|
154
|
+
class AsyncDocumentsModule(AsyncModule, AbstractDocumentsModule):
|
155
|
+
def __init__(self, parent, client):
|
156
|
+
AsyncModule.__init__(self, parent, client)
|
157
|
+
AbstractDocumentsModule.__init__(self, parent, client)
|
158
|
+
|
159
|
+
async def parse(self, files, **kwargs) -> List[ParsedDocument]:
|
160
|
+
output = await async_endpoints.parse_documents(
|
161
|
+
self._client,
|
162
|
+
self._prepare_files_payload(files),
|
163
|
+
**kwargs,
|
164
|
+
)
|
165
|
+
if output.message == "Files parsed successfully":
|
166
|
+
return self._get_parsed_files_from_s3(output.parsedFiles, output.bucketName)
|
167
|
+
else:
|
168
|
+
raise ValueError(f"Error parsing files: {output.message}")
|
aisberg/modules/embeddings.py
CHANGED
@@ -50,7 +50,7 @@ class AbstractEmbeddingsModule(ABC):
|
|
50
50
|
Returns:
|
51
51
|
EncodingResponse: The response containing the encoded embeddings.
|
52
52
|
"""
|
53
|
-
|
53
|
+
...
|
54
54
|
|
55
55
|
@abstractmethod
|
56
56
|
def retrieve(
|
@@ -75,7 +75,7 @@ class AbstractEmbeddingsModule(ABC):
|
|
75
75
|
Returns:
|
76
76
|
List[ChunkData]: A list of ChunkData objects containing the retrieved texts and their metadata.
|
77
77
|
"""
|
78
|
-
|
78
|
+
...
|
79
79
|
|
80
80
|
@abstractmethod
|
81
81
|
def rerank(
|
@@ -104,7 +104,7 @@ class AbstractEmbeddingsModule(ABC):
|
|
104
104
|
ValueError: If the documents list is empty or contains invalid document types.
|
105
105
|
Exception: If the documents list is not of the expected type.
|
106
106
|
"""
|
107
|
-
|
107
|
+
...
|
108
108
|
|
109
109
|
@staticmethod
|
110
110
|
def _format_collections_names(
|
@@ -192,6 +192,7 @@ class SyncEmbeddingsModule(AbstractEmbeddingsModule, SyncModule):
|
|
192
192
|
score_threshold: float = 0.0,
|
193
193
|
filters: List = None,
|
194
194
|
beta: float = 0.7,
|
195
|
+
**kwargs,
|
195
196
|
) -> ChunksDataList:
|
196
197
|
if filters is None:
|
197
198
|
filters = []
|
@@ -204,6 +205,7 @@ class SyncEmbeddingsModule(AbstractEmbeddingsModule, SyncModule):
|
|
204
205
|
score_threshold=score_threshold,
|
205
206
|
filters=filters,
|
206
207
|
beta=beta,
|
208
|
+
**kwargs,
|
207
209
|
)
|
208
210
|
return ChunksDataList.model_validate(resp)
|
209
211
|
|
@@ -215,6 +217,7 @@ class SyncEmbeddingsModule(AbstractEmbeddingsModule, SyncModule):
|
|
215
217
|
top_n: int = 10,
|
216
218
|
return_documents: bool = True,
|
217
219
|
threshold: Optional[float] = None,
|
220
|
+
**kwargs,
|
218
221
|
) -> RerankerResponse:
|
219
222
|
resp = endpoints.rerank(
|
220
223
|
self._client,
|
@@ -223,6 +226,7 @@ class SyncEmbeddingsModule(AbstractEmbeddingsModule, SyncModule):
|
|
223
226
|
model,
|
224
227
|
top_n,
|
225
228
|
return_documents,
|
229
|
+
**kwargs,
|
226
230
|
)
|
227
231
|
resp = RerankerResponse.model_validate(resp)
|
228
232
|
|
@@ -269,6 +273,7 @@ class AsyncEmbeddingsModule(AbstractEmbeddingsModule, AsyncModule):
|
|
269
273
|
score_threshold: float = 0.0,
|
270
274
|
filters: List = None,
|
271
275
|
beta: float = 0.7,
|
276
|
+
**kwargs,
|
272
277
|
) -> ChunksDataList:
|
273
278
|
if filters is None:
|
274
279
|
filters = []
|
@@ -281,6 +286,7 @@ class AsyncEmbeddingsModule(AbstractEmbeddingsModule, AsyncModule):
|
|
281
286
|
score_threshold=score_threshold,
|
282
287
|
filters=filters,
|
283
288
|
beta=beta,
|
289
|
+
**kwargs,
|
284
290
|
)
|
285
291
|
return ChunksDataList.model_validate(resp)
|
286
292
|
|
@@ -292,6 +298,7 @@ class AsyncEmbeddingsModule(AbstractEmbeddingsModule, AsyncModule):
|
|
292
298
|
top_n: int = 10,
|
293
299
|
return_documents: bool = True,
|
294
300
|
threshold: Optional[float] = None,
|
301
|
+
**kwargs,
|
295
302
|
) -> RerankerResponse:
|
296
303
|
resp = await async_endpoints.rerank(
|
297
304
|
self._client,
|
@@ -300,6 +307,7 @@ class AsyncEmbeddingsModule(AbstractEmbeddingsModule, AsyncModule):
|
|
300
307
|
model,
|
301
308
|
top_n,
|
302
309
|
return_documents,
|
310
|
+
**kwargs,
|
303
311
|
)
|
304
312
|
resp = RerankerResponse.model_validate(resp)
|
305
313
|
|