fabricatio 0.2.1.dev4__cp312-cp312-win_amd64.whl → 0.2.3__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +8 -0
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/capabilities/rag.py +310 -0
- fabricatio/capabilities/rating.py +79 -1
- fabricatio/config.py +52 -0
- fabricatio/core.py +33 -19
- fabricatio/models/action.py +6 -2
- fabricatio/models/generic.py +107 -1
- fabricatio/models/kwargs_types.py +23 -0
- fabricatio/models/task.py +69 -17
- fabricatio/models/usages.py +77 -70
- fabricatio/models/utils.py +50 -1
- fabricatio-0.2.3.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.1.dev4.dist-info → fabricatio-0.2.3.dist-info}/METADATA +42 -38
- {fabricatio-0.2.1.dev4.dist-info → fabricatio-0.2.3.dist-info}/RECORD +17 -16
- fabricatio-0.2.1.dev4.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.1.dev4.dist-info → fabricatio-0.2.3.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.1.dev4.dist-info → fabricatio-0.2.3.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/generic.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module defines generic classes for models in the Fabricatio library."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Callable, List, Self
|
4
|
+
from typing import Callable, Iterable, List, Optional, Self, Union, final
|
5
5
|
|
6
6
|
import orjson
|
7
7
|
from fabricatio._rust import blake3_hash
|
@@ -12,6 +12,11 @@ from pydantic import (
|
|
12
12
|
BaseModel,
|
13
13
|
ConfigDict,
|
14
14
|
Field,
|
15
|
+
HttpUrl,
|
16
|
+
NonNegativeFloat,
|
17
|
+
PositiveFloat,
|
18
|
+
PositiveInt,
|
19
|
+
SecretStr,
|
15
20
|
)
|
16
21
|
|
17
22
|
|
@@ -150,3 +155,104 @@ class WithDependency(Base):
|
|
150
155
|
for p in self.dependencies
|
151
156
|
},
|
152
157
|
)
|
158
|
+
|
159
|
+
|
160
|
+
class ScopedConfig(Base):
|
161
|
+
"""Class that manages a scoped configuration."""
|
162
|
+
|
163
|
+
llm_api_endpoint: Optional[HttpUrl] = None
|
164
|
+
"""The OpenAI API endpoint."""
|
165
|
+
|
166
|
+
llm_api_key: Optional[SecretStr] = None
|
167
|
+
"""The OpenAI API key."""
|
168
|
+
|
169
|
+
llm_timeout: Optional[PositiveInt] = None
|
170
|
+
"""The timeout of the LLM model."""
|
171
|
+
|
172
|
+
llm_max_retries: Optional[PositiveInt] = None
|
173
|
+
"""The maximum number of retries."""
|
174
|
+
|
175
|
+
llm_model: Optional[str] = None
|
176
|
+
"""The LLM model name."""
|
177
|
+
|
178
|
+
llm_temperature: Optional[NonNegativeFloat] = None
|
179
|
+
"""The temperature of the LLM model."""
|
180
|
+
|
181
|
+
llm_stop_sign: Optional[str | List[str]] = None
|
182
|
+
"""The stop sign of the LLM model."""
|
183
|
+
|
184
|
+
llm_top_p: Optional[NonNegativeFloat] = None
|
185
|
+
"""The top p of the LLM model."""
|
186
|
+
|
187
|
+
llm_generation_count: Optional[PositiveInt] = None
|
188
|
+
"""The number of generations to generate."""
|
189
|
+
|
190
|
+
llm_stream: Optional[bool] = None
|
191
|
+
"""Whether to stream the LLM model's response."""
|
192
|
+
|
193
|
+
llm_max_tokens: Optional[PositiveInt] = None
|
194
|
+
"""The maximum number of tokens to generate."""
|
195
|
+
|
196
|
+
embedding_api_endpoint: Optional[HttpUrl] = None
|
197
|
+
"""The OpenAI API endpoint."""
|
198
|
+
|
199
|
+
embedding_api_key: Optional[SecretStr] = None
|
200
|
+
"""The OpenAI API key."""
|
201
|
+
|
202
|
+
embedding_timeout: Optional[PositiveInt] = None
|
203
|
+
"""The timeout of the LLM model."""
|
204
|
+
|
205
|
+
embedding_model: Optional[str] = None
|
206
|
+
"""The LLM model name."""
|
207
|
+
|
208
|
+
embedding_dimensions: Optional[PositiveInt] = None
|
209
|
+
"""The dimensions of the embedding."""
|
210
|
+
embedding_caching: Optional[bool] = False
|
211
|
+
"""Whether to cache the embedding result."""
|
212
|
+
|
213
|
+
milvus_uri: Optional[HttpUrl] = Field(default=None)
|
214
|
+
"""The URI of the Milvus server."""
|
215
|
+
milvus_token: Optional[SecretStr] = Field(default=None)
|
216
|
+
"""The token for the Milvus server."""
|
217
|
+
milvus_timeout: Optional[PositiveFloat] = Field(default=None)
|
218
|
+
"""The timeout for the Milvus server."""
|
219
|
+
milvus_dimensions: Optional[PositiveInt] = Field(default=None)
|
220
|
+
"""The dimensions of the Milvus server."""
|
221
|
+
|
222
|
+
@final
|
223
|
+
def fallback_to(self, other: "ScopedConfig") -> Self:
|
224
|
+
"""Fallback to another instance's attribute values if the current instance's attributes are None.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
other (LLMUsage): Another instance from which to copy attribute values.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
Self: The current instance, allowing for method chaining.
|
231
|
+
"""
|
232
|
+
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
|
233
|
+
# noinspection PydanticTypeChecker,PyTypeChecker
|
234
|
+
for attr_name in ScopedConfig.model_fields:
|
235
|
+
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
|
236
|
+
if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
|
237
|
+
setattr(self, attr_name, attr)
|
238
|
+
|
239
|
+
# Return the current instance to allow for method chaining
|
240
|
+
return self
|
241
|
+
|
242
|
+
@final
|
243
|
+
def hold_to(self, others: Union["ScopedConfig", Iterable["ScopedConfig"]]) -> Self:
|
244
|
+
"""Hold to another instance's attribute values if the current instance's attributes are None.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
Self: The current instance, allowing for method chaining.
|
251
|
+
"""
|
252
|
+
if not isinstance(others, Iterable):
|
253
|
+
others = [others]
|
254
|
+
for other in others:
|
255
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
256
|
+
for attr_name in ScopedConfig.model_fields:
|
257
|
+
if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
|
258
|
+
setattr(other, attr_name, attr)
|
@@ -5,6 +5,29 @@ from typing import List, NotRequired, TypedDict
|
|
5
5
|
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
|
6
6
|
|
7
7
|
|
8
|
+
class CollectionSimpleConfigKwargs(TypedDict):
|
9
|
+
"""A type representing the configuration for a collection."""
|
10
|
+
|
11
|
+
dimension: NotRequired[int]
|
12
|
+
timeout: NotRequired[float]
|
13
|
+
|
14
|
+
|
15
|
+
class FetchKwargs(TypedDict):
|
16
|
+
"""A type representing the keyword arguments for the fetch method."""
|
17
|
+
|
18
|
+
similarity_threshold: NotRequired[float]
|
19
|
+
result_per_query: NotRequired[int]
|
20
|
+
|
21
|
+
|
22
|
+
class EmbeddingKwargs(TypedDict):
|
23
|
+
"""A type representing the keyword arguments for the embedding method."""
|
24
|
+
|
25
|
+
model: NotRequired[str]
|
26
|
+
dimensions: NotRequired[int]
|
27
|
+
timeout: NotRequired[PositiveInt]
|
28
|
+
caching: NotRequired[bool]
|
29
|
+
|
30
|
+
|
8
31
|
class LLMKwargs(TypedDict):
|
9
32
|
"""A type representing the keyword arguments for the LLM (Large Language Model) usage."""
|
10
33
|
|
fabricatio/models/task.py
CHANGED
@@ -46,21 +46,21 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
46
46
|
"""
|
47
47
|
|
48
48
|
name: str = Field(...)
|
49
|
-
"""The name of the task, which should be
|
49
|
+
"""The name of the task, which should be concise and descriptive."""
|
50
50
|
|
51
51
|
description: str = Field(default="")
|
52
|
-
"""
|
52
|
+
"""A detailed explanation of the task that includes all necessary information. Should be clear and answer what, why, when, where, who, and how questions."""
|
53
53
|
|
54
|
-
|
55
|
-
"""
|
54
|
+
goals: List[str] = Field(default=[])
|
55
|
+
"""A list of objectives that the task aims to accomplish. Each goal should be clear and specific. Complex tasks should be broken into multiple smaller goals."""
|
56
56
|
|
57
57
|
namespace: List[str] = Field(default_factory=list)
|
58
|
-
"""
|
58
|
+
"""A list of string segments that identify the task's location in the system. If not specified, defaults to an empty list."""
|
59
59
|
|
60
60
|
dependencies: List[str] = Field(default_factory=list)
|
61
|
-
"""A list of file paths
|
61
|
+
"""A list of file paths that are needed (either reading or writing) to complete this task. If not specified, defaults to an empty list."""
|
62
62
|
|
63
|
-
_output: Queue = PrivateAttr(default_factory=
|
63
|
+
_output: Queue[T | None] = PrivateAttr(default_factory=Queue)
|
64
64
|
"""The output queue of the task."""
|
65
65
|
|
66
66
|
_status: TaskStatus = PrivateAttr(default=TaskStatus.Pending)
|
@@ -113,7 +113,7 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
113
113
|
Returns:
|
114
114
|
Task: A new instance of the `Task` class.
|
115
115
|
"""
|
116
|
-
return cls(name=name,
|
116
|
+
return cls(name=name, goals=goal, description=description)
|
117
117
|
|
118
118
|
def update_task(self, goal: Optional[List[str] | str] = None, description: Optional[str] = None) -> Self:
|
119
119
|
"""Update the goal and description of the task.
|
@@ -126,12 +126,12 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
126
126
|
Task: The updated instance of the `Task` class.
|
127
127
|
"""
|
128
128
|
if goal:
|
129
|
-
self.
|
129
|
+
self.goals = goal if isinstance(goal, list) else [goal]
|
130
130
|
if description:
|
131
131
|
self.description = description
|
132
132
|
return self
|
133
133
|
|
134
|
-
async def get_output(self) -> T:
|
134
|
+
async def get_output(self) -> T | None:
|
135
135
|
"""Get the output of the task.
|
136
136
|
|
137
137
|
Returns:
|
@@ -232,6 +232,7 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
232
232
|
"""
|
233
233
|
logger.info(f"Cancelling task `{self.name}`")
|
234
234
|
self._status = TaskStatus.Cancelled
|
235
|
+
await self._output.put(None)
|
235
236
|
await env.emit_async(self.cancelled_label, self)
|
236
237
|
return self
|
237
238
|
|
@@ -243,27 +244,38 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
243
244
|
"""
|
244
245
|
logger.info(f"Failing task `{self.name}`")
|
245
246
|
self._status = TaskStatus.Failed
|
247
|
+
await self._output.put(None)
|
246
248
|
await env.emit_async(self.failed_label, self)
|
247
249
|
return self
|
248
250
|
|
249
|
-
|
251
|
+
def publish(self, new_namespace: Optional[EventLike] = None) -> Self:
|
250
252
|
"""Publish the task to the event bus.
|
251
253
|
|
254
|
+
Args:
|
255
|
+
new_namespace(EventLike, optional): The new namespace to move the task to.
|
256
|
+
|
252
257
|
Returns:
|
253
|
-
Task: The published instance of the `Task` class
|
258
|
+
Task: The published instance of the `Task` class.
|
254
259
|
"""
|
260
|
+
if new_namespace:
|
261
|
+
self.move_to(new_namespace)
|
255
262
|
logger.info(f"Publishing task `{(label := self.pending_label)}`")
|
256
|
-
|
263
|
+
env.emit_future(label, self)
|
257
264
|
return self
|
258
265
|
|
259
|
-
async def delegate(self) -> T:
|
260
|
-
"""Delegate the task to the event
|
266
|
+
async def delegate(self, new_namespace: Optional[EventLike] = None) -> T | None:
|
267
|
+
"""Delegate the task to the event.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
new_namespace(EventLike, optional): The new namespace to move the task to.
|
261
271
|
|
262
272
|
Returns:
|
263
|
-
T: The output of the task
|
273
|
+
T|None: The output of the task.
|
264
274
|
"""
|
275
|
+
if new_namespace:
|
276
|
+
self.move_to(new_namespace)
|
265
277
|
logger.info(f"Delegating task `{(label := self.pending_label)}`")
|
266
|
-
|
278
|
+
env.emit_future(label, self)
|
267
279
|
return await self.get_output()
|
268
280
|
|
269
281
|
@property
|
@@ -277,3 +289,43 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
277
289
|
configs.templates.task_briefing_template,
|
278
290
|
self.model_dump(),
|
279
291
|
)
|
292
|
+
|
293
|
+
def is_running(self) -> bool:
|
294
|
+
"""Check if the task is running.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
bool: True if the task is running, False otherwise.
|
298
|
+
"""
|
299
|
+
return self._status == TaskStatus.Running
|
300
|
+
|
301
|
+
def is_finished(self) -> bool:
|
302
|
+
"""Check if the task is finished.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
bool: True if the task is finished, False otherwise.
|
306
|
+
"""
|
307
|
+
return self._status == TaskStatus.Finished
|
308
|
+
|
309
|
+
def is_failed(self) -> bool:
|
310
|
+
"""Check if the task is failed.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
bool: True if the task is failed, False otherwise.
|
314
|
+
"""
|
315
|
+
return self._status == TaskStatus.Failed
|
316
|
+
|
317
|
+
def is_cancelled(self) -> bool:
|
318
|
+
"""Check if the task is cancelled.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
bool: True if the task is cancelled, False otherwise.
|
322
|
+
"""
|
323
|
+
return self._status == TaskStatus.Cancelled
|
324
|
+
|
325
|
+
def is_pending(self) -> bool:
|
326
|
+
"""Check if the task is pending.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
bool: True if the task is pending, False otherwise.
|
330
|
+
"""
|
331
|
+
return self._status == TaskStatus.Pending
|
fabricatio/models/usages.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack, overload
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
@@ -9,8 +9,8 @@ import orjson
|
|
9
9
|
from fabricatio._rust_instances import template_manager
|
10
10
|
from fabricatio.config import configs
|
11
11
|
from fabricatio.journal import logger
|
12
|
-
from fabricatio.models.generic import
|
13
|
-
from fabricatio.models.kwargs_types import ChooseKwargs, GenerateKwargs, LLMKwargs
|
12
|
+
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
13
|
+
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
16
|
from fabricatio.models.utils import Messages
|
@@ -18,48 +18,20 @@ from fabricatio.parser import JsonCapture
|
|
18
18
|
from litellm import stream_chunk_builder
|
19
19
|
from litellm.types.utils import (
|
20
20
|
Choices,
|
21
|
+
EmbeddingResponse,
|
21
22
|
ModelResponse,
|
22
23
|
StreamingChoices,
|
23
24
|
)
|
24
25
|
from litellm.utils import CustomStreamWrapper
|
25
|
-
from pydantic import Field,
|
26
|
+
from pydantic import Field, NonNegativeInt, PositiveInt
|
26
27
|
|
27
28
|
|
28
|
-
class LLMUsage(
|
29
|
+
class LLMUsage(ScopedConfig):
|
29
30
|
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
30
31
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
llm_api_key: Optional[SecretStr] = None
|
35
|
-
"""The OpenAI API key."""
|
36
|
-
|
37
|
-
llm_timeout: Optional[PositiveInt] = None
|
38
|
-
"""The timeout of the LLM model."""
|
39
|
-
|
40
|
-
llm_max_retries: Optional[PositiveInt] = None
|
41
|
-
"""The maximum number of retries."""
|
42
|
-
|
43
|
-
llm_model: Optional[str] = None
|
44
|
-
"""The LLM model name."""
|
45
|
-
|
46
|
-
llm_temperature: Optional[NonNegativeFloat] = None
|
47
|
-
"""The temperature of the LLM model."""
|
48
|
-
|
49
|
-
llm_stop_sign: Optional[str | List[str]] = None
|
50
|
-
"""The stop sign of the LLM model."""
|
51
|
-
|
52
|
-
llm_top_p: Optional[NonNegativeFloat] = None
|
53
|
-
"""The top p of the LLM model."""
|
54
|
-
|
55
|
-
llm_generation_count: Optional[PositiveInt] = None
|
56
|
-
"""The number of generations to generate."""
|
57
|
-
|
58
|
-
llm_stream: Optional[bool] = None
|
59
|
-
"""Whether to stream the LLM model's response."""
|
60
|
-
|
61
|
-
llm_max_tokens: Optional[PositiveInt] = None
|
62
|
-
"""The maximum number of tokens to generate."""
|
32
|
+
@classmethod
|
33
|
+
def _scoped_model(cls) -> Type["LLMUsage"]:
|
34
|
+
return LLMUsage
|
63
35
|
|
64
36
|
async def aquery(
|
65
37
|
self,
|
@@ -89,10 +61,8 @@ class LLMUsage(Base):
|
|
89
61
|
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
90
62
|
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
91
63
|
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
92
|
-
api_key=
|
93
|
-
base_url=self.llm_api_endpoint.unicode_string()
|
94
|
-
if self.llm_api_endpoint
|
95
|
-
else configs.llm.api_endpoint.unicode_string(),
|
64
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
65
|
+
base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
96
66
|
)
|
97
67
|
|
98
68
|
async def ainvoke(
|
@@ -121,13 +91,13 @@ class LLMUsage(Base):
|
|
121
91
|
if isinstance(resp, ModelResponse):
|
122
92
|
return resp.choices
|
123
93
|
if isinstance(resp, CustomStreamWrapper):
|
124
|
-
if configs.debug.streaming_visible:
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
return stream_chunk_builder(
|
94
|
+
if not configs.debug.streaming_visible:
|
95
|
+
return stream_chunk_builder(await asyncstdlib.list()).choices
|
96
|
+
chunks = []
|
97
|
+
async for chunk in resp:
|
98
|
+
chunks.append(chunk)
|
99
|
+
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
100
|
+
return stream_chunk_builder(chunks).choices
|
131
101
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
132
102
|
raise ValueError(err)
|
133
103
|
|
@@ -383,39 +353,76 @@ class LLMUsage(Base):
|
|
383
353
|
**kwargs,
|
384
354
|
)
|
385
355
|
|
386
|
-
|
387
|
-
|
356
|
+
|
357
|
+
class EmbeddingUsage(LLMUsage):
|
358
|
+
"""A class representing the embedding model."""
|
359
|
+
|
360
|
+
async def aembedding(
|
361
|
+
self,
|
362
|
+
input_text: List[str],
|
363
|
+
model: Optional[str] = None,
|
364
|
+
dimensions: Optional[int] = None,
|
365
|
+
timeout: Optional[PositiveInt] = None,
|
366
|
+
caching: Optional[bool] = False,
|
367
|
+
) -> EmbeddingResponse:
|
368
|
+
"""Asynchronously generates embeddings for the given input text.
|
388
369
|
|
389
370
|
Args:
|
390
|
-
|
371
|
+
input_text (List[str]): A list of strings to generate embeddings for.
|
372
|
+
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
373
|
+
dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
|
374
|
+
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
375
|
+
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
376
|
+
|
391
377
|
|
392
378
|
Returns:
|
393
|
-
|
379
|
+
EmbeddingResponse: The response containing the embeddings.
|
394
380
|
"""
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
381
|
+
return await litellm.aembedding(
|
382
|
+
input=input_text,
|
383
|
+
caching=caching or self.embedding_caching or configs.embedding.caching,
|
384
|
+
dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
|
385
|
+
model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
|
386
|
+
timeout=timeout
|
387
|
+
or self.embedding_timeout
|
388
|
+
or configs.embedding.timeout
|
389
|
+
or self.llm_timeout
|
390
|
+
or configs.llm.timeout,
|
391
|
+
api_key=(
|
392
|
+
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
393
|
+
).get_secret_value(),
|
394
|
+
api_base=(
|
395
|
+
self.embedding_api_endpoint
|
396
|
+
or configs.embedding.api_endpoint
|
397
|
+
or self.llm_api_endpoint
|
398
|
+
or configs.llm.api_endpoint
|
399
|
+
)
|
400
|
+
.unicode_string()
|
401
|
+
.rstrip("/"),
|
402
|
+
# seems embedding function takes no base_url end with a slash
|
403
|
+
)
|
404
404
|
|
405
|
-
|
406
|
-
|
405
|
+
@overload
|
406
|
+
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
407
|
+
@overload
|
408
|
+
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
409
|
+
|
410
|
+
async def vectorize(
|
411
|
+
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
412
|
+
) -> List[List[float]] | List[float]:
|
413
|
+
"""Asynchronously generates vector embeddings for the given input text.
|
407
414
|
|
408
415
|
Args:
|
409
|
-
|
416
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
417
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
410
418
|
|
411
419
|
Returns:
|
412
|
-
|
420
|
+
List[List[float]] | List[float]: The generated embeddings.
|
413
421
|
"""
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
setattr(other, attr_name, attr)
|
422
|
+
if isinstance(input_text, str):
|
423
|
+
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
424
|
+
|
425
|
+
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
419
426
|
|
420
427
|
|
421
428
|
class ToolBoxUsage(LLMUsage):
|
fabricatio/models/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""A module containing utility classes for the models."""
|
2
2
|
|
3
|
-
from typing import Dict, List, Literal, Self
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, Self
|
4
4
|
|
5
5
|
from pydantic import BaseModel, ConfigDict, Field
|
6
6
|
|
@@ -76,3 +76,52 @@ class Messages(list):
|
|
76
76
|
list[dict]: A list of dictionaries representing the messages.
|
77
77
|
"""
|
78
78
|
return [message.model_dump() for message in self]
|
79
|
+
|
80
|
+
|
81
|
+
class MilvusData(BaseModel):
|
82
|
+
"""A class representing data stored in Milvus."""
|
83
|
+
|
84
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
85
|
+
id: Optional[int] = Field(default=None)
|
86
|
+
"""The identifier of the data."""
|
87
|
+
|
88
|
+
vector: List[float]
|
89
|
+
"""The vector representation of the data."""
|
90
|
+
|
91
|
+
text: str
|
92
|
+
"""The text representation of the data."""
|
93
|
+
|
94
|
+
subject: Optional[str] = Field(default=None)
|
95
|
+
"""A subject label that we use to demo metadata filtering later."""
|
96
|
+
|
97
|
+
def prepare_insertion(self) -> Dict[str, Any]:
|
98
|
+
"""Prepares the data for insertion into Milvus.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
dict: A dictionary containing the data to be inserted into Milvus.
|
102
|
+
"""
|
103
|
+
return self.model_dump(exclude_none=True)
|
104
|
+
|
105
|
+
def update_subject(self, new_subject: str) -> Self:
|
106
|
+
"""Updates the subject label of the data.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
new_subject (str): The new subject label.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Self: The updated instance of MilvusData.
|
113
|
+
"""
|
114
|
+
self.subject = new_subject
|
115
|
+
return self
|
116
|
+
|
117
|
+
def update_id(self, new_id: int) -> Self:
|
118
|
+
"""Updates the identifier of the data.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
new_id (int): The new identifier.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
Self: The updated instance of MilvusData.
|
125
|
+
"""
|
126
|
+
self.id = new_id
|
127
|
+
return self
|
Binary file
|