fabricatio 0.2.10.dev0__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. fabricatio/actions/article.py +55 -10
  2. fabricatio/actions/article_rag.py +297 -12
  3. fabricatio/actions/fs.py +25 -0
  4. fabricatio/actions/output.py +17 -3
  5. fabricatio/actions/rag.py +42 -20
  6. fabricatio/actions/rules.py +14 -3
  7. fabricatio/capabilities/extract.py +70 -0
  8. fabricatio/capabilities/rag.py +5 -2
  9. fabricatio/capabilities/rating.py +5 -2
  10. fabricatio/capabilities/task.py +16 -16
  11. fabricatio/config.py +9 -2
  12. fabricatio/decorators.py +43 -26
  13. fabricatio/fs/__init__.py +9 -2
  14. fabricatio/fs/readers.py +6 -10
  15. fabricatio/models/action.py +16 -11
  16. fabricatio/models/adv_kwargs_types.py +5 -12
  17. fabricatio/models/extra/aricle_rag.py +254 -0
  18. fabricatio/models/extra/article_base.py +56 -7
  19. fabricatio/models/extra/article_essence.py +8 -7
  20. fabricatio/models/extra/article_main.py +102 -6
  21. fabricatio/models/extra/problem.py +5 -1
  22. fabricatio/models/extra/rag.py +49 -23
  23. fabricatio/models/generic.py +43 -24
  24. fabricatio/models/kwargs_types.py +12 -3
  25. fabricatio/models/task.py +13 -1
  26. fabricatio/models/usages.py +10 -27
  27. fabricatio/parser.py +16 -12
  28. fabricatio/rust.cp312-win_amd64.pyd +0 -0
  29. fabricatio/rust.pyi +177 -63
  30. fabricatio/utils.py +50 -10
  31. fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
  32. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -12
  33. fabricatio-0.2.11.dist-info/RECORD +65 -0
  34. fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
  35. fabricatio-0.2.10.dev0.dist-info/RECORD +0 -62
  36. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
  37. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,16 @@
1
1
  """A module containing the DraftRuleSet action."""
2
2
 
3
- from typing import List, Optional
3
+ from typing import Any, List, Mapping, Optional, Self, Tuple
4
4
 
5
5
  from fabricatio.capabilities.check import Check
6
6
  from fabricatio.journal import logger
7
7
  from fabricatio.models.action import Action
8
8
  from fabricatio.models.extra.rule import RuleSet
9
+ from fabricatio.models.generic import FromMapping
9
10
  from fabricatio.utils import ok
10
11
 
11
12
 
12
- class DraftRuleSet(Action, Check):
13
+ class DraftRuleSet(Action, Check, FromMapping):
13
14
  """Action to draft a ruleset based on a given requirement description."""
14
15
 
15
16
  output_key: str = "drafted_ruleset"
@@ -45,8 +46,13 @@ class DraftRuleSet(Action, Check):
45
46
  logger.warning(f"Drafting Rule Failed for:\n{ruleset_requirement}")
46
47
  return ruleset
47
48
 
49
+ @classmethod
50
+ def from_mapping(cls, mapping: Mapping[str, Tuple[int, str]], **kwargs) -> List[Self]:
51
+ """Create a list of DraftRuleSet actions from a mapping of output keys to tuples of rule counts and requirements."""
52
+ return [cls(ruleset_requirement=r, rule_count=c, output_key=k, **kwargs) for k, (c, r) in mapping.items()]
48
53
 
49
- class GatherRuleset(Action):
54
+
55
+ class GatherRuleset(Action, FromMapping):
50
56
  """Action to gather a ruleset from a given requirement description."""
51
57
 
52
58
  output_key: str = "gathered_ruleset"
@@ -55,6 +61,11 @@ class GatherRuleset(Action):
55
61
  to_gather: List[str]
56
62
  """the cxt name of RuleSet to gather"""
57
63
 
