fabricatio 0.2.2__cp312-cp312-win_amd64.whl → 0.2.3__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +8 -0
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/capabilities/rag.py +310 -0
- fabricatio/config.py +49 -0
- fabricatio/core.py +33 -19
- fabricatio/models/action.py +6 -2
- fabricatio/models/generic.py +107 -1
- fabricatio/models/kwargs_types.py +23 -0
- fabricatio/models/task.py +69 -17
- fabricatio/models/usages.py +77 -70
- fabricatio/models/utils.py +50 -1
- fabricatio-0.2.3.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.2.dist-info → fabricatio-0.2.3.dist-info}/METADATA +42 -38
- {fabricatio-0.2.2.dist-info → fabricatio-0.2.3.dist-info}/RECORD +16 -15
- fabricatio-0.2.2.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.2.dist-info → fabricatio-0.2.3.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.2.dist-info → fabricatio-0.2.3.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Fabricatio is a Python library for building llm app using event-based agent structure."""
|
2
2
|
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
3
5
|
from fabricatio._rust_instances import template_manager
|
4
6
|
from fabricatio.core import env
|
5
7
|
from fabricatio.fs import magika
|
@@ -35,3 +37,9 @@ __all__ = [
|
|
35
37
|
"task_toolbox",
|
36
38
|
"template_manager",
|
37
39
|
]
|
40
|
+
|
41
|
+
|
42
|
+
if find_spec("pymilvus"):
|
43
|
+
from fabricatio.capabilities.rag import RAG
|
44
|
+
|
45
|
+
__all__ += ["RAG"]
|
Binary file
|
@@ -0,0 +1,310 @@
|
|
1
|
+
"""A module for the RAG (Retrieval Augmented Generation) model."""
|
2
|
+
|
3
|
+
try:
|
4
|
+
from pymilvus import MilvusClient
|
5
|
+
except ImportError as e:
|
6
|
+
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
7
|
+
from functools import lru_cache
|
8
|
+
from operator import itemgetter
|
9
|
+
from os import PathLike
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
|
12
|
+
|
13
|
+
from fabricatio._rust_instances import template_manager
|
14
|
+
from fabricatio.config import configs
|
15
|
+
from fabricatio.journal import logger
|
16
|
+
from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
|
17
|
+
from fabricatio.models.usages import EmbeddingUsage
|
18
|
+
from fabricatio.models.utils import MilvusData
|
19
|
+
from more_itertools.recipes import flatten
|
20
|
+
from pydantic import Field, PrivateAttr
|
21
|
+
|
22
|
+
|
23
|
+
@lru_cache(maxsize=None)
|
24
|
+
def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
|
25
|
+
"""Create a Milvus client."""
|
26
|
+
return MilvusClient(
|
27
|
+
uri=uri,
|
28
|
+
token=token,
|
29
|
+
timeout=timeout,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class RAG(EmbeddingUsage):
|
34
|
+
"""A class representing the RAG (Retrieval Augmented Generation) model."""
|
35
|
+
|
36
|
+
target_collection: Optional[str] = Field(default=None)
|
37
|
+
"""The name of the collection being viewed."""
|
38
|
+
|
39
|
+
_client: Optional[MilvusClient] = PrivateAttr(None)
|
40
|
+
"""The Milvus client used for the RAG model."""
|
41
|
+
|
42
|
+
@property
|
43
|
+
def client(self) -> MilvusClient:
|
44
|
+
"""Return the Milvus client."""
|
45
|
+
if self._client is None:
|
46
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
47
|
+
return self._client
|
48
|
+
|
49
|
+
def init_client(
|
50
|
+
self,
|
51
|
+
milvus_uri: Optional[str] = None,
|
52
|
+
milvus_token: Optional[str] = None,
|
53
|
+
milvus_timeout: Optional[float] = None,
|
54
|
+
) -> Self:
|
55
|
+
"""Initialize the Milvus client."""
|
56
|
+
self._client = create_client(
|
57
|
+
uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
58
|
+
token=milvus_token
|
59
|
+
or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
|
60
|
+
timeout=milvus_timeout or self.milvus_timeout,
|
61
|
+
)
|
62
|
+
return self
|
63
|
+
|
64
|
+
@overload
|
65
|
+
async def pack(
|
66
|
+
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
67
|
+
) -> List[MilvusData]: ...
|
68
|
+
@overload
|
69
|
+
async def pack(
|
70
|
+
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
71
|
+
) -> MilvusData: ...
|
72
|
+
|
73
|
+
async def pack(
|
74
|
+
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
75
|
+
) -> List[MilvusData] | MilvusData:
|
76
|
+
"""Asynchronously generates MilvusData objects for the given input text.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
80
|
+
subject (Optional[str]): The subject of the input text. Defaults to None.
|
81
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
85
|
+
"""
|
86
|
+
if isinstance(input_text, str):
|
87
|
+
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
88
|
+
vecs = await self.vectorize(input_text, **kwargs)
|
89
|
+
return [
|
90
|
+
MilvusData(
|
91
|
+
vector=vec,
|
92
|
+
text=text,
|
93
|
+
subject=subject,
|
94
|
+
)
|
95
|
+
for text, vec in zip(input_text, vecs, strict=True)
|
96
|
+
]
|
97
|
+
|
98
|
+
def view(
|
99
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
|
100
|
+
) -> Self:
|
101
|
+
"""View the specified collection.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
collection_name (str): The name of the collection.
|
105
|
+
create (bool): Whether to create the collection if it does not exist.
|
106
|
+
**kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
|
107
|
+
"""
|
108
|
+
if create and collection_name and not self._client.has_collection(collection_name):
|
109
|
+
kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
|
110
|
+
self._client.create_collection(collection_name, auto_id=True, **kwargs)
|
111
|
+
logger.info(f"Creating collection {collection_name}")
|
112
|
+
|
113
|
+
self.target_collection = collection_name
|
114
|
+
return self
|
115
|
+
|
116
|
+
def quit_viewing(self) -> Self:
|
117
|
+
"""Quit the current view.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
Self: The current instance, allowing for method chaining.
|
121
|
+
"""
|
122
|
+
return self.view(None)
|
123
|
+
|
124
|
+
@property
|
125
|
+
def safe_target_collection(self) -> str:
|
126
|
+
"""Get the name of the collection being viewed, raise an error if not viewing any collection.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
str: The name of the collection being viewed.
|
130
|
+
"""
|
131
|
+
if self.target_collection is None:
|
132
|
+
raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
|
133
|
+
return self.target_collection
|
134
|
+
|
135
|
+
def add_document[D: Union[Dict[str, Any], MilvusData]](
|
136
|
+
self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
|
137
|
+
) -> Self:
|
138
|
+
"""Adds a document to the specified collection.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
|
142
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
143
|
+
flush (bool): Whether to flush the collection after insertion.
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
Self: The current instance, allowing for method chaining.
|
147
|
+
"""
|
148
|
+
if isinstance(data, MilvusData):
|
149
|
+
data = data.prepare_insertion()
|
150
|
+
if isinstance(data, list):
|
151
|
+
data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
|
152
|
+
c_name = collection_name or self.safe_target_collection
|
153
|
+
self._client.insert(c_name, data)
|
154
|
+
|
155
|
+
if flush:
|
156
|
+
logger.debug(f"Flushing collection {c_name}")
|
157
|
+
self._client.flush(c_name)
|
158
|
+
return self
|
159
|
+
|
160
|
+
async def consume_file(
|
161
|
+
self,
|
162
|
+
source: List[PathLike] | PathLike,
|
163
|
+
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
164
|
+
collection_name: Optional[str] = None,
|
165
|
+
) -> Self:
|
166
|
+
"""Consume a file and add its content to the collection.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
source (PathLike): The path to the file to be consumed.
|
170
|
+
reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
|
171
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
Self: The current instance, allowing for method chaining.
|
175
|
+
"""
|
176
|
+
if not isinstance(source, list):
|
177
|
+
source = [source]
|
178
|
+
return await self.consume_string([reader(s) for s in source], collection_name)
|
179
|
+
|
180
|
+
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
181
|
+
"""Consume a string and add it to the collection.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
text (List[str] | str): The text to be added to the collection.
|
185
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Self: The current instance, allowing for method chaining.
|
189
|
+
"""
|
190
|
+
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
191
|
+
return self
|
192
|
+
|
193
|
+
async def afetch_document(
|
194
|
+
self,
|
195
|
+
vecs: List[List[float]],
|
196
|
+
desired_fields: List[str] | str,
|
197
|
+
collection_name: Optional[str] = None,
|
198
|
+
similarity_threshold: float = 0.37,
|
199
|
+
result_per_query: int = 10,
|
200
|
+
) -> List[Dict[str, Any]] | List[Any]:
|
201
|
+
"""Fetch data from the collection.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
vecs (List[List[float]]): The vectors to search for.
|
205
|
+
desired_fields (List[str] | str): The fields to retrieve.
|
206
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
207
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
208
|
+
result_per_query (int): The number of results to return per query.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
List[Dict[str, Any]] | List[Any]: The retrieved data.
|
212
|
+
"""
|
213
|
+
# Step 1: Search for vectors
|
214
|
+
search_results = self._client.search(
|
215
|
+
collection_name or self.safe_target_collection,
|
216
|
+
vecs,
|
217
|
+
search_params={"radius": similarity_threshold},
|
218
|
+
output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
|
219
|
+
limit=result_per_query,
|
220
|
+
)
|
221
|
+
|
222
|
+
# Step 2: Flatten the search results
|
223
|
+
flattened_results = flatten(search_results)
|
224
|
+
|
225
|
+
# Step 3: Sort by distance (descending)
|
226
|
+
sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
|
227
|
+
|
228
|
+
logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
|
229
|
+
# Step 4: Extract the entities
|
230
|
+
resp = [result["entity"] for result in sorted_results]
|
231
|
+
|
232
|
+
if isinstance(desired_fields, list):
|
233
|
+
return resp
|
234
|
+
return [r.get(desired_fields) for r in resp]
|
235
|
+
|
236
|
+
async def aretrieve(
|
237
|
+
self,
|
238
|
+
query: List[str] | str,
|
239
|
+
collection_name: Optional[str] = None,
|
240
|
+
final_limit: int = 20,
|
241
|
+
**kwargs: Unpack[FetchKwargs],
|
242
|
+
) -> List[str]:
|
243
|
+
"""Retrieve data from the collection.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
query (List[str] | str): The query to be used for retrieval.
|
247
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
248
|
+
final_limit (int): The final limit on the number of results to return.
|
249
|
+
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
List[str]: A list of strings containing the retrieved data.
|
253
|
+
"""
|
254
|
+
if isinstance(query, str):
|
255
|
+
query = [query]
|
256
|
+
return (
|
257
|
+
await self.afetch_document(
|
258
|
+
vecs=(await self.vectorize(query)),
|
259
|
+
desired_fields="text",
|
260
|
+
collection_name=collection_name,
|
261
|
+
**kwargs,
|
262
|
+
)
|
263
|
+
)[:final_limit]
|
264
|
+
|
265
|
+
async def aask_retrieved(
|
266
|
+
self,
|
267
|
+
question: str | List[str],
|
268
|
+
query: List[str] | str,
|
269
|
+
collection_name: Optional[str] = None,
|
270
|
+
extra_system_message: str = "",
|
271
|
+
result_per_query: int = 10,
|
272
|
+
final_limit: int = 20,
|
273
|
+
similarity_threshold: float = 0.37,
|
274
|
+
**kwargs: Unpack[LLMKwargs],
|
275
|
+
) -> str:
|
276
|
+
"""Asks a question by retrieving relevant documents based on the provided query.
|
277
|
+
|
278
|
+
This method performs document retrieval using the given query, then asks the
|
279
|
+
specified question using the retrieved documents as context.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
question (str | List[str]): The question or list of questions to be asked.
|
283
|
+
query (List[str] | str): The query or list of queries used for document retrieval.
|
284
|
+
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
285
|
+
If not provided, the currently viewed collection is used.
|
286
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
287
|
+
result_per_query (int): The number of results to return per query. Default is 10.
|
288
|
+
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
289
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
290
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
str: A string response generated after asking with the context of retrieved documents.
|
294
|
+
"""
|
295
|
+
docs = await self.aretrieve(
|
296
|
+
query,
|
297
|
+
collection_name,
|
298
|
+
final_limit,
|
299
|
+
result_per_query=result_per_query,
|
300
|
+
similarity_threshold=similarity_threshold,
|
301
|
+
)
|
302
|
+
|
303
|
+
rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
|
304
|
+
|
305
|
+
logger.debug(f"Retrieved Documents: \n{rendered}")
|
306
|
+
return await self.aask(
|
307
|
+
question,
|
308
|
+
f"{rendered}\n\n{extra_system_message}",
|
309
|
+
**kwargs,
|
310
|
+
)
|
fabricatio/config.py
CHANGED
@@ -11,6 +11,7 @@ from pydantic import (
|
|
11
11
|
FilePath,
|
12
12
|
HttpUrl,
|
13
13
|
NonNegativeFloat,
|
14
|
+
PositiveFloat,
|
14
15
|
PositiveInt,
|
15
16
|
SecretStr,
|
16
17
|
)
|
@@ -79,6 +80,30 @@ class LLMConfig(BaseModel):
|
|
79
80
|
"""The maximum number of tokens to generate. Set to 8192 as per request."""
|
80
81
|
|
81
82
|
|
83
|
+
class EmbeddingConfig(BaseModel):
|
84
|
+
"""Embedding configuration class."""
|
85
|
+
|
86
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
87
|
+
|
88
|
+
model: str = Field(default="text-embedding-ada-002")
|
89
|
+
"""The embedding model name. """
|
90
|
+
|
91
|
+
dimensions: Optional[PositiveInt] = Field(default=None)
|
92
|
+
"""The dimensions of the embedding. Default is None."""
|
93
|
+
|
94
|
+
timeout: Optional[PositiveInt] = Field(default=None)
|
95
|
+
"""The timeout of the embedding model in seconds. Default is 300 seconds as per request."""
|
96
|
+
|
97
|
+
caching: bool = Field(default=False)
|
98
|
+
"""Whether to cache the embedding. Default is True."""
|
99
|
+
|
100
|
+
api_endpoint: Optional[HttpUrl] = None
|
101
|
+
"""The OpenAI API endpoint."""
|
102
|
+
|
103
|
+
api_key: Optional[SecretStr] = None
|
104
|
+
"""The OpenAI API key."""
|
105
|
+
|
106
|
+
|
82
107
|
class PymitterConfig(BaseModel):
|
83
108
|
"""Pymitter configuration class.
|
84
109
|
|
@@ -175,6 +200,9 @@ class TemplateConfig(BaseModel):
|
|
175
200
|
draft_rating_weights_klee_template: str = Field(default="draft_rating_weights_klee")
|
176
201
|
"""The name of the draft rating weights klee template which will be used to draft rating weights with Klee method."""
|
177
202
|
|
203
|
+
retrieved_display_template: str = Field(default="retrieved_display")
|
204
|
+
"""The name of the retrieved display template which will be used to display retrieved documents."""
|
205
|
+
|
178
206
|
|
179
207
|
class MagikaConfig(BaseModel):
|
180
208
|
"""Magika configuration class."""
|
@@ -207,6 +235,21 @@ class ToolBoxConfig(BaseModel):
|
|
207
235
|
"""The name of the module containing the data."""
|
208
236
|
|
209
237
|
|
238
|
+
class RagConfig(BaseModel):
|
239
|
+
"""RAG configuration class."""
|
240
|
+
|
241
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
242
|
+
|
243
|
+
milvus_uri: HttpUrl = Field(default=HttpUrl("http://localhost:19530"))
|
244
|
+
"""The URI of the Milvus server."""
|
245
|
+
milvus_timeout: Optional[PositiveFloat] = Field(default=None)
|
246
|
+
"""The timeout of the Milvus server."""
|
247
|
+
milvus_token: Optional[SecretStr] = Field(default=None)
|
248
|
+
"""The token of the Milvus server."""
|
249
|
+
milvus_dimensions: Optional[PositiveInt] = Field(default=None)
|
250
|
+
"""The dimensions of the Milvus server."""
|
251
|
+
|
252
|
+
|
210
253
|
class Settings(BaseSettings):
|
211
254
|
"""Application settings class.
|
212
255
|
|
@@ -232,6 +275,9 @@ class Settings(BaseSettings):
|
|
232
275
|
llm: LLMConfig = Field(default_factory=LLMConfig)
|
233
276
|
"""LLM Configuration"""
|
234
277
|
|
278
|
+
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
|
279
|
+
"""Embedding Configuration"""
|
280
|
+
|
235
281
|
debug: DebugConfig = Field(default_factory=DebugConfig)
|
236
282
|
"""Debug Configuration"""
|
237
283
|
|
@@ -250,6 +296,9 @@ class Settings(BaseSettings):
|
|
250
296
|
toolbox: ToolBoxConfig = Field(default_factory=ToolBoxConfig)
|
251
297
|
"""Toolbox Configuration"""
|
252
298
|
|
299
|
+
rag: RagConfig = Field(default_factory=RagConfig)
|
300
|
+
"""RAG Configuration"""
|
301
|
+
|
253
302
|
@classmethod
|
254
303
|
def settings_customise_sources(
|
255
304
|
cls,
|
fabricatio/core.py
CHANGED
@@ -38,11 +38,11 @@ class Env(BaseModel):
|
|
38
38
|
|
39
39
|
@overload
|
40
40
|
def on[**P, R](
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
41
|
+
self,
|
42
|
+
event: str | Event,
|
43
|
+
func: Optional[Callable[P, R]] = None,
|
44
|
+
/,
|
45
|
+
ttl: int = -1,
|
46
46
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
47
47
|
"""
|
48
48
|
Registers an event listener with a specific function that listens indefinitely or for a specified number of times.
|
@@ -58,11 +58,11 @@ class Env(BaseModel):
|
|
58
58
|
...
|
59
59
|
|
60
60
|
def on[**P, R](
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
self,
|
62
|
+
event: str | Event,
|
63
|
+
func: Optional[Callable[P, R]] = None,
|
64
|
+
/,
|
65
|
+
ttl=-1,
|
66
66
|
) -> Callable[[Callable[P, R]], Callable[P, R]] | Self:
|
67
67
|
"""Registers an event listener with a specific function that listens indefinitely or for a specified number of times.
|
68
68
|
|
@@ -78,14 +78,13 @@ class Env(BaseModel):
|
|
78
78
|
event = event.collapse()
|
79
79
|
if func is None:
|
80
80
|
return self._ee.on(event, ttl=ttl)
|
81
|
-
|
82
81
|
self._ee.on(event, func, ttl=ttl)
|
83
82
|
return self
|
84
83
|
|
85
84
|
@overload
|
86
85
|
def once[**P, R](
|
87
|
-
|
88
|
-
|
86
|
+
self,
|
87
|
+
event: str | Event,
|
89
88
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
90
89
|
"""
|
91
90
|
Registers an event listener that listens only once.
|
@@ -100,9 +99,9 @@ class Env(BaseModel):
|
|
100
99
|
|
101
100
|
@overload
|
102
101
|
def once[**P, R](
|
103
|
-
|
104
|
-
|
105
|
-
|
102
|
+
self,
|
103
|
+
event: str | Event,
|
104
|
+
func: Callable[[Callable[P, R]], Callable[P, R]],
|
106
105
|
) -> Self:
|
107
106
|
"""
|
108
107
|
Registers an event listener with a specific function that listens only once.
|
@@ -117,9 +116,9 @@ class Env(BaseModel):
|
|
117
116
|
...
|
118
117
|
|
119
118
|
def once[**P, R](
|
120
|
-
|
121
|
-
|
122
|
-
|
119
|
+
self,
|
120
|
+
event: str | Event,
|
121
|
+
func: Optional[Callable[P, R]] = None,
|
123
122
|
) -> Callable[[Callable[P, R]], Callable[P, R]] | Self:
|
124
123
|
"""Registers an event listener with a specific function that listens only once.
|
125
124
|
|
@@ -163,5 +162,20 @@ class Env(BaseModel):
|
|
163
162
|
event = event.collapse()
|
164
163
|
return await self._ee.emit_async(event, *args, **kwargs)
|
165
164
|
|
165
|
+
def emit_future[**P](self, event: str | Event, *args: P.args, **kwargs: P.kwargs) -> None:
|
166
|
+
"""Emits an event to all registered listeners and returns a future object.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
event (str | Event): The event to emit.
|
170
|
+
*args: Positional arguments to pass to the listeners.
|
171
|
+
**kwargs: Keyword arguments to pass to the listeners.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
None: The future object.
|
175
|
+
"""
|
176
|
+
if isinstance(event, Event):
|
177
|
+
event = event.collapse()
|
178
|
+
return self._ee.emit_future(event, *args, **kwargs)
|
179
|
+
|
166
180
|
|
167
181
|
env = Env()
|
fabricatio/models/action.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import traceback
|
4
4
|
from abc import abstractmethod
|
5
|
-
from asyncio import Queue
|
5
|
+
from asyncio import Queue, create_task
|
6
6
|
from typing import Any, Dict, Self, Tuple, Type, Union, Unpack
|
7
7
|
|
8
8
|
from fabricatio.capabilities.rating import GiveRating
|
@@ -108,7 +108,11 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
108
108
|
try:
|
109
109
|
for step in self._instances:
|
110
110
|
logger.debug(f"Executing step: {step.name}")
|
111
|
-
|
111
|
+
act_task = create_task(step.act(await self._context.get()))
|
112
|
+
if task.is_cancelled():
|
113
|
+
act_task.cancel(f"Cancelled by task: {task.name}")
|
114
|
+
break
|
115
|
+
modified_ctx = await act_task
|
112
116
|
await self._context.put(modified_ctx)
|
113
117
|
current_action = step.name
|
114
118
|
logger.info(f"Finished executing workflow: {self.name}")
|
fabricatio/models/generic.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module defines generic classes for models in the Fabricatio library."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Callable, List, Self
|
4
|
+
from typing import Callable, Iterable, List, Optional, Self, Union, final
|
5
5
|
|
6
6
|
import orjson
|
7
7
|
from fabricatio._rust import blake3_hash
|
@@ -12,6 +12,11 @@ from pydantic import (
|
|
12
12
|
BaseModel,
|
13
13
|
ConfigDict,
|
14
14
|
Field,
|
15
|
+
HttpUrl,
|
16
|
+
NonNegativeFloat,
|
17
|
+
PositiveFloat,
|
18
|
+
PositiveInt,
|
19
|
+
SecretStr,
|
15
20
|
)
|
16
21
|
|
17
22
|
|
@@ -150,3 +155,104 @@ class WithDependency(Base):
|
|
150
155
|
for p in self.dependencies
|
151
156
|
},
|
152
157
|
)
|
158
|
+
|
159
|
+
|
160
|
+
class ScopedConfig(Base):
|
161
|
+
"""Class that manages a scoped configuration."""
|
162
|
+
|
163
|
+
llm_api_endpoint: Optional[HttpUrl] = None
|
164
|
+
"""The OpenAI API endpoint."""
|
165
|
+
|
166
|
+
llm_api_key: Optional[SecretStr] = None
|
167
|
+
"""The OpenAI API key."""
|
168
|
+
|
169
|
+
llm_timeout: Optional[PositiveInt] = None
|
170
|
+
"""The timeout of the LLM model."""
|
171
|
+
|
172
|
+
llm_max_retries: Optional[PositiveInt] = None
|
173
|
+
"""The maximum number of retries."""
|
174
|
+
|
175
|
+
llm_model: Optional[str] = None
|
176
|
+
"""The LLM model name."""
|
177
|
+
|
178
|
+
llm_temperature: Optional[NonNegativeFloat] = None
|
179
|
+
"""The temperature of the LLM model."""
|
180
|
+
|
181
|
+
llm_stop_sign: Optional[str | List[str]] = None
|
182
|
+
"""The stop sign of the LLM model."""
|
183
|
+
|
184
|
+
llm_top_p: Optional[NonNegativeFloat] = None
|
185
|
+
"""The top p of the LLM model."""
|
186
|
+
|
187
|
+
llm_generation_count: Optional[PositiveInt] = None
|
188
|
+
"""The number of generations to generate."""
|
189
|
+
|
190
|
+
llm_stream: Optional[bool] = None
|
191
|
+
"""Whether to stream the LLM model's response."""
|
192
|
+
|
193
|
+
llm_max_tokens: Optional[PositiveInt] = None
|
194
|
+
"""The maximum number of tokens to generate."""
|
195
|
+
|
196
|
+
embedding_api_endpoint: Optional[HttpUrl] = None
|
197
|
+
"""The OpenAI API endpoint."""
|
198
|
+
|
199
|
+
embedding_api_key: Optional[SecretStr] = None
|
200
|
+
"""The OpenAI API key."""
|
201
|
+
|
202
|
+
embedding_timeout: Optional[PositiveInt] = None
|
203
|
+
"""The timeout of the LLM model."""
|
204
|
+
|
205
|
+
embedding_model: Optional[str] = None
|
206
|
+
"""The LLM model name."""
|
207
|
+
|
208
|
+
embedding_dimensions: Optional[PositiveInt] = None
|
209
|
+
"""The dimensions of the embedding."""
|
210
|
+
embedding_caching: Optional[bool] = False
|
211
|
+
"""Whether to cache the embedding result."""
|
212
|
+
|
213
|
+
milvus_uri: Optional[HttpUrl] = Field(default=None)
|
214
|
+
"""The URI of the Milvus server."""
|
215
|
+
milvus_token: Optional[SecretStr] = Field(default=None)
|
216
|
+
"""The token for the Milvus server."""
|
217
|
+
milvus_timeout: Optional[PositiveFloat] = Field(default=None)
|
218
|
+
"""The timeout for the Milvus server."""
|
219
|
+
milvus_dimensions: Optional[PositiveInt] = Field(default=None)
|
220
|
+
"""The dimensions of the Milvus server."""
|
221
|
+
|
222
|
+
@final
|
223
|
+
def fallback_to(self, other: "ScopedConfig") -> Self:
|
224
|
+
"""Fallback to another instance's attribute values if the current instance's attributes are None.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
other (LLMUsage): Another instance from which to copy attribute values.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
Self: The current instance, allowing for method chaining.
|
231
|
+
"""
|
232
|
+
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
|
233
|
+
# noinspection PydanticTypeChecker,PyTypeChecker
|
234
|
+
for attr_name in ScopedConfig.model_fields:
|
235
|
+
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
|
236
|
+
if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
|
237
|
+
setattr(self, attr_name, attr)
|
238
|
+
|
239
|
+
# Return the current instance to allow for method chaining
|
240
|
+
return self
|
241
|
+
|
242
|
+
@final
|
243
|
+
def hold_to(self, others: Union["ScopedConfig", Iterable["ScopedConfig"]]) -> Self:
|
244
|
+
"""Hold to another instance's attribute values if the current instance's attributes are None.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
Self: The current instance, allowing for method chaining.
|
251
|
+
"""
|
252
|
+
if not isinstance(others, Iterable):
|
253
|
+
others = [others]
|
254
|
+
for other in others:
|
255
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
256
|
+
for attr_name in ScopedConfig.model_fields:
|
257
|
+
if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
|
258
|
+
setattr(other, attr_name, attr)
|
@@ -5,6 +5,29 @@ from typing import List, NotRequired, TypedDict
|
|
5
5
|
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
|
6
6
|
|
7
7
|
|
8
|
+
class CollectionSimpleConfigKwargs(TypedDict):
|
9
|
+
"""A type representing the configuration for a collection."""
|
10
|
+
|
11
|
+
dimension: NotRequired[int]
|
12
|
+
timeout: NotRequired[float]
|
13
|
+
|
14
|
+
|
15
|
+
class FetchKwargs(TypedDict):
|
16
|
+
"""A type representing the keyword arguments for the fetch method."""
|
17
|
+
|
18
|
+
similarity_threshold: NotRequired[float]
|
19
|
+
result_per_query: NotRequired[int]
|
20
|
+
|
21
|
+
|
22
|
+
class EmbeddingKwargs(TypedDict):
|
23
|
+
"""A type representing the keyword arguments for the embedding method."""
|
24
|
+
|
25
|
+
model: NotRequired[str]
|
26
|
+
dimensions: NotRequired[int]
|
27
|
+
timeout: NotRequired[PositiveInt]
|
28
|
+
caching: NotRequired[bool]
|
29
|
+
|
30
|
+
|
8
31
|
class LLMKwargs(TypedDict):
|
9
32
|
"""A type representing the keyword arguments for the LLM (Large Language Model) usage."""
|
10
33
|
|
fabricatio/models/task.py
CHANGED
@@ -46,21 +46,21 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
46
46
|
"""
|
47
47
|
|
48
48
|
name: str = Field(...)
|
49
|
-
"""The name of the task, which should be
|
49
|
+
"""The name of the task, which should be concise and descriptive."""
|
50
50
|
|
51
51
|
description: str = Field(default="")
|
52
|
-
"""
|
52
|
+
"""A detailed explanation of the task that includes all necessary information. Should be clear and answer what, why, when, where, who, and how questions."""
|
53
53
|
|
54
|
-
|
55
|
-
"""
|
54
|
+
goals: List[str] = Field(default=[])
|
55
|
+
"""A list of objectives that the task aims to accomplish. Each goal should be clear and specific. Complex tasks should be broken into multiple smaller goals."""
|
56
56
|
|
57
57
|
namespace: List[str] = Field(default_factory=list)
|
58
|
-
"""
|
58
|
+
"""A list of string segments that identify the task's location in the system. If not specified, defaults to an empty list."""
|
59
59
|
|
60
60
|
dependencies: List[str] = Field(default_factory=list)
|
61
|
-
"""A list of file paths
|
61
|
+
"""A list of file paths that are needed (either reading or writing) to complete this task. If not specified, defaults to an empty list."""
|
62
62
|
|
63
|
-
_output: Queue = PrivateAttr(default_factory=
|
63
|
+
_output: Queue[T | None] = PrivateAttr(default_factory=Queue)
|
64
64
|
"""The output queue of the task."""
|
65
65
|
|
66
66
|
_status: TaskStatus = PrivateAttr(default=TaskStatus.Pending)
|
@@ -113,7 +113,7 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
113
113
|
Returns:
|
114
114
|
Task: A new instance of the `Task` class.
|
115
115
|
"""
|
116
|
-
return cls(name=name,
|
116
|
+
return cls(name=name, goals=goal, description=description)
|
117
117
|
|
118
118
|
def update_task(self, goal: Optional[List[str] | str] = None, description: Optional[str] = None) -> Self:
|
119
119
|
"""Update the goal and description of the task.
|
@@ -126,12 +126,12 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
126
126
|
Task: The updated instance of the `Task` class.
|
127
127
|
"""
|
128
128
|
if goal:
|
129
|
-
self.
|
129
|
+
self.goals = goal if isinstance(goal, list) else [goal]
|
130
130
|
if description:
|
131
131
|
self.description = description
|
132
132
|
return self
|
133
133
|
|
134
|
-
async def get_output(self) -> T:
|
134
|
+
async def get_output(self) -> T | None:
|
135
135
|
"""Get the output of the task.
|
136
136
|
|
137
137
|
Returns:
|
@@ -232,6 +232,7 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
232
232
|
"""
|
233
233
|
logger.info(f"Cancelling task `{self.name}`")
|
234
234
|
self._status = TaskStatus.Cancelled
|
235
|
+
await self._output.put(None)
|
235
236
|
await env.emit_async(self.cancelled_label, self)
|
236
237
|
return self
|
237
238
|
|
@@ -243,27 +244,38 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
243
244
|
"""
|
244
245
|
logger.info(f"Failing task `{self.name}`")
|
245
246
|
self._status = TaskStatus.Failed
|
247
|
+
await self._output.put(None)
|
246
248
|
await env.emit_async(self.failed_label, self)
|
247
249
|
return self
|
248
250
|
|
249
|
-
|
251
|
+
def publish(self, new_namespace: Optional[EventLike] = None) -> Self:
|
250
252
|
"""Publish the task to the event bus.
|
251
253
|
|
254
|
+
Args:
|
255
|
+
new_namespace(EventLike, optional): The new namespace to move the task to.
|
256
|
+
|
252
257
|
Returns:
|
253
|
-
Task: The published instance of the `Task` class
|
258
|
+
Task: The published instance of the `Task` class.
|
254
259
|
"""
|
260
|
+
if new_namespace:
|
261
|
+
self.move_to(new_namespace)
|
255
262
|
logger.info(f"Publishing task `{(label := self.pending_label)}`")
|
256
|
-
|
263
|
+
env.emit_future(label, self)
|
257
264
|
return self
|
258
265
|
|
259
|
-
async def delegate(self) -> T:
|
260
|
-
"""Delegate the task to the event
|
266
|
+
async def delegate(self, new_namespace: Optional[EventLike] = None) -> T | None:
|
267
|
+
"""Delegate the task to the event.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
new_namespace(EventLike, optional): The new namespace to move the task to.
|
261
271
|
|
262
272
|
Returns:
|
263
|
-
T: The output of the task
|
273
|
+
T|None: The output of the task.
|
264
274
|
"""
|
275
|
+
if new_namespace:
|
276
|
+
self.move_to(new_namespace)
|
265
277
|
logger.info(f"Delegating task `{(label := self.pending_label)}`")
|
266
|
-
|
278
|
+
env.emit_future(label, self)
|
267
279
|
return await self.get_output()
|
268
280
|
|
269
281
|
@property
|
@@ -277,3 +289,43 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
277
289
|
configs.templates.task_briefing_template,
|
278
290
|
self.model_dump(),
|
279
291
|
)
|
292
|
+
|
293
|
+
def is_running(self) -> bool:
|
294
|
+
"""Check if the task is running.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
bool: True if the task is running, False otherwise.
|
298
|
+
"""
|
299
|
+
return self._status == TaskStatus.Running
|
300
|
+
|
301
|
+
def is_finished(self) -> bool:
|
302
|
+
"""Check if the task is finished.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
bool: True if the task is finished, False otherwise.
|
306
|
+
"""
|
307
|
+
return self._status == TaskStatus.Finished
|
308
|
+
|
309
|
+
def is_failed(self) -> bool:
|
310
|
+
"""Check if the task is failed.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
bool: True if the task is failed, False otherwise.
|
314
|
+
"""
|
315
|
+
return self._status == TaskStatus.Failed
|
316
|
+
|
317
|
+
def is_cancelled(self) -> bool:
|
318
|
+
"""Check if the task is cancelled.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
bool: True if the task is cancelled, False otherwise.
|
322
|
+
"""
|
323
|
+
return self._status == TaskStatus.Cancelled
|
324
|
+
|
325
|
+
def is_pending(self) -> bool:
|
326
|
+
"""Check if the task is pending.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
bool: True if the task is pending, False otherwise.
|
330
|
+
"""
|
331
|
+
return self._status == TaskStatus.Pending
|
fabricatio/models/usages.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack, overload
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
@@ -9,8 +9,8 @@ import orjson
|
|
9
9
|
from fabricatio._rust_instances import template_manager
|
10
10
|
from fabricatio.config import configs
|
11
11
|
from fabricatio.journal import logger
|
12
|
-
from fabricatio.models.generic import
|
13
|
-
from fabricatio.models.kwargs_types import ChooseKwargs, GenerateKwargs, LLMKwargs
|
12
|
+
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
13
|
+
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
16
|
from fabricatio.models.utils import Messages
|
@@ -18,48 +18,20 @@ from fabricatio.parser import JsonCapture
|
|
18
18
|
from litellm import stream_chunk_builder
|
19
19
|
from litellm.types.utils import (
|
20
20
|
Choices,
|
21
|
+
EmbeddingResponse,
|
21
22
|
ModelResponse,
|
22
23
|
StreamingChoices,
|
23
24
|
)
|
24
25
|
from litellm.utils import CustomStreamWrapper
|
25
|
-
from pydantic import Field,
|
26
|
+
from pydantic import Field, NonNegativeInt, PositiveInt
|
26
27
|
|
27
28
|
|
28
|
-
class LLMUsage(
|
29
|
+
class LLMUsage(ScopedConfig):
|
29
30
|
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
30
31
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
llm_api_key: Optional[SecretStr] = None
|
35
|
-
"""The OpenAI API key."""
|
36
|
-
|
37
|
-
llm_timeout: Optional[PositiveInt] = None
|
38
|
-
"""The timeout of the LLM model."""
|
39
|
-
|
40
|
-
llm_max_retries: Optional[PositiveInt] = None
|
41
|
-
"""The maximum number of retries."""
|
42
|
-
|
43
|
-
llm_model: Optional[str] = None
|
44
|
-
"""The LLM model name."""
|
45
|
-
|
46
|
-
llm_temperature: Optional[NonNegativeFloat] = None
|
47
|
-
"""The temperature of the LLM model."""
|
48
|
-
|
49
|
-
llm_stop_sign: Optional[str | List[str]] = None
|
50
|
-
"""The stop sign of the LLM model."""
|
51
|
-
|
52
|
-
llm_top_p: Optional[NonNegativeFloat] = None
|
53
|
-
"""The top p of the LLM model."""
|
54
|
-
|
55
|
-
llm_generation_count: Optional[PositiveInt] = None
|
56
|
-
"""The number of generations to generate."""
|
57
|
-
|
58
|
-
llm_stream: Optional[bool] = None
|
59
|
-
"""Whether to stream the LLM model's response."""
|
60
|
-
|
61
|
-
llm_max_tokens: Optional[PositiveInt] = None
|
62
|
-
"""The maximum number of tokens to generate."""
|
32
|
+
@classmethod
|
33
|
+
def _scoped_model(cls) -> Type["LLMUsage"]:
|
34
|
+
return LLMUsage
|
63
35
|
|
64
36
|
async def aquery(
|
65
37
|
self,
|
@@ -89,10 +61,8 @@ class LLMUsage(Base):
|
|
89
61
|
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
90
62
|
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
91
63
|
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
92
|
-
api_key=
|
93
|
-
base_url=self.llm_api_endpoint.unicode_string()
|
94
|
-
if self.llm_api_endpoint
|
95
|
-
else configs.llm.api_endpoint.unicode_string(),
|
64
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
65
|
+
base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
96
66
|
)
|
97
67
|
|
98
68
|
async def ainvoke(
|
@@ -121,13 +91,13 @@ class LLMUsage(Base):
|
|
121
91
|
if isinstance(resp, ModelResponse):
|
122
92
|
return resp.choices
|
123
93
|
if isinstance(resp, CustomStreamWrapper):
|
124
|
-
if configs.debug.streaming_visible:
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
return stream_chunk_builder(
|
94
|
+
if not configs.debug.streaming_visible:
|
95
|
+
return stream_chunk_builder(await asyncstdlib.list()).choices
|
96
|
+
chunks = []
|
97
|
+
async for chunk in resp:
|
98
|
+
chunks.append(chunk)
|
99
|
+
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
100
|
+
return stream_chunk_builder(chunks).choices
|
131
101
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
132
102
|
raise ValueError(err)
|
133
103
|
|
@@ -383,39 +353,76 @@ class LLMUsage(Base):
|
|
383
353
|
**kwargs,
|
384
354
|
)
|
385
355
|
|
386
|
-
|
387
|
-
|
356
|
+
|
357
|
+
class EmbeddingUsage(LLMUsage):
|
358
|
+
"""A class representing the embedding model."""
|
359
|
+
|
360
|
+
async def aembedding(
|
361
|
+
self,
|
362
|
+
input_text: List[str],
|
363
|
+
model: Optional[str] = None,
|
364
|
+
dimensions: Optional[int] = None,
|
365
|
+
timeout: Optional[PositiveInt] = None,
|
366
|
+
caching: Optional[bool] = False,
|
367
|
+
) -> EmbeddingResponse:
|
368
|
+
"""Asynchronously generates embeddings for the given input text.
|
388
369
|
|
389
370
|
Args:
|
390
|
-
|
371
|
+
input_text (List[str]): A list of strings to generate embeddings for.
|
372
|
+
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
373
|
+
dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
|
374
|
+
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
375
|
+
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
376
|
+
|
391
377
|
|
392
378
|
Returns:
|
393
|
-
|
379
|
+
EmbeddingResponse: The response containing the embeddings.
|
394
380
|
"""
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
381
|
+
return await litellm.aembedding(
|
382
|
+
input=input_text,
|
383
|
+
caching=caching or self.embedding_caching or configs.embedding.caching,
|
384
|
+
dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
|
385
|
+
model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
|
386
|
+
timeout=timeout
|
387
|
+
or self.embedding_timeout
|
388
|
+
or configs.embedding.timeout
|
389
|
+
or self.llm_timeout
|
390
|
+
or configs.llm.timeout,
|
391
|
+
api_key=(
|
392
|
+
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
393
|
+
).get_secret_value(),
|
394
|
+
api_base=(
|
395
|
+
self.embedding_api_endpoint
|
396
|
+
or configs.embedding.api_endpoint
|
397
|
+
or self.llm_api_endpoint
|
398
|
+
or configs.llm.api_endpoint
|
399
|
+
)
|
400
|
+
.unicode_string()
|
401
|
+
.rstrip("/"),
|
402
|
+
# seems embedding function takes no base_url end with a slash
|
403
|
+
)
|
404
404
|
|
405
|
-
|
406
|
-
|
405
|
+
@overload
|
406
|
+
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
407
|
+
@overload
|
408
|
+
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
409
|
+
|
410
|
+
async def vectorize(
|
411
|
+
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
412
|
+
) -> List[List[float]] | List[float]:
|
413
|
+
"""Asynchronously generates vector embeddings for the given input text.
|
407
414
|
|
408
415
|
Args:
|
409
|
-
|
416
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
417
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
410
418
|
|
411
419
|
Returns:
|
412
|
-
|
420
|
+
List[List[float]] | List[float]: The generated embeddings.
|
413
421
|
"""
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
setattr(other, attr_name, attr)
|
422
|
+
if isinstance(input_text, str):
|
423
|
+
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
424
|
+
|
425
|
+
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
419
426
|
|
420
427
|
|
421
428
|
class ToolBoxUsage(LLMUsage):
|
fabricatio/models/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""A module containing utility classes for the models."""
|
2
2
|
|
3
|
-
from typing import Dict, List, Literal, Self
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, Self
|
4
4
|
|
5
5
|
from pydantic import BaseModel, ConfigDict, Field
|
6
6
|
|
@@ -76,3 +76,52 @@ class Messages(list):
|
|
76
76
|
list[dict]: A list of dictionaries representing the messages.
|
77
77
|
"""
|
78
78
|
return [message.model_dump() for message in self]
|
79
|
+
|
80
|
+
|
81
|
+
class MilvusData(BaseModel):
|
82
|
+
"""A class representing data stored in Milvus."""
|
83
|
+
|
84
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
85
|
+
id: Optional[int] = Field(default=None)
|
86
|
+
"""The identifier of the data."""
|
87
|
+
|
88
|
+
vector: List[float]
|
89
|
+
"""The vector representation of the data."""
|
90
|
+
|
91
|
+
text: str
|
92
|
+
"""The text representation of the data."""
|
93
|
+
|
94
|
+
subject: Optional[str] = Field(default=None)
|
95
|
+
"""A subject label that we use to demo metadata filtering later."""
|
96
|
+
|
97
|
+
def prepare_insertion(self) -> Dict[str, Any]:
|
98
|
+
"""Prepares the data for insertion into Milvus.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
dict: A dictionary containing the data to be inserted into Milvus.
|
102
|
+
"""
|
103
|
+
return self.model_dump(exclude_none=True)
|
104
|
+
|
105
|
+
def update_subject(self, new_subject: str) -> Self:
|
106
|
+
"""Updates the subject label of the data.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
new_subject (str): The new subject label.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Self: The updated instance of MilvusData.
|
113
|
+
"""
|
114
|
+
self.subject = new_subject
|
115
|
+
return self
|
116
|
+
|
117
|
+
def update_id(self, new_id: int) -> Self:
|
118
|
+
"""Updates the identifier of the data.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
new_id (int): The new identifier.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
Self: The updated instance of MilvusData.
|
125
|
+
"""
|
126
|
+
self.id = new_id
|
127
|
+
return self
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: fabricatio
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.3
|
4
4
|
Classifier: License :: OSI Approved :: MIT License
|
5
5
|
Classifier: Programming Language :: Rust
|
6
6
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -98,32 +98,32 @@ from fabricatio import Action, Role, Task, logger
|
|
98
98
|
|
99
99
|
|
100
100
|
class Hello(Action):
|
101
|
-
|
101
|
+
"""Action that says hello."""
|
102
102
|
|
103
|
-
|
104
|
-
|
103
|
+
name: str = "hello"
|
104
|
+
output_key: str = "task_output"
|
105
105
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
106
|
+
async def _execute(self, task_input: Task[str], **_) -> Any:
|
107
|
+
ret = "Hello fabricatio!"
|
108
|
+
logger.info("executing talk action")
|
109
|
+
return ret
|
110
110
|
|
111
111
|
|
112
112
|
async def main() -> None:
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
113
|
+
"""Main function."""
|
114
|
+
role = Role(
|
115
|
+
name="talker",
|
116
|
+
description="talker role",
|
117
|
+
registry={Task.pending_label: WorkFlow(name="talk", steps=(Hello,))}
|
118
|
+
)
|
119
119
|
|
120
|
-
|
121
|
-
|
122
|
-
|
120
|
+
task = Task(name="say hello", goals="say hello", description="say hello to the world")
|
121
|
+
result = await task.delegate()
|
122
|
+
logger.success(f"Result: {result}")
|
123
123
|
|
124
124
|
|
125
125
|
if __name__ == "__main__":
|
126
|
-
|
126
|
+
asyncio.run(main())
|
127
127
|
```
|
128
128
|
|
129
129
|
#### Writing and Dumping Code
|
@@ -311,17 +311,18 @@ from fabricatio.models.task import Task
|
|
311
311
|
|
312
312
|
toolbox_usage = ToolBoxUsage()
|
313
313
|
|
314
|
+
|
314
315
|
async def handle_security_vulnerabilities():
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
316
|
+
task = Task(
|
317
|
+
name="Security Check",
|
318
|
+
goals=["Identify security vulnerabilities"],
|
319
|
+
description="Perform a thorough security review on the project.",
|
320
|
+
dependencies=["./src/main.py"]
|
321
|
+
)
|
322
|
+
|
323
|
+
vulnerabilities = await toolbox_usage.gather_tools_fine_grind(task)
|
324
|
+
for vulnerability in vulnerabilities:
|
325
|
+
print(f"Found vulnerability: {vulnerability.name}")
|
325
326
|
```
|
326
327
|
|
327
328
|
#### Managing CTF Challenges
|
@@ -334,19 +335,22 @@ from fabricatio.models.task import Task
|
|
334
335
|
|
335
336
|
toolbox_usage = ToolBoxUsage()
|
336
337
|
|
338
|
+
|
337
339
|
async def solve_ctf_challenge(challenge_name: str, challenge_description: str, files: list[str]):
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
340
|
+
task = Task(
|
341
|
+
name=challenge_name,
|
342
|
+
goals=[f"Solve {challenge_name} challenge"],
|
343
|
+
description=challenge_description,
|
344
|
+
dependencies=files
|
345
|
+
)
|
346
|
+
|
347
|
+
solution = await toolbox_usage.gather_tools_fine_grind(task)
|
348
|
+
print(f"Challenge Solved: {solution}")
|
349
|
+
|
347
350
|
|
348
351
|
if __name__ == "__main__":
|
349
|
-
|
352
|
+
asyncio.run(
|
353
|
+
solve_ctf_challenge("Binary Exploitation", "CTF Binary Exploitation Challenge", ["./challenges/binary_exploit"]))
|
350
354
|
```
|
351
355
|
|
352
356
|
### Configuration
|
@@ -1,27 +1,28 @@
|
|
1
|
-
fabricatio-0.2.
|
2
|
-
fabricatio-0.2.
|
3
|
-
fabricatio-0.2.
|
1
|
+
fabricatio-0.2.3.dist-info/METADATA,sha256=uJGeKnZ9deejZ-g23WI67qGEi0AyVN0hJWZ1hw1hjWo,12291
|
2
|
+
fabricatio-0.2.3.dist-info/WHEEL,sha256=tpW5AN9B-9qsM9WW2FXG2r193YXiqexDadpKp0A2daI,96
|
3
|
+
fabricatio-0.2.3.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
|
4
4
|
fabricatio/actions/communication.py,sha256=NZxIIncKgJSDyBrqNebUtH_haqtxHa8ld2TZxT3CMdU,429
|
5
5
|
fabricatio/actions/transmission.py,sha256=xpvKqbXqgpi1BWy-vUUvmd8NZ1GhRNfsYUBp-l2jLyk,862
|
6
6
|
fabricatio/actions/__init__.py,sha256=eFmFVPQvtNgFynIXBVr3eP-vWQDWCPng60YY5LXvZgg,115
|
7
|
+
fabricatio/capabilities/rag.py,sha256=paq2zUOfw6whIBFkKDo1Kg5Ft5YXgWiJBNKq-6uGhuU,13295
|
7
8
|
fabricatio/capabilities/rating.py,sha256=zmTUvsUfxFgovRQzy4djL2zKRYTHmN6JY7A4lyT5uVQ,14907
|
8
9
|
fabricatio/capabilities/task.py,sha256=d2xtrwQxXWI40UskQCR5YhHarY7ST0ppr8TjY12uWQE,5327
|
9
|
-
fabricatio/config.py,sha256=
|
10
|
-
fabricatio/core.py,sha256=
|
10
|
+
fabricatio/config.py,sha256=j9dk9q_1-L2rIbunj8JzMGJUQLZ4-TAyLRYQ5BiP3B0,13232
|
11
|
+
fabricatio/core.py,sha256=VQ_JKgUGIy2gZ8xsTBZCdr_IP7wC5aPg0_bsOmjQ588,6458
|
11
12
|
fabricatio/decorators.py,sha256=uzsP4tFKQNjDHBkofsjjoJA0IUAaYOtt6YVedoyOqlo,6551
|
12
13
|
fabricatio/fs/curd.py,sha256=faMstgGUiQ4k2AW3OXfvvWWTldTtKXco7QINYaMjmyA,3981
|
13
14
|
fabricatio/fs/readers.py,sha256=Pz1-cdZYtmqr032dsroImlkFXAd0kCYY_9qVpD4UrG4,1045
|
14
15
|
fabricatio/fs/__init__.py,sha256=lWcKYg0v3mv2LnnSegOQaTtlVDODU0vtw_s6iKU5IqQ,122
|
15
16
|
fabricatio/journal.py,sha256=siqimKF0M_QaaOCMxtjr_BJVNyUIAQWILzE9Q4T6-7c,781
|
16
|
-
fabricatio/models/action.py,sha256=
|
17
|
+
fabricatio/models/action.py,sha256=NpklAVUHYO5JIY9YLwYowZ-U8R9CFf5aC10DhLF7gxQ,5924
|
17
18
|
fabricatio/models/events.py,sha256=mrihNEFgQ5o7qFWja1z_qX8dnaTLwPBoJdVlzxQV5oM,2719
|
18
|
-
fabricatio/models/generic.py,sha256=
|
19
|
-
fabricatio/models/kwargs_types.py,sha256=
|
19
|
+
fabricatio/models/generic.py,sha256=BXCweaYDSIxit3kqh0QshdvO7eRkF8RkNt1r-9rN76Q,9146
|
20
|
+
fabricatio/models/kwargs_types.py,sha256=Xhy5LcTB1oWBGVGipLf5y_dTb7tBzMO5QAQdEfZeI9I,1786
|
20
21
|
fabricatio/models/role.py,sha256=gYvleTeKUGDUNKPAC5B0EPMLC4jZ4vHsFHmHiVXkU6c,1830
|
21
|
-
fabricatio/models/task.py,sha256=
|
22
|
+
fabricatio/models/task.py,sha256=M6jeDFE3jX6cNV9bdOwhjHqgBHI3FKtFLWcmlqhYgcs,11419
|
22
23
|
fabricatio/models/tool.py,sha256=WTFnpF6xZ1nJbmIOonLsGQcM-kkDCeZiAFqyil9xg2U,6988
|
23
|
-
fabricatio/models/usages.py,sha256=
|
24
|
-
fabricatio/models/utils.py,sha256=
|
24
|
+
fabricatio/models/usages.py,sha256=Rh8zz-BUVpdsOKoV4gX1yrac-bFOfWED3rkWvKIzP0E,24569
|
25
|
+
fabricatio/models/utils.py,sha256=mXea76bd4r2jy_zx74GM4t5kCvkMu0JTOaw_VGvTCxk,3952
|
25
26
|
fabricatio/parser.py,sha256=uLabsvF07wRKW1PoTGuGEENCx3P4mhmuO8JkmOEkKko,3522
|
26
27
|
fabricatio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
28
|
fabricatio/toolboxes/arithmetic.py,sha256=WLqhY-Pikv11Y_0SGajwZx3WhsLNpHKf9drzAqOf_nY,1369
|
@@ -30,7 +31,7 @@ fabricatio/toolboxes/task.py,sha256=kU4a501awIDV7GwNDuSlK3_Ym-5OhCp5sS-insTmUmQ,
|
|
30
31
|
fabricatio/toolboxes/__init__.py,sha256=b13KmASO8q5fBLwew964fn9oH86ER5g-S1PgA4fZ_xs,482
|
31
32
|
fabricatio/_rust.pyi,sha256=0wCqtwWkVxxoqprvk8T27T8QYKIAKHS7xgsmdMNjQKc,1756
|
32
33
|
fabricatio/_rust_instances.py,sha256=dl0-yZ4UvT5g20tQgnPJpmqtkjFGXNG_YK4eLfi_ugQ,279
|
33
|
-
fabricatio/__init__.py,sha256=
|
34
|
-
fabricatio/_rust.cp312-win_amd64.pyd,sha256=
|
35
|
-
fabricatio-0.2.
|
36
|
-
fabricatio-0.2.
|
34
|
+
fabricatio/__init__.py,sha256=E-JoEkGpl543nTbES0JGo_qOHaj2R6fz1bUcNajveys,1246
|
35
|
+
fabricatio/_rust.cp312-win_amd64.pyd,sha256=miczlXr6tcmL9VUTUdQEUaI6GNupRPktAAqf--4wXds,1270784
|
36
|
+
fabricatio-0.2.3.data/scripts/tdown.exe,sha256=Shuxh1PdIObKTQtXi69ra0Dxi9yllTKjdv2LoXB-l7M,3398144
|
37
|
+
fabricatio-0.2.3.dist-info/RECORD,,
|
Binary file
|
File without changes
|
File without changes
|