fabricatio 0.2.1.dev4__cp312-cp312-win_amd64.whl → 0.2.3__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fabricatio/__init__.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Fabricatio is a Python library for building llm app using event-based agent structure."""
2
2
 
3
+ from importlib.util import find_spec
4
+
3
5
  from fabricatio._rust_instances import template_manager
4
6
  from fabricatio.core import env
5
7
  from fabricatio.fs import magika
@@ -35,3 +37,9 @@ __all__ = [
35
37
  "task_toolbox",
36
38
  "template_manager",
37
39
  ]
40
+
41
+
42
+ if find_spec("pymilvus"):
43
+ from fabricatio.capabilities.rag import RAG
44
+
45
+ __all__ += ["RAG"]
Binary file
@@ -0,0 +1,310 @@
1
+ """A module for the RAG (Retrieval Augmented Generation) model."""
2
+
3
+ try:
4
+ from pymilvus import MilvusClient
5
+ except ImportError as e:
6
+ raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
7
+ from functools import lru_cache
8
+ from operator import itemgetter
9
+ from os import PathLike
10
+ from pathlib import Path
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
12
+
13
+ from fabricatio._rust_instances import template_manager
14
+ from fabricatio.config import configs
15
+ from fabricatio.journal import logger
16
+ from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
17
+ from fabricatio.models.usages import EmbeddingUsage
18
+ from fabricatio.models.utils import MilvusData
19
+ from more_itertools.recipes import flatten
20
+ from pydantic import Field, PrivateAttr
21
+
22
+
23
+ @lru_cache(maxsize=None)
24
+ def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
25
+ """Create a Milvus client."""
26
+ return MilvusClient(
27
+ uri=uri,
28
+ token=token,
29
+ timeout=timeout,
30
+ )
31
+
32
+
33
+ class RAG(EmbeddingUsage):
34
+ """A class representing the RAG (Retrieval Augmented Generation) model."""
35
+
36
+ target_collection: Optional[str] = Field(default=None)
37
+ """The name of the collection being viewed."""
38
+
39
+ _client: Optional[MilvusClient] = PrivateAttr(None)
40
+ """The Milvus client used for the RAG model."""
41
+
42
+ @property
43
+ def client(self) -> MilvusClient:
44
+ """Return the Milvus client."""
45
+ if self._client is None:
46
+ raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
47
+ return self._client
48
+
49
+ def init_client(
50
+ self,
51
+ milvus_uri: Optional[str] = None,
52
+ milvus_token: Optional[str] = None,
53
+ milvus_timeout: Optional[float] = None,
54
+ ) -> Self:
55
+ """Initialize the Milvus client."""
56
+ self._client = create_client(
57
+ uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
58
+ token=milvus_token
59
+ or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
60
+ timeout=milvus_timeout or self.milvus_timeout,
61
+ )
62
+ return self
63
+
64
+ @overload
65
+ async def pack(
66
+ self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
67
+ ) -> List[MilvusData]: ...
68
+ @overload
69
+ async def pack(
70
+ self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
71
+ ) -> MilvusData: ...
72
+
73
+ async def pack(
74
+ self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
75
+ ) -> List[MilvusData] | MilvusData:
76
+ """Asynchronously generates MilvusData objects for the given input text.
77
+
78
+ Args:
79
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
80
+ subject (Optional[str]): The subject of the input text. Defaults to None.
81
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
82
+
83
+ Returns:
84
+ List[MilvusData] | MilvusData: The generated MilvusData objects.
85
+ """
86
+ if isinstance(input_text, str):
87
+ return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
88
+ vecs = await self.vectorize(input_text, **kwargs)
89
+ return [
90
+ MilvusData(
91
+ vector=vec,
92
+ text=text,
93
+ subject=subject,
94
+ )
95
+ for text, vec in zip(input_text, vecs, strict=True)
96
+ ]
97
+
98
+ def view(
99
+ self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
100
+ ) -> Self:
101
+ """View the specified collection.
102
+
103
+ Args:
104
+ collection_name (str): The name of the collection.
105
+ create (bool): Whether to create the collection if it does not exist.
106
+ **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
107
+ """
108
+ if create and collection_name and not self._client.has_collection(collection_name):
109
+ kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
110
+ self._client.create_collection(collection_name, auto_id=True, **kwargs)
111
+ logger.info(f"Creating collection {collection_name}")
112
+
113
+ self.target_collection = collection_name
114
+ return self
115
+
116
+ def quit_viewing(self) -> Self:
117
+ """Quit the current view.
118
+
119
+ Returns:
120
+ Self: The current instance, allowing for method chaining.
121
+ """
122
+ return self.view(None)
123
+
124
+ @property
125
+ def safe_target_collection(self) -> str:
126
+ """Get the name of the collection being viewed, raise an error if not viewing any collection.
127
+
128
+ Returns:
129
+ str: The name of the collection being viewed.
130
+ """
131
+ if self.target_collection is None:
132
+ raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
133
+ return self.target_collection
134
+
135
+ def add_document[D: Union[Dict[str, Any], MilvusData]](
136
+ self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
137
+ ) -> Self:
138
+ """Adds a document to the specified collection.
139
+
140
+ Args:
141
+ data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
142
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
143
+ flush (bool): Whether to flush the collection after insertion.
144
+
145
+ Returns:
146
+ Self: The current instance, allowing for method chaining.
147
+ """
148
+ if isinstance(data, MilvusData):
149
+ data = data.prepare_insertion()
150
+ if isinstance(data, list):
151
+ data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
152
+ c_name = collection_name or self.safe_target_collection
153
+ self._client.insert(c_name, data)
154
+
155
+ if flush:
156
+ logger.debug(f"Flushing collection {c_name}")
157
+ self._client.flush(c_name)
158
+ return self
159
+
160
+ async def consume_file(
161
+ self,
162
+ source: List[PathLike] | PathLike,
163
+ reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
164
+ collection_name: Optional[str] = None,
165
+ ) -> Self:
166
+ """Consume a file and add its content to the collection.
167
+
168
+ Args:
169
+ source (PathLike): The path to the file to be consumed.
170
+ reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
171
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
172
+
173
+ Returns:
174
+ Self: The current instance, allowing for method chaining.
175
+ """
176
+ if not isinstance(source, list):
177
+ source = [source]
178
+ return await self.consume_string([reader(s) for s in source], collection_name)
179
+
180
+ async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
181
+ """Consume a string and add it to the collection.
182
+
183
+ Args:
184
+ text (List[str] | str): The text to be added to the collection.
185
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
186
+
187
+ Returns:
188
+ Self: The current instance, allowing for method chaining.
189
+ """
190
+ self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
191
+ return self
192
+
193
+ async def afetch_document(
194
+ self,
195
+ vecs: List[List[float]],
196
+ desired_fields: List[str] | str,
197
+ collection_name: Optional[str] = None,
198
+ similarity_threshold: float = 0.37,
199
+ result_per_query: int = 10,
200
+ ) -> List[Dict[str, Any]] | List[Any]:
201
+ """Fetch data from the collection.
202
+
203
+ Args:
204
+ vecs (List[List[float]]): The vectors to search for.
205
+ desired_fields (List[str] | str): The fields to retrieve.
206
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
207
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
208
+ result_per_query (int): The number of results to return per query.
209
+
210
+ Returns:
211
+ List[Dict[str, Any]] | List[Any]: The retrieved data.
212
+ """
213
+ # Step 1: Search for vectors
214
+ search_results = self._client.search(
215
+ collection_name or self.safe_target_collection,
216
+ vecs,
217
+ search_params={"radius": similarity_threshold},
218
+ output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
219
+ limit=result_per_query,
220
+ )
221
+
222
+ # Step 2: Flatten the search results
223
+ flattened_results = flatten(search_results)
224
+
225
+ # Step 3: Sort by distance (descending)
226
+ sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
227
+
228
+ logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
229
+ # Step 4: Extract the entities
230
+ resp = [result["entity"] for result in sorted_results]
231
+
232
+ if isinstance(desired_fields, list):
233
+ return resp
234
+ return [r.get(desired_fields) for r in resp]
235
+
236
+ async def aretrieve(
237
+ self,
238
+ query: List[str] | str,
239
+ collection_name: Optional[str] = None,
240
+ final_limit: int = 20,
241
+ **kwargs: Unpack[FetchKwargs],
242
+ ) -> List[str]:
243
+ """Retrieve data from the collection.
244
+
245
+ Args:
246
+ query (List[str] | str): The query to be used for retrieval.
247
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
248
+ final_limit (int): The final limit on the number of results to return.
249
+ **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
250
+
251
+ Returns:
252
+ List[str]: A list of strings containing the retrieved data.
253
+ """
254
+ if isinstance(query, str):
255
+ query = [query]
256
+ return (
257
+ await self.afetch_document(
258
+ vecs=(await self.vectorize(query)),
259
+ desired_fields="text",
260
+ collection_name=collection_name,
261
+ **kwargs,
262
+ )
263
+ )[:final_limit]
264
+
265
+ async def aask_retrieved(
266
+ self,
267
+ question: str | List[str],
268
+ query: List[str] | str,
269
+ collection_name: Optional[str] = None,
270
+ extra_system_message: str = "",
271
+ result_per_query: int = 10,
272
+ final_limit: int = 20,
273
+ similarity_threshold: float = 0.37,
274
+ **kwargs: Unpack[LLMKwargs],
275
+ ) -> str:
276
+ """Asks a question by retrieving relevant documents based on the provided query.
277
+
278
+ This method performs document retrieval using the given query, then asks the
279
+ specified question using the retrieved documents as context.
280
+
281
+ Args:
282
+ question (str | List[str]): The question or list of questions to be asked.
283
+ query (List[str] | str): The query or list of queries used for document retrieval.
284
+ collection_name (Optional[str]): The name of the collection to retrieve documents from.
285
+ If not provided, the currently viewed collection is used.
286
+ extra_system_message (str): An additional system message to be included in the prompt.
287
+ result_per_query (int): The number of results to return per query. Default is 10.
288
+ final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
289
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
290
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
291
+
292
+ Returns:
293
+ str: A string response generated after asking with the context of retrieved documents.
294
+ """
295
+ docs = await self.aretrieve(
296
+ query,
297
+ collection_name,
298
+ final_limit,
299
+ result_per_query=result_per_query,
300
+ similarity_threshold=similarity_threshold,
301
+ )
302
+
303
+ rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
304
+
305
+ logger.debug(f"Retrieved Documents: \n{rendered}")
306
+ return await self.aask(
307
+ question,
308
+ f"{rendered}\n\n{extra_system_message}",
309
+ **kwargs,
310
+ )
@@ -12,7 +12,7 @@ from fabricatio.models.generic import WithBriefing
12
12
  from fabricatio.models.kwargs_types import GenerateKwargs, ValidateKwargs
13
13
  from fabricatio.models.usages import LLMUsage
14
14
  from fabricatio.parser import JsonCapture
15
- from more_itertools import flatten
15
+ from more_itertools import flatten, windowed
16
16
  from pydantic import NonNegativeInt, PositiveInt
17
17
 
18
18
 
@@ -275,3 +275,81 @@ class GiveRating(WithBriefing, LLMUsage):
275
275
  validator=_criteria_validator,
276
276
  **kwargs,
277
277
  )
278
+
279
+ async def drafting_rating_weights_klee(
280
+ self,
281
+ topic: str,
282
+ criteria: Set[str],
283
+ **kwargs: Unpack[ValidateKwargs],
284
+ ) -> Dict[str, float]:
285
+ """Drafts rating weights for a given topic and criteria using the Klee method.
286
+
287
+ Args:
288
+ topic (str): The topic for the rating weights.
289
+ criteria (Set[str]): A set of criteria for the rating weights.
290
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
291
+
292
+ Returns:
293
+ Dict[str, float]: A dictionary representing the drafted rating weights for each criterion.
294
+ """
295
+ if len(criteria) < 2: # noqa: PLR2004
296
+ raise ValueError("At least two criteria are required to draft rating weights")
297
+
298
+ def _validator(resp: str) -> float | None:
299
+ if (cap := JsonCapture.convert_with(resp, orjson.loads)) is not None and isinstance(cap, float):
300
+ return cap
301
+ return None
302
+
303
+ criteria = list(criteria) # freeze the order
304
+ windows = windowed(criteria, 2)
305
+
306
+ # get the importance multiplier indicating how important is second criterion compared to the first one
307
+ relative_weights = await self.aask_validate_batch(
308
+ questions=[
309
+ template_manager.render_template(
310
+ configs.templates.draft_rating_weights_klee_template,
311
+ {
312
+ "topic": topic,
313
+ "first": pair[0],
314
+ "second": pair[1],
315
+ },
316
+ )
317
+ for pair in windows
318
+ ],
319
+ validator=_validator,
320
+ **GenerateKwargs(system_message=f"# your personal briefing: \n{self.briefing}", **kwargs),
321
+ )
322
+ weights = [1]
323
+ for rw in relative_weights:
324
+ weights.append(weights[-1] * rw)
325
+ total = sum(weights)
326
+ return dict(zip(criteria, [w / total for w in weights], strict=True))
327
+
328
+ async def composite_score(
329
+ self,
330
+ topic: str,
331
+ to_rate: List[str],
332
+ reasons_count: PositiveInt = 2,
333
+ criteria_count: PositiveInt = 5,
334
+ **kwargs: Unpack[ValidateKwargs],
335
+ ) -> List[float]:
336
+ """Calculates the composite scores for a list of items based on a given topic and criteria.
337
+
338
+ Args:
339
+ topic (str): The topic for the rating.
340
+ to_rate (List[str]): A list of strings to be rated.
341
+ reasons_count (PositiveInt, optional): The number of reasons to extract from each pair of examples. Defaults to 2.
342
+ criteria_count (PositiveInt, optional): The number of criteria to draft. Defaults to 5.
343
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
344
+
345
+ Returns:
346
+ List[float]: A list of composite scores for the items.
347
+ """
348
+ criteria = await self.draft_rating_criteria_from_examples(
349
+ topic, to_rate, reasons_count, criteria_count, **kwargs
350
+ )
351
+ weights = await self.drafting_rating_weights_klee(topic, criteria, **kwargs)
352
+ logger.info(f"Criteria: {criteria}\nWeights: {weights}")
353
+ ratings_seq = await self.rate(to_rate, topic, criteria, **kwargs)
354
+
355
+ return [sum(ratings[c] * weights[c] for c in criteria) for ratings in ratings_seq]
fabricatio/config.py CHANGED
@@ -11,6 +11,7 @@ from pydantic import (
11
11
  FilePath,
12
12
  HttpUrl,
13
13
  NonNegativeFloat,
14
+ PositiveFloat,
14
15
  PositiveInt,
15
16
  SecretStr,
16
17
  )
@@ -79,6 +80,30 @@ class LLMConfig(BaseModel):
79
80
  """The maximum number of tokens to generate. Set to 8192 as per request."""
80
81
 
81
82
 
83
+ class EmbeddingConfig(BaseModel):
84
+ """Embedding configuration class."""
85
+
86
+ model_config = ConfigDict(use_attribute_docstrings=True)
87
+
88
+ model: str = Field(default="text-embedding-ada-002")
89
+ """The embedding model name. """
90
+
91
+ dimensions: Optional[PositiveInt] = Field(default=None)
92
+ """The dimensions of the embedding. Default is None."""
93
+
94
+ timeout: Optional[PositiveInt] = Field(default=None)
95
+ """The timeout of the embedding model in seconds. Default is 300 seconds as per request."""
96
+
97
+ caching: bool = Field(default=False)
98
+ """Whether to cache the embedding. Default is True."""
99
+
100
+ api_endpoint: Optional[HttpUrl] = None
101
+ """The OpenAI API endpoint."""
102
+
103
+ api_key: Optional[SecretStr] = None
104
+ """The OpenAI API key."""
105
+
106
+
82
107
  class PymitterConfig(BaseModel):
83
108
  """Pymitter configuration class.