64
+ @classmethod
65
+ def from_mapping(cls, mapping: Mapping[str, List[str]], **kwargs: Any) -> List[Self]:
66
+ """Create a list of GatherRuleset actions from a mapping of output keys to tuples of rule counts and requirements."""
67
+ return [cls(to_gather=t, output_key=k, **kwargs) for k, t in mapping.items()]
68
+
58
69
  async def _execute(self, **cxt) -> RuleSet:
59
70
  logger.info(f"Gathering Ruleset from {self.to_gather}")
60
71
  # Fix for not_found
@@ -0,0 +1,70 @@
1
+ """A module that provide capabilities for extracting information from a given source to a model."""
2
+
3
+ from typing import List, Optional, Type, Unpack, overload
4
+
5
+ from fabricatio import TEMPLATE_MANAGER
6
+ from fabricatio.capabilities.propose import Propose
7
+ from fabricatio.config import configs
8
+ from fabricatio.models.generic import ProposedAble
9
+ from fabricatio.models.kwargs_types import ValidateKwargs
10
+
11
+
12
+ class Extract(Propose):
13
+ """A class that extract information from a given source to a model."""
14
+
15
+ @overload
16
+ async def extract[M: ProposedAble](
17
+ self,
18
+ cls: Type[M],
19
+ source: str,
20
+ extract_requirement: Optional[str] = None,
21
+ align_language: bool = True,
22
+ **kwargs: Unpack[ValidateKwargs[M]],
23
+ ) -> M: ...
24
+ @overload
25
+ async def extract[M: ProposedAble](
26
+ self,
27
+ cls: Type[M],
28
+ source: str,
29
+ extract_requirement: Optional[str] = None,
30
+ align_language: bool = True,
31
+ **kwargs: Unpack[ValidateKwargs[None]],
32
+ ) -> Optional[M]: ...
33
+
34
+ @overload
35
+ async def extract[M: ProposedAble](
36
+ self,
37
+ cls: Type[M],
38
+ source: List[str],
39
+ extract_requirement: Optional[str] = None,
40
+ align_language: bool = True,
41
+ **kwargs: Unpack[ValidateKwargs[M]],
42
+ ) -> List[M]: ...
43
+ @overload
44
+ async def extract[M: ProposedAble](
45
+ self,
46
+ cls: Type[M],
47
+ source: List[str],
48
+ extract_requirement: Optional[str] = None,
49
+ align_language: bool = True,
50
+ **kwargs: Unpack[ValidateKwargs[None]],
51
+ ) -> List[Optional[M]]: ...
52
+ async def extract[M: ProposedAble](
53
+ self,
54
+ cls: Type[M],
55
+ source: List[str] | str,
56
+ extract_requirement: Optional[str] = None,
57
+ align_language: bool = True,
58
+ **kwargs: Unpack[ValidateKwargs[Optional[M]]],
59
+ ) -> M | List[M] | Optional[M] | List[Optional[M]]:
60
+ """Extract information from a given source to a model."""
61
+ return await self.propose(
62
+ cls,
63
+ prompt=TEMPLATE_MANAGER.render_template(
64
+ configs.templates.extract_template,
65
+ [{"source": s, "extract_requirement": extract_requirement} for s in source]
66
+ if isinstance(source, list)
67
+ else {"source": source, "extract_requirement": extract_requirement, "align_language": align_language},
68
+ ),
69
+ **kwargs,
70
+ )
@@ -130,7 +130,7 @@ class RAG(EmbeddingUsage):
130
130
  if isinstance(data, MilvusDataBase):
131
131
  data = [data]
132
132
 
133
- data_vec = await self.vectorize([d.to_vectorize for d in data])
133
+ data_vec = await self.vectorize([d.prepare_vectorization() for d in data])
134
134
  prepared_data = [d.prepare_insertion(vec) for d, vec in zip(data, data_vec, strict=True)]
135
135
 
136
136
  c_name = collection_name or self.safe_target_collection
