aisberg 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. aisberg/__init__.py +7 -0
  2. aisberg/abstract/__init__.py +0 -0
  3. aisberg/abstract/modules.py +57 -0
  4. aisberg/api/__init__.py +0 -0
  5. aisberg/api/async_endpoints.py +333 -0
  6. aisberg/api/endpoints.py +328 -0
  7. aisberg/async_client.py +107 -0
  8. aisberg/client.py +108 -0
  9. aisberg/config.py +17 -0
  10. aisberg/exceptions.py +22 -0
  11. aisberg/models/__init__.py +0 -0
  12. aisberg/models/chat.py +143 -0
  13. aisberg/models/collections.py +36 -0
  14. aisberg/models/embeddings.py +92 -0
  15. aisberg/models/models.py +39 -0
  16. aisberg/models/requests.py +11 -0
  17. aisberg/models/token.py +11 -0
  18. aisberg/models/tools.py +73 -0
  19. aisberg/models/workflows.py +66 -0
  20. aisberg/modules/__init__.py +23 -0
  21. aisberg/modules/chat.py +403 -0
  22. aisberg/modules/collections.py +117 -0
  23. aisberg/modules/document.py +117 -0
  24. aisberg/modules/embeddings.py +309 -0
  25. aisberg/modules/me.py +77 -0
  26. aisberg/modules/models.py +108 -0
  27. aisberg/modules/tools.py +78 -0
  28. aisberg/modules/workflows.py +140 -0
  29. aisberg/requests/__init__.py +0 -0
  30. aisberg/requests/async_requests.py +85 -0
  31. aisberg/requests/sync_requests.py +85 -0
  32. aisberg/utils.py +111 -0
  33. aisberg-0.1.0.dist-info/METADATA +212 -0
  34. aisberg-0.1.0.dist-info/RECORD +43 -0
  35. aisberg-0.1.0.dist-info/WHEEL +5 -0
  36. aisberg-0.1.0.dist-info/licenses/LICENSE +9 -0
  37. aisberg-0.1.0.dist-info/top_level.txt +3 -0
  38. tests/integration/test_collections_integration.py +115 -0
  39. tests/unit/test_collections_sync.py +104 -0
  40. tmp/test.py +33 -0
  41. tmp/test_async.py +126 -0
  42. tmp/test_doc_parse.py +12 -0
  43. tmp/test_sync.py +146 -0