84
109
 
@@ -172,6 +197,12 @@ class TemplateConfig(BaseModel):
172
197
  extract_criteria_from_reasons_template: str = Field(default="extract_criteria_from_reasons")
173
198
  """The name of the extract criteria from reasons template which will be used to extract criteria from reasons."""
174
199
 
200
+ draft_rating_weights_klee_template: str = Field(default="draft_rating_weights_klee")
201
+ """The name of the draft rating weights klee template which will be used to draft rating weights with Klee method."""
202
+
203
+ retrieved_display_template: str = Field(default="retrieved_display")
204
+ """The name of the retrieved display template which will be used to display retrieved documents."""
205
+
175
206
 
176
207
  class MagikaConfig(BaseModel):
177
208
  """Magika configuration class."""
@@ -204,6 +235,21 @@ class ToolBoxConfig(BaseModel):
204
235
  """The name of the module containing the data."""
205
236
 
206
237
 
238
+ class RagConfig(BaseModel):
239
+ """RAG configuration class."""
240
+
241
+ model_config = ConfigDict(use_attribute_docstrings=True)
242
+
243
+ milvus_uri: HttpUrl = Field(default=HttpUrl("http://localhost:19530"))
244
+ """The URI of the Milvus server."""
245
+ milvus_timeout: Optional[PositiveFloat] = Field(default=None)
246
+ """The timeout of the Milvus server."""
247
+ milvus_token: Optional[SecretStr] = Field(default=None)
248
+ """The token of the Milvus server."""
249
+ milvus_dimensions: Optional[PositiveInt] = Field(default=None)
250
+ """The dimensions of the Milvus server."""
251
+
252
+
207
253
  class Settings(BaseSettings):