@@ -188,13 +188,15 @@ class RAG(EmbeddingUsage):
188
188
  async def aretrieve[D: MilvusDataBase](
189
189
  self,
190
190
  query: List[str] | str,
191
+ document_model: Type[D],
191
192
  final_limit: int = 20,
192
- **kwargs: Unpack[FetchKwargs[D]],
193
+ **kwargs: Unpack[FetchKwargs],
193
194
  ) -> List[D]:
194
195
  """Retrieve data from the collection.
195
196
 
196
197
  Args:
197
198
  query (List[str] | str): The query to be used for retrieval.
199
+ document_model (Type[D]): The model class used to convert retrieved data into document objects.
198
200
  final_limit (int): The final limit on the number of results to return.
199
201
  **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
200
202
 
@@ -206,6 +208,7 @@ class RAG(EmbeddingUsage):
206
208
  return (
207
209
  await self.afetch_document(
208
210
  vecs=(await self.vectorize(query)),
211
+ document_model=document_model,
209
212
  **kwargs,
210
213
  )
211
214
  )[:final_limit]
@@ -14,7 +14,7 @@ from fabricatio.models.generic import Display, ProposedAble
14
14
  from fabricatio.models.kwargs_types import CompositeScoreKwargs, ValidateKwargs
15
15
  from fabricatio.parser import JsonCapture
16
16
  from fabricatio.rust_instances import TEMPLATE_MANAGER
17
- from fabricatio.utils import fallback_kwargs, ok, override_kwargs
17
+ from fabricatio.utils import ok, override_kwargs
18
18
 
19
19
 
20
20
  class Rating(Propose):
@@ -137,7 +137,7 @@ class Rating(Propose):
137
137
  or dict(zip(criteria, criteria, strict=True))
138
138
  )
139
139
 
140
- return await self.rate_fine_grind(to_rate, manual, score_range, **fallback_kwargs(kwargs, co_extractor={}))
140
+ return await self.rate_fine_grind(to_rate, manual, score_range, **kwargs)
141
141
 
142
142
  async def draft_rating_manual(
143
143
  self, topic: str, criteria: Optional[Set[str]] = None, **kwargs: Unpack[ValidateKwargs[Dict[str, str]]]
@@ -338,6 +338,7 @@ class Rating(Propose):
338
338
  criteria: Optional[Set[str]] = None,
339
339
  weights: Optional[Dict[str, float]] = None,
340
340
  manual: Optional[Dict[str, str]] = None,
341
+ approx: bool = False,
341
342
  **kwargs: Unpack[ValidateKwargs[List[Dict[str, float]]]],
342
343
  ) -> List[float]:
343
344
  """Calculates the composite scores for a list of items based on a given topic and criteria.
@@ -348,6 +349,7 @@ class Rating(Propose):
348
349
  criteria (Optional[Set[str]]): A set of criteria for the rating. Defaults to None.
349
350
  weights (Optional[Dict[str, float]]): A dictionary of rating weights for each criterion. Defaults to None.
350
351
  manual (Optional[Dict[str, str]]): A dictionary of manual ratings for each item. Defaults to None.
352
+ approx (bool): Whether to use approximate rating criteria. Defaults to False.
351
353
  **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
352
354
 
353
355
  Returns:
@@ -355,6 +357,7 @@ class Rating(Propose):
355
357
  """
356
358
  criteria = ok(
357
359
  criteria
360
+ or (await self.draft_rating_criteria(topic, **override_kwargs(kwargs, default=None)) if approx else None)
358
361
  or await self.draft_rating_criteria_from_examples(topic, to_rate, **override_kwargs(kwargs, default=None))
359
362
  )
360
363
  weights = ok(
@@ -3,7 +3,7 @@
3
3
  from types import CodeType
4
4
  from typing import Any, Dict, List, Optional, Tuple, Unpack
5
5
 
6
- import orjson
6
+ import ujson
7
7
 
8
8
  from fabricatio.capabilities.propose import Propose
9
9
  from fabricatio.config import configs
@@ -20,9 +20,9 @@ class ProposeTask(Propose):
20
20
  """A class that proposes a task based on a prompt."""
21
21
 
22
22
  async def propose_task[T](
23
- self,
24
- prompt: str,
25
- **kwargs: Unpack[ValidateKwargs[Task[T]]],
23
+ self,
24
+ prompt: str,
25
+ **kwargs: Unpack[ValidateKwargs[Task[T]]],
26
26
  ) -> Optional[Task[T]]:
27
27
  """Asynchronously proposes a task based on a given prompt and parameters.
