aisberg 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aisberg/__init__.py +7 -0
- aisberg/abstract/__init__.py +0 -0
- aisberg/abstract/modules.py +57 -0
- aisberg/api/__init__.py +0 -0
- aisberg/api/async_endpoints.py +333 -0
- aisberg/api/endpoints.py +328 -0
- aisberg/async_client.py +107 -0
- aisberg/client.py +108 -0
- aisberg/config.py +17 -0
- aisberg/exceptions.py +22 -0
- aisberg/models/__init__.py +0 -0
- aisberg/models/chat.py +143 -0
- aisberg/models/collections.py +36 -0
- aisberg/models/embeddings.py +92 -0
- aisberg/models/models.py +39 -0
- aisberg/models/requests.py +11 -0
- aisberg/models/token.py +11 -0
- aisberg/models/tools.py +73 -0
- aisberg/models/workflows.py +66 -0
- aisberg/modules/__init__.py +23 -0
- aisberg/modules/chat.py +403 -0
- aisberg/modules/collections.py +117 -0
- aisberg/modules/document.py +117 -0
- aisberg/modules/embeddings.py +309 -0
- aisberg/modules/me.py +77 -0
- aisberg/modules/models.py +108 -0
- aisberg/modules/tools.py +78 -0
- aisberg/modules/workflows.py +140 -0
- aisberg/requests/__init__.py +0 -0
- aisberg/requests/async_requests.py +85 -0
- aisberg/requests/sync_requests.py +85 -0
- aisberg/utils.py +111 -0
- aisberg-0.1.0.dist-info/METADATA +212 -0
- aisberg-0.1.0.dist-info/RECORD +43 -0
- aisberg-0.1.0.dist-info/WHEEL +5 -0
- aisberg-0.1.0.dist-info/licenses/LICENSE +9 -0
- aisberg-0.1.0.dist-info/top_level.txt +3 -0
- tests/integration/test_collections_integration.py +115 -0
- tests/unit/test_collections_sync.py +104 -0
- tmp/test.py +33 -0
- tmp/test_async.py +126 -0
- tmp/test_doc_parse.py +12 -0
- tmp/test_sync.py +146 -0
aisberg/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC
|
4
|
+
from typing import TYPE_CHECKING, Union, Optional
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from ..client import AisbergClient
|
8
|
+
from ..async_client import AisbergAsyncClient
|
9
|
+
from httpx import Client as HttpClient
|
10
|
+
from httpx import AsyncClient as AsyncHttpClient
|
11
|
+
|
12
|
+
|
13
|
+
class BaseModule(ABC):
|
14
|
+
"""Abstract base class for modules in the Aisberg framework."""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
parent: Union["AisbergClient", "AisbergAsyncClient"],
|
19
|
+
http_client: Optional[Union["HttpClient", "AsyncHttpClient"]] = None,
|
20
|
+
):
|
21
|
+
"""
|
22
|
+
Initialize the BaseModule.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
parent (Any): Parent client or module.
|
26
|
+
http_client (Any): HTTP client for making requests.
|
27
|
+
"""
|
28
|
+
self._parent = parent
|
29
|
+
self._client = http_client
|
30
|
+
|
31
|
+
|
32
|
+
class SyncModule(BaseModule):
|
33
|
+
"""Abstract base class for synchronous modules in the Aisberg framework."""
|
34
|
+
|
35
|
+
def __init__(self, parent: "AisbergClient", http_client: "HttpClient"):
|
36
|
+
"""
|
37
|
+
Initialize the SyncModule.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
parent (Any): Parent client or module.
|
41
|
+
http_client (Any): HTTP client for making requests.
|
42
|
+
"""
|
43
|
+
super().__init__(parent, http_client)
|
44
|
+
|
45
|
+
|
46
|
+
class AsyncModule(BaseModule):
|
47
|
+
"""Abstract base class for asynchronous modules in the Aisberg framework."""
|
48
|
+
|
49
|
+
def __init__(self, parent: "AisbergAsyncClient", http_client: "AsyncHttpClient"):
|
50
|
+
"""
|
51
|
+
Initialize the AsyncModule.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
parent (Any): Parent client or module.
|
55
|
+
http_client (Any): HTTP client for making requests.
|
56
|
+
"""
|
57
|
+
super().__init__(parent, http_client)
|
aisberg/api/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,333 @@
|
|
1
|
+
from io import BytesIO
|
2
|
+
|
3
|
+
import httpx
|
4
|
+
from ..models.chat import (
|
5
|
+
LanguageModelInput,
|
6
|
+
format_messages,
|
7
|
+
ChatCompletionResponse,
|
8
|
+
ChatCompletionChunk,
|
9
|
+
)
|
10
|
+
from typing import Optional, AsyncGenerator, Union, List, Any, Tuple
|
11
|
+
|
12
|
+
from ..models.collections import GroupCollections, PointDetails
|
13
|
+
from ..models.embeddings import (
|
14
|
+
EncodingFormat,
|
15
|
+
EncodingResponse,
|
16
|
+
ChunksDataList,
|
17
|
+
RerankerResponse,
|
18
|
+
)
|
19
|
+
from ..models.models import Model
|
20
|
+
from ..models.token import TokenInfo
|
21
|
+
from ..models.workflows import WorkflowDetails, Workflow
|
22
|
+
from ..utils import parse_chat_line, WorkflowLineParser
|
23
|
+
from ..requests.async_requests import areq, areq_stream
|
24
|
+
from ..models.requests import AnyDict, AnyList
|
25
|
+
|
26
|
+
|
27
|
+
async def models(client: httpx.AsyncClient) -> List[Model]:
|
28
|
+
"""
|
29
|
+
Get the list of available models.
|
30
|
+
"""
|
31
|
+
resp = await areq(client, "GET", "/v1/models", AnyDict)
|
32
|
+
data = resp.data
|
33
|
+
if not data or not isinstance(data, list):
|
34
|
+
raise ValueError("Invalid response format for models")
|
35
|
+
return [Model.model_validate(item) for item in data]
|
36
|
+
|
37
|
+
|
38
|
+
async def workflows(client: httpx.AsyncClient) -> List[Workflow]:
|
39
|
+
"""
|
40
|
+
Get the list of available workflows.
|
41
|
+
"""
|
42
|
+
resp = await areq(client, "GET", "/workflow/light", AnyList)
|
43
|
+
return [Workflow.model_validate(item) for item in resp.root]
|
44
|
+
|
45
|
+
|
46
|
+
async def workflow(client: httpx.AsyncClient, workflow_id: str) -> WorkflowDetails:
|
47
|
+
"""
|
48
|
+
Get details of a specific workflow.
|
49
|
+
"""
|
50
|
+
try:
|
51
|
+
resp = await areq(
|
52
|
+
client, "GET", f"/workflow/details/{workflow_id}", WorkflowDetails
|
53
|
+
)
|
54
|
+
return resp
|
55
|
+
except httpx.HTTPStatusError as e:
|
56
|
+
if e.response.status_code == 404:
|
57
|
+
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
58
|
+
raise e
|
59
|
+
|
60
|
+
|
61
|
+
async def collections(client: httpx.AsyncClient) -> List[GroupCollections]:
|
62
|
+
"""
|
63
|
+
Get the list of available collections.
|
64
|
+
"""
|
65
|
+
resp = await areq(client, "GET", "/collections", AnyList)
|
66
|
+
return [GroupCollections.model_validate(item) for item in resp.root]
|
67
|
+
|
68
|
+
|
69
|
+
async def collection(
|
70
|
+
client: httpx.AsyncClient, collection_id: str, group_id: str
|
71
|
+
) -> List[PointDetails]:
|
72
|
+
"""
|
73
|
+
Get details of a specific collection.
|
74
|
+
"""
|
75
|
+
try:
|
76
|
+
resp = await areq(
|
77
|
+
client, "GET", f"/collections/{collection_id}/{group_id}", AnyList
|
78
|
+
)
|
79
|
+
return [PointDetails.model_validate(item) for item in resp.root]
|
80
|
+
except httpx.HTTPStatusError as e:
|
81
|
+
if e.response.status_code == 404:
|
82
|
+
raise ValueError(
|
83
|
+
f"Collection with ID {collection_id} not found in group {group_id}"
|
84
|
+
)
|
85
|
+
raise e
|
86
|
+
|
87
|
+
|
88
|
+
async def me(client: httpx.AsyncClient) -> TokenInfo:
|
89
|
+
"""
|
90
|
+
Get the details of the current user.
|
91
|
+
"""
|
92
|
+
return await areq(client, "GET", "/users/me", TokenInfo)
|
93
|
+
|
94
|
+
|
95
|
+
async def chat(
|
96
|
+
client: httpx.AsyncClient,
|
97
|
+
input: LanguageModelInput,
|
98
|
+
model: str = None,
|
99
|
+
temperature: float = 0.7,
|
100
|
+
tools: Optional[list] = None,
|
101
|
+
group: Optional[str] = None,
|
102
|
+
**kwargs,
|
103
|
+
) -> ChatCompletionResponse:
|
104
|
+
"""
|
105
|
+
Send a chat message and get a response from an LLM endpoint.
|
106
|
+
"""
|
107
|
+
if model is None:
|
108
|
+
raise ValueError("Model must be specified")
|
109
|
+
|
110
|
+
formatted_messages = format_messages(input)
|
111
|
+
|
112
|
+
payload = {
|
113
|
+
"model": model,
|
114
|
+
"messages": formatted_messages,
|
115
|
+
"temperature": temperature,
|
116
|
+
"stream": False,
|
117
|
+
**kwargs,
|
118
|
+
}
|
119
|
+
|
120
|
+
if group is not None:
|
121
|
+
payload["group"] = group
|
122
|
+
|
123
|
+
if tools is not None:
|
124
|
+
payload["tools"] = tools
|
125
|
+
|
126
|
+
return await areq(
|
127
|
+
client,
|
128
|
+
"POST",
|
129
|
+
"/v1/chat/completions",
|
130
|
+
ChatCompletionResponse,
|
131
|
+
json=payload,
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
async def chat_stream(
|
136
|
+
client: httpx.AsyncClient,
|
137
|
+
input: LanguageModelInput,
|
138
|
+
model: str,
|
139
|
+
temperature: float = 0.7,
|
140
|
+
full_chunk: bool = True,
|
141
|
+
group: Optional[str] = None,
|
142
|
+
**kwargs,
|
143
|
+
) -> AsyncGenerator[Union[str, ChatCompletionChunk], None]:
|
144
|
+
"""
|
145
|
+
Stream de complétions OpenAI.
|
146
|
+
- Si `full_chunk` est True (défaut) : chaque yield est le JSON complet du chunk.
|
147
|
+
- Sinon : on garde la compat ascendante → on ne yield que le delta.content + marquages.
|
148
|
+
"""
|
149
|
+
formatted_messages = format_messages(input)
|
150
|
+
|
151
|
+
payload = {
|
152
|
+
"model": model,
|
153
|
+
"messages": formatted_messages,
|
154
|
+
"temperature": temperature,
|
155
|
+
"stream": True,
|
156
|
+
**kwargs,
|
157
|
+
}
|
158
|
+
|
159
|
+
if group is not None:
|
160
|
+
payload["group"] = group
|
161
|
+
|
162
|
+
async for chunk in areq_stream(
|
163
|
+
client,
|
164
|
+
"POST",
|
165
|
+
"/v1/chat/completions",
|
166
|
+
parse_line=lambda line: parse_chat_line(line, full_chunk=full_chunk),
|
167
|
+
json=payload,
|
168
|
+
):
|
169
|
+
data = ChatCompletionChunk.model_validate(chunk)
|
170
|
+
|
171
|
+
if data is None:
|
172
|
+
continue
|
173
|
+
|
174
|
+
if full_chunk:
|
175
|
+
yield data
|
176
|
+
else:
|
177
|
+
yield data.choices[0].delta.content if data.choices else ""
|
178
|
+
|
179
|
+
|
180
|
+
async def embeddings(
|
181
|
+
client: httpx.AsyncClient,
|
182
|
+
input: str,
|
183
|
+
model: str,
|
184
|
+
encoding_format: EncodingFormat,
|
185
|
+
normalize: bool,
|
186
|
+
group: Optional[str] = None,
|
187
|
+
**kwargs,
|
188
|
+
) -> EncodingResponse:
|
189
|
+
"""
|
190
|
+
Get embeddings for a given input using the specified model.
|
191
|
+
"""
|
192
|
+
payload = {
|
193
|
+
"model": model,
|
194
|
+
"input": input,
|
195
|
+
"encoding_format": encoding_format,
|
196
|
+
"normalize": normalize,
|
197
|
+
**kwargs,
|
198
|
+
}
|
199
|
+
|
200
|
+
if group is not None:
|
201
|
+
payload["group"] = group
|
202
|
+
|
203
|
+
return await areq(
|
204
|
+
client,
|
205
|
+
"POST",
|
206
|
+
"/v1/embeddings",
|
207
|
+
EncodingResponse,
|
208
|
+
json=payload,
|
209
|
+
)
|
210
|
+
|
211
|
+
|
212
|
+
async def retrieve(
|
213
|
+
client: httpx.AsyncClient,
|
214
|
+
query: str,
|
215
|
+
collections_names: List[str],
|
216
|
+
limit: int,
|
217
|
+
score_threshold: float,
|
218
|
+
filters: list,
|
219
|
+
beta: float,
|
220
|
+
group: Optional[str] = None,
|
221
|
+
**kwargs,
|
222
|
+
) -> ChunksDataList:
|
223
|
+
"""
|
224
|
+
Retrieve the most relevant documents based on the given query from specified collections.
|
225
|
+
"""
|
226
|
+
data = {
|
227
|
+
"query": query,
|
228
|
+
"collections_names": collections_names,
|
229
|
+
"limit": limit,
|
230
|
+
"score": score_threshold,
|
231
|
+
"filters": filters,
|
232
|
+
"beta": beta,
|
233
|
+
**kwargs,
|
234
|
+
}
|
235
|
+
|
236
|
+
if group is not None:
|
237
|
+
data["group"] = group
|
238
|
+
|
239
|
+
return await areq(
|
240
|
+
client,
|
241
|
+
"POST",
|
242
|
+
"/collections/run/search",
|
243
|
+
ChunksDataList,
|
244
|
+
json=data,
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
async def rerank(
|
249
|
+
client: httpx.AsyncClient,
|
250
|
+
query: str,
|
251
|
+
documents: List[str],
|
252
|
+
model: str,
|
253
|
+
top_n: int,
|
254
|
+
return_documents: bool,
|
255
|
+
group: Optional[str] = None,
|
256
|
+
**kwargs,
|
257
|
+
) -> RerankerResponse:
|
258
|
+
"""
|
259
|
+
Rerank a list of documents based on their relevance to a given query using the specified model.
|
260
|
+
"""
|
261
|
+
payload = {
|
262
|
+
"query": query,
|
263
|
+
"documents": documents,
|
264
|
+
"model": model,
|
265
|
+
"top_n": top_n,
|
266
|
+
"return_documents": return_documents,
|
267
|
+
**kwargs,
|
268
|
+
}
|
269
|
+
|
270
|
+
if group is not None:
|
271
|
+
payload["group"] = group
|
272
|
+
|
273
|
+
return await areq(
|
274
|
+
client,
|
275
|
+
"POST",
|
276
|
+
"/v1/rerank",
|
277
|
+
RerankerResponse,
|
278
|
+
json=payload,
|
279
|
+
)
|
280
|
+
|
281
|
+
|
282
|
+
async def run_workflow(
|
283
|
+
client: httpx.AsyncClient,
|
284
|
+
workflow_id: str,
|
285
|
+
data: dict,
|
286
|
+
) -> Any:
|
287
|
+
"""
|
288
|
+
Run a specific workflow with the provided data.
|
289
|
+
"""
|
290
|
+
try:
|
291
|
+
parser = WorkflowLineParser()
|
292
|
+
async for chunk in areq_stream(
|
293
|
+
client,
|
294
|
+
"POST",
|
295
|
+
f"/workflow/run/{workflow_id}",
|
296
|
+
parse_line=parser,
|
297
|
+
json=data,
|
298
|
+
):
|
299
|
+
yield chunk
|
300
|
+
except httpx.HTTPStatusError as e:
|
301
|
+
if e.response.status_code == 404:
|
302
|
+
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
303
|
+
raise e
|
304
|
+
|
305
|
+
|
306
|
+
async def parse_document(
|
307
|
+
client: httpx.AsyncClient,
|
308
|
+
file: Tuple[bytes, str],
|
309
|
+
source: str,
|
310
|
+
group: Optional[str] = None,
|
311
|
+
) -> str:
|
312
|
+
"""
|
313
|
+
Parse a document using the specified model.
|
314
|
+
"""
|
315
|
+
payload = {
|
316
|
+
"source": source,
|
317
|
+
}
|
318
|
+
|
319
|
+
if group is not None:
|
320
|
+
payload["group"] = group
|
321
|
+
|
322
|
+
files = {"file": (file[1], BytesIO(file[0]), "application/octet-stream")}
|
323
|
+
|
324
|
+
response = areq(
|
325
|
+
client,
|
326
|
+
"POST",
|
327
|
+
"/document-parser/parsing/parse",
|
328
|
+
AnyDict,
|
329
|
+
files=files,
|
330
|
+
json=payload,
|
331
|
+
)
|
332
|
+
print(response)
|
333
|
+
return response
|