208
254
  """Application settings class.
209
255
 
@@ -229,6 +275,9 @@ class Settings(BaseSettings):
229
275
  llm: LLMConfig = Field(default_factory=LLMConfig)
230
276
  """LLM Configuration"""
231
277
 
278
+ embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
279
+ """Embedding Configuration"""
280
+
232
281
  debug: DebugConfig = Field(default_factory=DebugConfig)
233
282
  """Debug Configuration"""
234
283
 
@@ -247,6 +296,9 @@ class Settings(BaseSettings):
247
296
  toolbox: ToolBoxConfig = Field(default_factory=ToolBoxConfig)
248
297
  """Toolbox Configuration"""
249
298
 
299
+ rag: RagConfig = Field(default_factory=RagConfig)
300
+ """RAG Configuration"""
301
+
250
302
  @classmethod
251
303
  def settings_customise_sources(
252
304
  cls,
fabricatio/core.py CHANGED
@@ -38,11 +38,11 @@ class Env(BaseModel):
38
38
 
39
39
  @overload
40
40
  def on[**P, R](
41
- self,
42
- event: str | Event,
43
- func: Optional[Callable[P, R]] = None,
44
- /,
45
- ttl: int = -1,
41
+ self,
42
+ event: str | Event,
43
+ func: Optional[Callable[P, R]] = None,
44
+ /,
45
+ ttl: int = -1,
46
46
  ) -> Callable[[Callable[P, R]], Callable[P, R]]:
47
47
  """