28
28
 
@@ -44,11 +44,11 @@ class HandleTask(ToolBoxUsage):
44
44
  """A class that handles a task based on a task object."""
45
45
 
46
46
  async def draft_tool_usage_code(
47
- self,
48
- task: Task,
49
- tools: List[Tool],
50
- data: Dict[str, Any],
51
- **kwargs: Unpack[ValidateKwargs],
47
+ self,
48
+ task: Task,
49
+ tools: List[Tool],
50
+ data: Dict[str, Any],
51
+ **kwargs: Unpack[ValidateKwargs],
52
52
  ) -> Optional[Tuple[CodeType, List[str]]]:
53
53
  """Asynchronously drafts the tool usage code for a task based on a given task object and tools."""
54
54
  logger.info(f"Drafting tool usage code for task: {task.briefing}")
@@ -60,7 +60,7 @@ class HandleTask(ToolBoxUsage):
60
60
 
61
61
  def _validator(response: str) -> Tuple[CodeType, List[str]] | None:
62
62
  if (source := PythonCapture.convert_with(response, lambda resp: compile(resp, "<string>", "exec"))) and (
63
- to_extract := JsonCapture.convert_with(response, orjson.loads)
63
+ to_extract := JsonCapture.convert_with(response, ujson.loads)
64
64
  ):
65
65
  return source, to_extract
66
66
 
@@ -85,12 +85,12 @@ class HandleTask(ToolBoxUsage):
85
85
  )
86
86
 
87
87
  async def handle_fine_grind(
88
- self,
89
- task: Task,
90
- data: Dict[str, Any],
91
- box_choose_kwargs: Optional[ChooseKwargs] = None,
92
- tool_choose_kwargs: Optional[ChooseKwargs] = None,
93
- **kwargs: Unpack[ValidateKwargs],
88
+ self,
89
+ task: Task,
90
+ data: Dict[str, Any],
91
+ box_choose_kwargs: Optional[ChooseKwargs] = None,
92
+ tool_choose_kwargs: Optional[ChooseKwargs] = None,
93
+ **kwargs: Unpack[ValidateKwargs],
94
94
  ) -> Optional[Tuple]:
95
95
  """Asynchronously handles a task based on a given task object and parameters."""
96
96
  logger.info(f"Handling task: \n{task.briefing}")
fabricatio/config.py CHANGED
@@ -86,8 +86,10 @@ class LLMConfig(BaseModel):
86
86
 
87
87
  tpm: Optional[PositiveInt] = Field(default=1000000)
88
88
  """The rate limit of the LLM model in tokens per minute. None means not checked."""
89
-
90
-
89
+ presence_penalty:Optional[PositiveFloat]=None
90
+ """The presence penalty of the LLM model."""
91
+ frequency_penalty:Optional[PositiveFloat]=None
92
+ """The frequency penalty of the LLM model."""
91
93
  class EmbeddingConfig(BaseModel):
92
94
  """Embedding configuration class."""
93
95
 
@@ -249,6 +251,11 @@ class TemplateConfig(BaseModel):
249
251
 
250
252
  rule_requirement_template: str = Field(default="rule_requirement")
251
253
  """The name of the rule requirement template which will be used to generate a rule requirement."""
254
+
255
+
256
+ extract_template: str = Field(default="extract")
257
+ """The name of the extract template which will be used to extract model from string."""
258
+
252
259
  class MagikaConfig(BaseModel):
253
260
  """Magika configuration class."""
254
261
 
fabricatio/decorators.py CHANGED
@@ -6,14 +6,47 @@ from importlib.util import find_spec
6
6
  from inspect import signature
7
7
  from shutil import which