aisberg/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from .client import AisbergClient
2
+ from .async_client import AisbergAsyncClient
3
+
4
+ __all__ = [
5
+ "AisbergClient",
6
+ "AisbergAsyncClient",
7
+ ]
File without changes
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from typing import TYPE_CHECKING, Union, Optional
5
+
6
+ if TYPE_CHECKING:
7
+ from ..client import AisbergClient
8
+ from ..async_client import AisbergAsyncClient
9
+ from httpx import Client as HttpClient
10
+ from httpx import AsyncClient as AsyncHttpClient
11
+
12
+
13
+ class BaseModule(ABC):
14
+ """Abstract base class for modules in the Aisberg framework."""
15
+
16
+ def __init__(
17
+ self,
18
+ parent: Union["AisbergClient", "AisbergAsyncClient"],
19
+ http_client: Optional[Union["HttpClient", "AsyncHttpClient"]] = None,
20
+ ):
21
+ """
22
+ Initialize the BaseModule.
23
+
24
+ Args:
25
+ parent (Any): Parent client or module.
26
+ http_client (Any): HTTP client for making requests.
27
+ """
28
+ self._parent = parent
29
+ self._client = http_client
30
+
31
+
32
+ class SyncModule(BaseModule):
33
+ """Abstract base class for synchronous modules in the Aisberg framework."""
34
+
35
+ def __init__(self, parent: "AisbergClient", http_client: "HttpClient"):
36
+ """
37
+ Initialize the SyncModule.
38
+
39
+ Args:
40
+ parent (Any): Parent client or module.
41
+ http_client (Any): HTTP client for making requests.
42
+ """
43
+ super().__init__(parent, http_client)
44
+
45
+
46
+ class AsyncModule(BaseModule):
47
+ """Abstract base class for asynchronous modules in the Aisberg framework."""
48
+
49
+ def __init__(self, parent: "AisbergAsyncClient", http_client: "AsyncHttpClient"):
50
+ """
51
+ Initialize the AsyncModule.
52
+
53
+ Args:
54
+ parent (Any): Parent client or module.
55
+ http_client (Any): HTTP client for making requests.
56
+ """
57
+ super().__init__(parent, http_client)
File without changes
@@ -0,0 +1,333 @@
1
+ from io import BytesIO
2
+
3
+ import httpx
4
+ from ..models.chat import (
5
+ LanguageModelInput,
6
+ format_messages,
7
+ ChatCompletionResponse,
8
+ ChatCompletionChunk,
9
+ )
10
+ from typing import Optional, AsyncGenerator, Union, List, Any, Tuple
11
+
12
+ from ..models.collections import GroupCollections, PointDetails
13
+ from ..models.embeddings import (
14
+ EncodingFormat,
15
+ EncodingResponse,
16
+ ChunksDataList,
17
+ RerankerResponse,
18
+ )
19
+ from ..models.models import Model
20
+ from ..models.token import TokenInfo
21
+ from ..models.workflows import WorkflowDetails, Workflow
22
+ from ..utils import parse_chat_line, WorkflowLineParser
23
+ from ..requests.async_requests import areq, areq_stream
24
+ from ..models.requests import AnyDict, AnyList
25
+
26
+
27
+ async def models(client: httpx.AsyncClient) -> List[Model]:
28
+ """
29
+ Get the list of available models.
30
+ """
31
+ resp = await areq(client, "GET", "/v1/models", AnyDict)
32
+ data = resp.data
33
+ if not data or not isinstance(data, list):
34
+ raise ValueError("Invalid response format for models")
35
+ return [Model.model_validate(item) for item in data]
36
+
37
+
38
+ async def workflows(client: httpx.AsyncClient) -> List[Workflow]:
39
+ """
40
+ Get the list of available workflows.
41
+ """
42
+ resp = await areq(client, "GET", "/workflow/light", AnyList)
43
+ return [Workflow.model_validate(item) for item in resp.root]
44
+
45
+
46
+ async def workflow(client: httpx.AsyncClient, workflow_id: str) -> WorkflowDetails:
47
+ """
48
+ Get details of a specific workflow.
49
+ """
50
+ try:
51
+ resp = await areq(
52
+ client, "GET", f"/workflow/details/{workflow_id}", WorkflowDetails
53
+ )
54
+ return resp
55
+ except httpx.HTTPStatusError as e:
56
+ if e.response.status_code == 404:
57
+ raise ValueError(f"Workflow with ID {workflow_id} not found")
58
+ raise e
59
+
60
+
61
+ async def collections(client: httpx.AsyncClient) -> List[GroupCollections]:
62
+ """
63
+ Get the list of available collections.
64
+ """
65
+ resp = await areq(client, "GET", "/collections", AnyList)
66
+ return [GroupCollections.model_validate(item) for item in resp.root]
67
+
68
+
69
+ async def collection(
70
+ client: httpx.AsyncClient, collection_id: str, group_id: str
71
+ ) -> List[PointDetails]:
72
+ """
73
+ Get details of a specific collection.
74
+ """
75
+ try:
76
+ resp = await areq(
77
+ client, "GET", f"/collections/{collection_id}/{group_id}", AnyList
78
+ )
79
+ return [PointDetails.model_validate(item) for item in resp.root]
80
+ except httpx.HTTPStatusError as e:
81
+ if e.response.status_code == 404:
82
+ raise ValueError(
83
+ f"Collection with ID {collection_id} not found in group {group_id}"
84
+ )
85
+ raise e
86
+
87
+
88
+ async def me(client: httpx.AsyncClient) -> TokenInfo:
89
+ """
90
+ Get the details of the current user.
91
+ """
92
+ return await areq(client, "GET", "/users/me", TokenInfo)
93
+
94
+
95
+ async def chat(
96
+ client: httpx.AsyncClient,
97
+ input: LanguageModelInput,
98
+ model: str = None,
99
+ temperature: float = 0.7,
100
+ tools: Optional[list] = None,
101
+ group: Optional[str] = None,
102
+ **kwargs,
103
+ ) -> ChatCompletionResponse:
104
+ """
105
+ Send a chat message and get a response from an LLM endpoint.
106
+ """
107
+ if model is None:
108
+ raise ValueError("Model must be specified")
109
+
110
+ formatted_messages = format_messages(input)
111
+
112
+ payload = {
113
+ "model": model,
114
+ "messages": formatted_messages,
115
+ "temperature": temperature,
116
+ "stream": False,
117
+ **kwargs,
118
+ }
119
+
120
+ if group is not None:
121
+ payload["group"] = group
122
+
123
+ if tools is not None:
124
+ payload["tools"] = tools
125
+
126
+ return await areq(
127
+ client,
128
+ "POST",
129
+ "/v1/chat/completions",
130
+ ChatCompletionResponse,
131
+ json=payload,
132
+ )
133
+
134
+
135
+ async def chat_stream(
136
+ client: httpx.AsyncClient,
137
+ input: LanguageModelInput,
138
+ model: str,
139
+ temperature: float = 0.7,
140
+ full_chunk: bool = True,
141
+ group: Optional[str] = None,
142
+ **kwargs,
143
+ ) -> AsyncGenerator[Union[str, ChatCompletionChunk], None]:
144
+ """
145
+ Stream de complétions OpenAI.
146
+ - Si `full_chunk` est True (défaut) : chaque yield est le JSON complet du chunk.
147
+ - Sinon : on garde la compat ascendante → on ne yield que le delta.content + marquages.
148
+ """
149
+ formatted_messages = format_messages(input)
150
+
151
+ payload = {
152
+ "model": model,
153
+ "messages": formatted_messages,
154
+ "temperature": temperature,
155
+ "stream": True,
156
+ **kwargs,
157
+ }
158
+
159
+ if group is not None:
160
+ payload["group"] = group
161
+
162
+ async for chunk in areq_stream(
163
+ client,
164
+ "POST",
165
+ "/v1/chat/completions",
166
+ parse_line=lambda line: parse_chat_line(line, full_chunk=full_chunk),
167
+ json=payload,
168
+ ):
169
+ data = ChatCompletionChunk.model_validate(chunk)
170
+
171
+ if data is None:
172
+ continue
173
+
174
+ if full_chunk:
175
+ yield data
176
+ else:
177
+ yield data.choices[0].delta.content if data.choices else ""
178
+
179
+
180
+ async def embeddings(
181
+ client: httpx.AsyncClient,
182
+ input: str,
183
+ model: str,
184
+ encoding_format: EncodingFormat,
185
+ normalize: bool,
186
+ group: Optional[str] = None,
187
+ **kwargs,
188
+ ) -> EncodingResponse:
189
+ """
190
+ Get embeddings for a given input using the specified model.
191
+ """
192
+ payload = {
193
+ "model": model,
194
+ "input": input,
195
+ "encoding_format": encoding_format,
196
+ "normalize": normalize,
197
+ **kwargs,
198
+ }
199
+
200
+ if group is not None:
201
+ payload["group"] = group
202
+
203
+ return await areq(
204
+ client,
205
+ "POST",
206
+ "/v1/embeddings",
207
+ EncodingResponse,
208
+ json=payload,
209
+ )
210
+
211
+
212
+ async def retrieve(
213
+ client: httpx.AsyncClient,
214
+ query: str,
215
+ collections_names: List[str],
216
+ limit: int,
217
+ score_threshold: float,
218
+ filters: list,
219
+ beta: float,
220
+ group: Optional[str] = None,
221
+ **kwargs,
222
+ ) -> ChunksDataList:
223
+ """
224
+ Retrieve the most relevant documents based on the given query from specified collections.
225
+ """
226
+ data = {
227
+ "query": query,
228
+ "collections_names": collections_names,
229
+ "limit": limit,
230
+ "score": score_threshold,
231
+ "filters": filters,
232
+ "beta": beta,
233
+ **kwargs,
234
+ }
235
+
236
+ if group is not None:
237
+ data["group"] = group
238
+
239
+ return await areq(
240
+ client,
241
+ "POST",
242
+ "/collections/run/search",
243
+ ChunksDataList,
244
+ json=data,
245
+ )
246
+
247
+
248
+ async def rerank(
249
+ client: httpx.AsyncClient,
250
+ query: str,
251
+ documents: List[str],
252
+ model: str,
253
+ top_n: int,
254
+ return_documents: bool,
255
+ group: Optional[str] = None,
256
+ **kwargs,
257
+ ) -> RerankerResponse:
258
+ """
259
+ Rerank a list of documents based on their relevance to a given query using the specified model.
260
+ """
261
+ payload = {
262
+ "query": query,
263
+ "documents": documents,
264
+ "model": model,
265
+ "top_n": top_n,
266
+ "return_documents": return_documents,
267
+ **kwargs,
268
+ }
269
+
270
+ if group is not None:
271
+ payload["group"] = group
272
+
273
+ return await areq(
274
+ client,
275
+ "POST",
276
+ "/v1/rerank",
277
+ RerankerResponse,
278
+ json=payload,
279
+ )
280
+
281
+
282
+ async def run_workflow(
283
+ client: httpx.AsyncClient,
284
+ workflow_id: str,
285
+ data: dict,
286
+ ) -> Any:
287
+ """
288
+ Run a specific workflow with the provided data.
289
+ """
290
+ try:
291
+ parser = WorkflowLineParser()
292
+ async for chunk in areq_stream(
293
+ client,
294
+ "POST",
295
+ f"/workflow/run/{workflow_id}",
296
+ parse_line=parser,
297
+ json=data,
298
+ ):
299
+ yield chunk
300
+ except httpx.HTTPStatusError as e:
301
+ if e.response.status_code == 404:
302
+ raise ValueError(f"Workflow with ID {workflow_id} not found")
303
+ raise e
304
+
305
+
306
+ async def parse_document(
307
+ client: httpx.AsyncClient,
308
+ file: Tuple[bytes, str],
309
+ source: str,
310
+ group: Optional[str] = None,
311
+ ) -> str:
312
+ """
313
+ Parse a document using the specified model.
314
+ """
315
+ payload = {
316
+ "source": source,
317
+ }
318
+
319
+ if group is not None:
320
+ payload["group"] = group
321
+
322
+ files = {"file": (file[1], BytesIO(file[0]), "application/octet-stream")}
323
+
324
+ response = areq(
325
+ client,
326
+ "POST",
327
+ "/document-parser/parsing/parse",
328
+ AnyDict,
329
+ files=files,
330
+ json=payload,
331
+ )
332
+ print(response)
333
+ return response