48
48
  Registers an event listener with a specific function that listens indefinitely or for a specified number of times.
@@ -58,11 +58,11 @@ class Env(BaseModel):
58
58
  ...
59
59
 
60
60
  def on[**P, R](
61
- self,
62
- event: str | Event,
63
- func: Optional[Callable[P, R]] = None,
64
- /,
65
- ttl=-1,
61
+ self,
62
+ event: str | Event,
63
+ func: Optional[Callable[P, R]] = None,
64
+ /,
65
+ ttl=-1,
66
66
  ) -> Callable[[Callable[P, R]], Callable[P, R]] | Self:
67
67
  """Registers an event listener with a specific function that listens indefinitely or for a specified number of times.
68
68
 
@@ -78,14 +78,13 @@ class Env(BaseModel):
78
78
  event = event.collapse()
79
79
  if func is None:
80
80
  return self._ee.on(event, ttl=ttl)
81
-
82
81
  self._ee.on(event, func, ttl=ttl)
83
82
  return self
84
83
 
85
84
  @overload
86
85
  def once[**P, R](
87
- self,
88
- event: str | Event,
86
+ self,
87
+ event: str | Event,
89
88
  ) -> Callable[[Callable[P, R]], Callable[P, R]]:
90
89
  """
91
90
  Registers an event listener that listens only once.