8
8
  from types import ModuleType
9
- from typing import Callable, List, Optional
10
-
11
- from questionary import confirm
9
+ from typing import Callable, Coroutine, List, Optional
12
10
 
13
11
  from fabricatio.config import configs
14
12
  from fabricatio.journal import logger
15
13
 
16
14
 
15
+ def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
16
+ """Check if a package exists in the current environment.
17
+
18
+ Args:
19
+ package_name (str): The name of the package to check.
20
+ msg (str): The message to display if the package is not found.
21
+
22
+ Returns:
23
+ bool: True if the package exists, False otherwise.
24
+ """
25
+
26
+ def _wrapper(
27
+ func: Callable[P, R] | Callable[P, Coroutine[None, None, R]],
28
+ ) -> Callable[P, R] | Callable[P, Coroutine[None, None, R]]:
29
+ if iscoroutinefunction(func):
30
+
31
+ @wraps(func)
32
+ async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> R:
33
+ if find_spec(package_name):
34
+ return await func(*args, **kwargs)
35
+ raise RuntimeError(msg)
36
+
37
+ return _async_inner
38
+
39
+ @wraps(func)
40
+ def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
41
+ if find_spec(package_name):
42
+ return func(*args, **kwargs)
43
+ raise RuntimeError(msg)
44
+
45
+ return _inner
46
+
47
+ return _wrapper
48
+
49
+
17
50
  def depend_on_external_cmd[**P, R](
18
51
  bin_name: str, install_tip: Optional[str], homepage: Optional[str] = None
19
52
  ) -> Callable[[Callable[P, R]], Callable[P, R]]:
@@ -68,6 +101,9 @@ def logging_execution_info[**P, R](func: Callable[P, R]) -> Callable[P, R]:
68
101
  return _wrapper
69
102
 
70
103
 
104
+ @precheck_package(
105
+ "questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
106
+ )
71
107
  def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]] | Callable[P, R]:
72
108
  """Decorator to confirm before executing a function.
73
109
 
@@ -80,6 +116,7 @@ def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]]
80
116
  if not configs.general.confirm_on_ops:
81
117
  # Skip confirmation if the configuration is set to False
82
118
  return func
119
+ from questionary import confirm
83
120
 
84
121
  if iscoroutinefunction(func):
85
122
 
@@ -180,7 +217,9 @@ def use_temp_module[**P, R](modules: ModuleType | List[ModuleType]) -> Callable[
180
217
  return _decorator
181
218
 
182
219
 
183
- def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
220
+ def logging_exec_time[**P, R](
221
+ func: Callable[P, R] | Callable[P, Coroutine[None, None, R]],
222
+ ) -> Callable[P, R] | Callable[P, Coroutine[None, None, R]]:
184
223
  """Decorator to log the execution time of a function.
185
224
 
186
225
  Args:
@@ -210,25 +249,3 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
210
249
  return result
211
250
 
212
251
  return _wrapper
213
-
214
-
215
- def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
216
- """Check if a package exists in the current environment.
217
-
218
- Args:
219
- package_name (str): The name of the package to check.
220
- msg (str): The message to display if the package is not found.
221
-
222
- Returns:
223
- bool: True if the package exists, False otherwise.
224
- """
225
-
226
- def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
227
- def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
228
- if find_spec(package_name):
229
- return func(*args, **kwargs)
230
- raise RuntimeError(msg)
231
-
232
- return _inner
233
-
234
- return _wrapper
fabricatio/fs/__init__.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """FileSystem manipulation module for Fabricatio."""
2
+ from importlib.util import find_spec
2
3
 