@@ -100,9 +99,9 @@ class Env(BaseModel):
100
99
 
101
100
  @overload
102
101
  def once[**P, R](
103
- self,
104
- event: str | Event,
105
- func: Callable[[Callable[P, R]], Callable[P, R]],
102
+ self,
103
+ event: str | Event,
104
+ func: Callable[[Callable[P, R]], Callable[P, R]],
106
105
  ) -> Self:
107
106
  """
108
107
  Registers an event listener with a specific function that listens only once.
@@ -117,9 +116,9 @@ class Env(BaseModel):
117
116
  ...
118
117
 
119
118
  def once[**P, R](
120
- self,
121
- event: str | Event,
122
- func: Optional[Callable[P, R]] = None,
119
+ self,
120
+ event: str | Event,
121
+ func: Optional[Callable[P, R]] = None,
123
122
  ) -> Callable[[Callable[P, R]], Callable[P, R]] | Self:
124
123
  """Registers an event listener with a specific function that listens only once.
125
124
 
@@ -163,5 +162,20 @@ class Env(BaseModel):
163
162
  event = event.collapse()
164
163
  return await self._ee.emit_async(event, *args, **kwargs)
165
164
 
165
+ def emit_future[**P](self, event: str | Event, *args: P.args, **kwargs: P.kwargs) -> None:
166
+ """Emits an event to all registered listeners and returns a future object.
167
+
168
+ Args:
169
+ event (str | Event): The event to emit.
170
+ *args: Positional arguments to pass to the listeners.
171
+ **kwargs: Keyword arguments to pass to the listeners.
172
+
173
+ Returns:
174
+ None: The future object.
175
+ """
176
+ if isinstance(event, Event):
177
+ event = event.collapse()
178
+ return self._ee.emit_future(event, *args, **kwargs)
179
+
166
180
 
167
181
  env = Env()
@@ -2,7 +2,7 @@
2
2
 
3
3
  import traceback
4
4
  from abc import abstractmethod
5
- from asyncio import Queue
5
+ from asyncio import Queue, create_task
6
6
  from typing import Any, Dict, Self, Tuple, Type, Union, Unpack
7
7
 
8
8
  from fabricatio.capabilities.rating import GiveRating
@@ -108,7 +108,11 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
108
108
  try:
109
109
  for step in self._instances:
110
110
  logger.debug(f"Executing step: {step.name}")
111
- modified_ctx = await step.act(await self._context.get())
111
+ act_task = create_task(step.act(await self._context.get()))
112
+ if task.is_cancelled():
113
+ act_task.cancel(f"Cancelled by task: {task.name}")
114
+ break
115
+ modified_ctx = await act_task
112
116
  await self._context.put(modified_ctx)
113
117
  current_action = step.name
114
118
  logger.info(f"Finished executing workflow: {self.name}")