4
+ from fabricatio.config import configs
3
5
  from fabricatio.fs.curd import (
4
6
  absolute_path,
5
7
  copy_file,
@@ -11,10 +13,9 @@ from fabricatio.fs.curd import (
11
13
  move_file,
12
14
  tree,
13
15
  )
14
- from fabricatio.fs.readers import MAGIKA, safe_json_read, safe_text_read
16
+ from fabricatio.fs.readers import safe_json_read, safe_text_read
15
17
 
16
18
  __all__ = [
17
- "MAGIKA",
18
19
  "absolute_path",
19
20
  "copy_file",
20
21
  "create_directory",
@@ -27,3 +28,9 @@ __all__ = [
27
28
  "safe_text_read",
28
29
  "tree",
29
30
  ]
31
+
32
+ if find_spec("magika"):
33
+ from magika import Magika
34
+
35
+ MAGIKA = Magika(model_dir=configs.magika.model_dir)
36
+ __all__ += ["MAGIKA"]
fabricatio/fs/readers.py CHANGED
@@ -1,17 +1,13 @@
1
1
  """Filesystem readers for Fabricatio."""
2
2
 
3
+ import re
3
4
  from pathlib import Path
4
5
  from typing import Dict, List, Tuple
5
6
 
6
- import orjson
7
- import regex
8
- from magika import Magika
7
+ import ujson
9
8
 
10
- from fabricatio.config import configs
11
9
  from fabricatio.journal import logger
12
10
 
13
- MAGIKA = Magika(model_dir=configs.magika.model_dir)
14
-
15
11
 
16
12
  def safe_text_read(path: Path | str) -> str:
17
13
  """Safely read the text from a file.
@@ -41,8 +37,8 @@ def safe_json_read(path: Path | str) -> Dict:
41
37
  """
42
38
  path = Path(path)
43
39
  try:
44
- return orjson.loads(path.read_text(encoding="utf-8"))
45
- except (orjson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
40
+ return ujson.loads(path.read_text(encoding="utf-8"))
41
+ except (ujson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
46
42
  logger.error(f"Failed to read file {path}: {e!s}")
47
43
  return {}
48
44
 
@@ -58,8 +54,8 @@ def extract_sections(string: str, level: int, section_char: str = "#") -> List[T
58
54
  Returns:
59
55
  List[Tuple[str, str]]: List of (header_text, section_content) tuples
60
56
  """
61
- return regex.findall(
57
+ return re.findall(
62
58
  r"^%s{%d}\s+(.+?)\n((?:(?!^%s{%d}\s).|\n)*)" % (section_char, level, section_char, level),
63
59
  string,
64
- regex.MULTILINE,
60
+ re.MULTILINE,
65
61
  )
@@ -12,12 +12,13 @@ Classes:
12
12
  import traceback
13
13
  from abc import abstractmethod
14
14
  from asyncio import Queue, create_task
15
- from typing import Any, Dict, Self, Tuple, Type, Union, final
15
+ from typing import Any, Dict, Self, Sequence, Tuple, Type, Union, final
16
16
 
17
17
  from fabricatio.journal import logger
18
18
  from fabricatio.models.generic import WithBriefing
19
19
  from fabricatio.models.task import Task
20
20
  from fabricatio.models.usages import LLMUsage, ToolBoxUsage
21
+ from fabricatio.utils import override_kwargs
21
22
  from pydantic import Field, PrivateAttr
22
23
 
23
24
  OUTPUT_KEY = "task_output"
@@ -55,7 +56,7 @@ class Action(WithBriefing, LLMUsage):
55
56
  self.description = self.description or self.__class__.__doc__ or ""
56
57
 
57
58
  @abstractmethod
58
- async def _execute(self, *_:Any, **cxt) -> Any:
59
+ async def _execute(self, *_: Any, **cxt) -> Any:
59
60
  """Implement the core logic of the action.
60
61
 
61
62
  Args:
@@ -95,11 +96,12 @@ class Action(WithBriefing, LLMUsage):
95
96
  return f"## Your personality: \n{self.personality}\n# The action you are going to perform: \n{super().briefing}"
96
97
  return f"# The action you are going to perform: \n{super().briefing}"
97
98
 
98
- def to_task_output(self)->Self:
99
+ def to_task_output(self, task_output_key: str = OUTPUT_KEY) -> Self:
99
100
  """Set the output key to OUTPUT_KEY and return the action instance."""
100
- self.output_key=OUTPUT_KEY
101
+ self.output_key = task_output_key
101
102
  return self
102
103
 
104
+
103
105
  class WorkFlow(WithBriefing, ToolBoxUsage):
104
106
  """Manages sequences of actions to fulfill tasks.
105
107
 
@@ -121,9 +123,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
121
123
  _instances: Tuple[Action, ...] = PrivateAttr(default_factory=tuple)
122
124
  """Instantiated action objects to be executed in this workflow."""
123
125
 
124
- steps: Tuple[Union[Type[Action], Action], ...] = Field(
125
- frozen=True,
126
- )
126
+ steps: Sequence[Union[Type[Action], Action]] = Field(frozen=True)
127
127
  """The sequence of actions to be executed, can be action classes or instances."""
128
128
 
129
129
  task_input_key: str = Field(default=INPUT_KEY)
@@ -177,7 +177,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
177
177
  current_action = None
178
178
  try:
179
179
  # Process each action in sequence
180
- for i,step in enumerate(self._instances):
180
+ for i, step in enumerate(self._instances):
181
181
  current_action = step.name
182
182
  logger.info(f"Executing step [{i}] >> {current_action}")
183
183
 
@@ -227,8 +227,13 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
227
227
  - Any extra_init_context values
228
228
  """
229
229
  logger.debug(f"Initializing context for workflow: {self.name}")
230
- initial_context = {self.task_input_key: task, **dict(self.extra_init_context)}
231
- await self._context.put(initial_context)
230
+ ctx = override_kwargs(self.extra_init_context, **task.extra_init_context)
231
+ if self.task_input_key in ctx:
232
+ raise ValueError(
233
+ f"Task input key: `{self.task_input_key}`, which is reserved, is already set in the init context"
234
+ )
235
+
236
+ await self._context.put({self.task_input_key: task, **ctx})
232
237
 
233
238
  def steps_fallback_to_self(self) -> Self:
234
239
  """Configure all steps to use this workflow's configuration as fallback.
@@ -245,7 +250,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
245
250
  Returns:
246
251
  Self: The workflow instance for method chaining.
247
252
  """
248
- self.provide_tools_to(i for i in self._instances if isinstance(i,ToolBoxUsage))
253
+ self.provide_tools_to(i for i in self._instances if isinstance(i, ToolBoxUsage))
249
254
  return self
250
255
 
251
256
  def update_init_context(self, /, **kwargs) -> Self:
@@ -1,10 +1,9 @@
1
1
  """A module containing kwargs types for content correction and checking operations."""
2
2
 
3
3
  from importlib.util import find_spec
4
- from typing import Required, Type, TypedDict
4
+ from typing import NotRequired, TypedDict
5
5
 
6
6
  from fabricatio.models.extra.problem import Improvement
7
- from fabricatio.models.extra.rag import MilvusDataBase
8
7
  from fabricatio.models.extra.rule import RuleSet
9
8
  from fabricatio.models.generic import SketchedAble
10
9
  from fabricatio.models.kwargs_types import ReferencedKwargs
@@ -49,19 +48,13 @@ if find_spec("pymilvus"):
49
48
  schema: CollectionSchema | None
50
49
  index_params: IndexParams | None
51
50
 
52
- class FetchKwargs[D: MilvusDataBase](TypedDict, total=False):
51
+ class FetchKwargs(TypedDict):
53
52
  """Arguments for fetching data from vector collections.
54
53
 
55
54
  Controls how data is retrieved from vector databases, including filtering
56
55
  and result limiting parameters.
57
56
  """
58
57
 
59
- document_model: Required[Type[D]]
60
- collection_name: str | None
61
- similarity_threshold: float
62
- result_per_query: int
63
-
64
- class RetrievalKwargs(FetchKwargs, total=False):
65
- """Arguments for retrieval operations."""
66
-
67
- final_limit: int
58
+ collection_name: NotRequired[str | None]
59
+ similarity_threshold: NotRequired[float]
60
+ result_per_query: NotRequired[int]