fabricatio 0.2.3.dev3__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.4.dev1__cp312-cp312-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fabricatio/__init__.py CHANGED
@@ -3,11 +3,13 @@
3
3
  from importlib.util import find_spec
4
4
 
5
5
  from fabricatio._rust_instances import template_manager
6
+ from fabricatio.actions import ExtractArticleEssence
6
7
  from fabricatio.core import env
7
8
  from fabricatio.fs import magika
8
9
  from fabricatio.journal import logger
9
10
  from fabricatio.models.action import Action, WorkFlow
10
11
  from fabricatio.models.events import Event
12
+ from fabricatio.models.extra import ArticleEssence
11
13
  from fabricatio.models.role import Role
12
14
  from fabricatio.models.task import Task
13
15
  from fabricatio.models.tool import ToolBox
@@ -17,9 +19,11 @@ from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
17
19
 
18
20
  __all__ = [
19
21
  "Action",
22
+ "ArticleEssence",
20
23
  "Capture",
21
24
  "CodeBlockCapture",
22
25
  "Event",
26
+ "ExtractArticleEssence",
23
27
  "JsonCapture",
24
28
  "Message",
25
29
  "Messages",
@@ -40,6 +44,6 @@ __all__ = [
40
44
 
41
45
 
42
46
  if find_spec("pymilvus"):
43
- from fabricatio.capabilities.rag import Rag
47
+ from fabricatio.capabilities.rag import RAG
44
48
 
45
- __all__ += ["Rag"]
49
+ __all__ += ["RAG"]
@@ -1,5 +1,5 @@
1
1
  """module for actions."""
2
2
 
3
- from fabricatio.actions.transmission import PublishTask
3
+ from fabricatio.actions.article import ExtractArticleEssence
4
4
 
5
- __all__ = ["PublishTask"]
5
+ __all__ = ["ExtractArticleEssence"]
@@ -0,0 +1,44 @@
1
+ """Actions for transmitting tasks to targets."""
2
+
3
+ from os import PathLike
4
+ from pathlib import Path
5
+ from typing import Callable, List
6
+
7
+ from fabricatio.journal import logger
8
+ from fabricatio.models.action import Action
9
+ from fabricatio.models.extra import ArticleEssence
10
+ from fabricatio.models.task import Task
11
+
12
+
13
+ class ExtractArticleEssence(Action):
14
+ """Extract the essence of article(s)."""
15
+
16
+ name: str = "extract article essence"
17
+ """The name of the action."""
18
+ description: str = "Extract the essence of an article. output as json"
19
+ """The description of the action."""
20
+
21
+ output_key: str = "article_essence"
22
+ """The key of the output data."""
23
+
24
+ async def _execute[P: PathLike | str](
25
+ self,
26
+ task_input: Task,
27
+ reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
28
+ **_,
29
+ ) -> List[ArticleEssence]:
30
+ if not await self.ajudge(
31
+ f"= Task\n{task_input.briefing}\n\n\n= Role\n{self.briefing}",
32
+ affirm_case="The task does not violate the role, and could be approved since the file dependencies are specified.",
33
+ deny_case="The task does violate the role, and could not be approved.",
34
+ ):
35
+ logger.info(err := "Task not approved.")
36
+ raise RuntimeError(err)
37
+
38
+ # trim the references
39
+ contents = ["References".join(c.split("References")[:-1]) for c in map(reader, task_input.dependencies)]
40
+ return await self.propose(
41
+ ArticleEssence,
42
+ contents,
43
+ system_message=f"# your personal briefing: \n{self.briefing}",
44
+ )
@@ -0,0 +1,55 @@
1
+ """A module for the task capabilities of the Fabricatio library."""
2
+
3
+ from typing import List, Type, Unpack, overload
4
+
5
+ from fabricatio.models.generic import ProposedAble
6
+ from fabricatio.models.kwargs_types import GenerateKwargs
7
+ from fabricatio.models.usages import LLMUsage
8
+
9
+
10
+ class Propose[M: ProposedAble](LLMUsage):
11
+ """A class that proposes an Obj based on a prompt."""
12
+
13
+ @overload
14
+ async def propose(
15
+ self,
16
+ cls: Type[M],
17
+ prompt: List[str],
18
+ **kwargs: Unpack[GenerateKwargs],
19
+ ) -> List[M]: ...
20
+
21
+ @overload
22
+ async def propose(
23
+ self,
24
+ cls: Type[M],
25
+ prompt: str,
26
+ **kwargs: Unpack[GenerateKwargs],
27
+ ) -> M: ...
28
+
29
+ async def propose(
30
+ self,
31
+ cls: Type[M],
32
+ prompt: List[str] | str,
33
+ **kwargs: Unpack[GenerateKwargs],
34
+ ) -> List[M] | M:
35
+ """Asynchronously proposes a task based on a given prompt and parameters.
36
+
37
+ Parameters:
38
+ cls: The class type of the task to be proposed.
39
+ prompt: The prompt text for proposing a task, which is a string that must be provided.
40
+ **kwargs: The keyword arguments for the LLM (Large Language Model) usage.
41
+
42
+ Returns:
43
+ A Task object based on the proposal result.
44
+ """
45
+ if isinstance(prompt, str):
46
+ return await self.aask_validate(
47
+ question=cls.create_json_prompt(prompt),
48
+ validator=cls.instantiate_from_string,
49
+ **kwargs,
50
+ )
51
+ return await self.aask_validate_batch(
52
+ questions=[cls.create_json_prompt(p) for p in prompt],
53
+ validator=cls.instantiate_from_string,
54
+ **kwargs,
55
+ )
@@ -1,71 +1,114 @@
1
1
  """A module for the RAG (Retrieval Augmented Generation) model."""
2
2
 
3
+ try:
4
+ from pymilvus import MilvusClient
5
+ except ImportError as e:
6
+ raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
3
7
  from functools import lru_cache
4
8
  from operator import itemgetter
5
9
  from os import PathLike
6
10
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
8
12
 
9
- from fabricatio import template_manager
13
+ from fabricatio._rust_instances import template_manager
10
14
  from fabricatio.config import configs
11
- from fabricatio.models.kwargs_types import LLMKwargs
12
- from fabricatio.models.usages import LLMUsage
15
+ from fabricatio.journal import logger
16
+ from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
17
+ from fabricatio.models.usages import EmbeddingUsage
13
18
  from fabricatio.models.utils import MilvusData
14
19
  from more_itertools.recipes import flatten
15
-
16
- try:
17
- from pymilvus import MilvusClient
18
- except ImportError as e:
19
- raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
20
20
  from pydantic import Field, PrivateAttr
21
21
 
22
22
 
23
23
  @lru_cache(maxsize=None)
24
- def create_client(
25
- uri: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None
26
- ) -> MilvusClient:
24
+ def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
27
25
  """Create a Milvus client."""
28
26
  return MilvusClient(
29
- uri=uri or configs.rag.milvus_uri.unicode_string(),
30
- token=token or configs.rag.milvus_token.get_secret_value() if configs.rag.milvus_token else "",
31
- timeout=timeout or configs.rag.milvus_timeout,
27
+ uri=uri,
28
+ token=token,
29
+ timeout=timeout,
32
30
  )
33
31
 
34
32
 
35
- class Rag(LLMUsage):
33
+ class RAG(EmbeddingUsage):
36
34
  """A class representing the RAG (Retrieval Augmented Generation) model."""
37
35
 
38
- milvus_uri: Optional[str] = Field(default=None, frozen=True)
39
- """The URI of the Milvus server."""
40
- milvus_token: Optional[str] = Field(default=None, frozen=True)
41
- """The token for the Milvus server."""
42
- milvus_timeout: Optional[float] = Field(default=None, frozen=True)
43
- """The timeout for the Milvus server."""
44
36
  target_collection: Optional[str] = Field(default=None)
45
37
  """The name of the collection being viewed."""
46
38
 
47
- _client: MilvusClient = PrivateAttr(None)
39
+ _client: Optional[MilvusClient] = PrivateAttr(None)
48
40
  """The Milvus client used for the RAG model."""
49
41
 
50
42
  @property
51
43
  def client(self) -> MilvusClient:
52
44
  """Return the Milvus client."""
45
+ if self._client is None:
46
+ raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
53
47
  return self._client
54
48
 
55
- def model_post_init(self, __context: Any) -> None:
56
- """Initialize the RAG model by creating the collection if it does not exist."""
57
- self._client = create_client(self.milvus_uri, self.milvus_token, self.milvus_timeout)
58
- self.view(self.target_collection, create=True)
49
+ def init_client(
50
+ self,
51
+ milvus_uri: Optional[str] = None,
52
+ milvus_token: Optional[str] = None,
53
+ milvus_timeout: Optional[float] = None,
54
+ ) -> Self:
55
+ """Initialize the Milvus client."""
56
+ self._client = create_client(
57
+ uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
58
+ token=milvus_token
59
+ or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
60
+ timeout=milvus_timeout or self.milvus_timeout,
61
+ )
62
+ return self
63
+
64
+ @overload
65
+ async def pack(
66
+ self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
67
+ ) -> List[MilvusData]: ...
68
+ @overload
69
+ async def pack(
70
+ self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
71
+ ) -> MilvusData: ...
72
+
73
+ async def pack(
74
+ self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
75
+ ) -> List[MilvusData] | MilvusData:
76
+ """Asynchronously generates MilvusData objects for the given input text.
59
77
 
60
- def view(self, collection_name: Optional[str], create: bool = False) -> Self:
78
+ Args:
79
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
80
+ subject (Optional[str]): The subject of the input text. Defaults to None.
81
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
82
+
83
+ Returns:
84
+ List[MilvusData] | MilvusData: The generated MilvusData objects.
85
+ """
86
+ if isinstance(input_text, str):
87
+ return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
88
+ vecs = await self.vectorize(input_text, **kwargs)
89
+ return [
90
+ MilvusData(
91
+ vector=vec,
92
+ text=text,
93
+ subject=subject,
94
+ )
95
+ for text, vec in zip(input_text, vecs, strict=True)
96
+ ]
97
+
98
+ def view(
99
+ self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
100
+ ) -> Self:
61
101
  """View the specified collection.
62
102
 
63
103
  Args:
64
104
  collection_name (str): The name of the collection.
65
105
  create (bool): Whether to create the collection if it does not exist.
106
+ **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
66
107
  """
67
108
  if create and collection_name and not self._client.has_collection(collection_name):
68
- self._client.create_collection(collection_name)
109
+ kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
110
+ self._client.create_collection(collection_name, auto_id=True, **kwargs)
111
+ logger.info(f"Creating collection {collection_name}")
69
112
 
70
113
  self.target_collection = collection_name
71
114
  return self
@@ -90,13 +133,14 @@ class Rag(LLMUsage):
90
133
  return self.target_collection
91
134
 
92
135
  def add_document[D: Union[Dict[str, Any], MilvusData]](
93
- self, data: D | List[D], collection_name: Optional[str] = None
136
+ self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
94
137
  ) -> Self:
95
138
  """Adds a document to the specified collection.
96
139
 
97
140
  Args:
98
141
  data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
99
142
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
143
+ flush (bool): Whether to flush the collection after insertion.
100
144
 
101
145
  Returns:
102
146
  Self: The current instance, allowing for method chaining.
@@ -105,11 +149,19 @@ class Rag(LLMUsage):
105
149
  data = data.prepare_insertion()
106
150
  if isinstance(data, list):
107
151
  data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
108
- self._client.insert(collection_name or self.safe_target_collection, data)
152
+ c_name = collection_name or self.safe_target_collection
153
+ self._client.insert(c_name, data)
154
+
155
+ if flush:
156
+ logger.debug(f"Flushing collection {c_name}")
157
+ self._client.flush(c_name)
109
158
  return self
110
159
 
111
- def consume(
112
- self, source: PathLike, reader: Callable[[PathLike], MilvusData], collection_name: Optional[str] = None
160
+ async def consume_file(
161
+ self,
162
+ source: List[PathLike] | PathLike,
163
+ reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
164
+ collection_name: Optional[str] = None,
113
165
  ) -> Self:
114
166
  """Consume a file and add its content to the collection.
115
167
 
@@ -121,8 +173,21 @@ class Rag(LLMUsage):
121
173
  Returns:
122
174
  Self: The current instance, allowing for method chaining.
123
175
  """
124
- data = reader(Path(source))
125
- self.add_document(data, collection_name or self.safe_target_collection)
176
+ if not isinstance(source, list):
177
+ source = [source]
178
+ return await self.consume_string([reader(s) for s in source], collection_name)
179
+
180
+ async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
181
+ """Consume a string and add it to the collection.
182
+
183
+ Args:
184
+ text (List[str] | str): The text to be added to the collection.
185
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
186
+
187
+ Returns:
188
+ Self: The current instance, allowing for method chaining.
189
+ """
190
+ self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
126
191
  return self
127
192
 
128
193
  async def afetch_document(
@@ -130,6 +195,7 @@ class Rag(LLMUsage):
130
195
  vecs: List[List[float]],
131
196
  desired_fields: List[str] | str,
132
197
  collection_name: Optional[str] = None,
198
+ similarity_threshold: float = 0.37,
133
199
  result_per_query: int = 10,
134
200
  ) -> List[Dict[str, Any]] | List[Any]:
135
201
  """Fetch data from the collection.
@@ -138,6 +204,7 @@ class Rag(LLMUsage):
138
204
  vecs (List[List[float]]): The vectors to search for.
139
205
  desired_fields (List[str] | str): The fields to retrieve.
140
206
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
207
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
141
208
  result_per_query (int): The number of results to return per query.
142
209
 
143
210
  Returns:
@@ -147,6 +214,7 @@ class Rag(LLMUsage):
147
214
  search_results = self._client.search(
148
215
  collection_name or self.safe_target_collection,
149
216
  vecs,
217
+ search_params={"radius": similarity_threshold},
150
218
  output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
151
219
  limit=result_per_query,
152
220
  )
@@ -157,6 +225,7 @@ class Rag(LLMUsage):
157
225
  # Step 3: Sort by distance (descending)
158
226
  sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
159
227
 
228
+ logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
160
229
  # Step 4: Extract the entities
161
230
  resp = [result["entity"] for result in sorted_results]
162
231
 
@@ -168,27 +237,29 @@ class Rag(LLMUsage):
168
237
  self,
169
238
  query: List[str] | str,
170
239
  collection_name: Optional[str] = None,
171
- result_per_query: int = 10,
172
240
  final_limit: int = 20,
241
+ **kwargs: Unpack[FetchKwargs],
173
242
  ) -> List[str]:
174
243
  """Retrieve data from the collection.
175
244
 
176
245
  Args:
177
246
  query (List[str] | str): The query to be used for retrieval.
178
247
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
179
- result_per_query (int): The number of results to be returned per query.
180
248
  final_limit (int): The final limit on the number of results to return.
249
+ **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
181
250
 
182
251
  Returns:
183
252
  List[str]: A list of strings containing the retrieved data.
184
253
  """
185
254
  if isinstance(query, str):
186
255
  query = [query]
187
- return await self.afetch_document(
188
- vecs=(await self.vectorize(query)),
189
- desired_fields="text",
190
- collection_name=collection_name,
191
- result_per_query=result_per_query,
256
+ return (
257
+ await self.afetch_document(
258
+ vecs=(await self.vectorize(query)),
259
+ desired_fields="text",
260
+ collection_name=collection_name,
261
+ **kwargs,
262
+ )
192
263
  )[:final_limit]
193
264
 
194
265
  async def aask_retrieved(
@@ -196,8 +267,10 @@ class Rag(LLMUsage):
196
267
  question: str | List[str],
197
268
  query: List[str] | str,
198
269
  collection_name: Optional[str] = None,
270
+ extra_system_message: str = "",
199
271
  result_per_query: int = 10,
200
272
  final_limit: int = 20,
273
+ similarity_threshold: float = 0.37,
201
274
  **kwargs: Unpack[LLMKwargs],
202
275
  ) -> str:
203
276
  """Asks a question by retrieving relevant documents based on the provided query.
@@ -210,16 +283,28 @@ class Rag(LLMUsage):
210
283
  query (List[str] | str): The query or list of queries used for document retrieval.
211
284
  collection_name (Optional[str]): The name of the collection to retrieve documents from.
212
285
  If not provided, the currently viewed collection is used.
286
+ extra_system_message (str): An additional system message to be included in the prompt.
213
287
  result_per_query (int): The number of results to return per query. Default is 10.
214
288
  final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
289
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
215
290
  **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
216
291
 
217
292
  Returns:
218
293
  str: A string response generated after asking with the context of retrieved documents.
219
294
  """
220
- docs = await self.aretrieve(query, collection_name, result_per_query, final_limit)
295
+ docs = await self.aretrieve(
296
+ query,
297
+ collection_name,
298
+ final_limit,
299
+ result_per_query=result_per_query,
300
+ similarity_threshold=similarity_threshold,
301
+ )
302
+
303
+ rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
304
+
305
+ logger.debug(f"Retrieved Documents: \n{rendered}")
221
306
  return await self.aask(
222
307
  question,
223
- template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs}),
308
+ f"{rendered}\n\n{extra_system_message}",
224
309
  **kwargs,
225
310
  )
@@ -131,8 +131,7 @@ class GiveRating(WithBriefing, LLMUsage):
131
131
 
132
132
  def _validator(response: str) -> Dict[str, str] | None:
133
133
  if (
134
- (json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
135
- and isinstance(json_data, dict)
134
+ (json_data := JsonCapture.validate_with(response, target_type=dict, elements_type=str)) is not None
136
135
  and json_data.keys() == criteria
137
136
  and all(isinstance(v, str) for v in json_data.values())
138
137
  ):
@@ -173,11 +172,10 @@ class GiveRating(WithBriefing, LLMUsage):
173
172
 
174
173
  def _validator(response: str) -> Set[str] | None:
175
174
  if (
176
- (json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
177
- and isinstance(json_data, list)
178
- and all(isinstance(v, str) for v in json_data)
179
- and (criteria_count == 0 or len(json_data) == criteria_count)
180
- ):
175
+ json_data := JsonCapture.validate_with(
176
+ response, target_type=list, elements_type=str, length=criteria_count
177
+ )
178
+ ) is not None:
181
179
  return set(json_data)
182
180
  return None
183
181
 
@@ -219,27 +217,6 @@ class GiveRating(WithBriefing, LLMUsage):
219
217
  Returns:
220
218
  Set[str]: A set of drafted rating criteria.
221
219
  """
222
-
223
- def _reasons_validator(response: str) -> List[str] | None:
224
- if (
225
- (json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
226
- and isinstance(json_data, list)
227
- and all(isinstance(v, str) for v in json_data)
228
- and len(json_data) == reasons_count
229
- ):
230
- return json_data
231
- return None
232
-
233
- def _criteria_validator(response: str) -> Set[str] | None:
234
- if (
235
- (json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
236
- and isinstance(json_data, list)
237
- and all(isinstance(v, str) for v in json_data)
238
- and len(json_data) == criteria_count
239
- ):
240
- return set(json_data)
241
- return None
242
-
243
220
  kwargs = GenerateKwargs(system_message=f"# your personal briefing: \n{self.briefing}", **kwargs)
244
221
  # extract reasons from the comparison of ordered pairs of extracted from examples
245
222
  reasons = flatten(
@@ -256,7 +233,9 @@ class GiveRating(WithBriefing, LLMUsage):
256
233
  )
257
234
  for pair in (permutations(examples, 2))
258
235
  ],
259
- validator=_reasons_validator,
236
+ validator=lambda resp: JsonCapture.validate_with(
237
+ resp, target_type=list, elements_type=str, length=reasons_count
238
+ ),
260
239
  **kwargs,
261
240
  )
262
241
  )
@@ -272,7 +251,9 @@ class GiveRating(WithBriefing, LLMUsage):
272
251
  },
273
252
  )
274
253
  ),
275
- validator=_criteria_validator,
254
+ validator=lambda resp: set(out)
255
+ if (out := JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=criteria_count))
256
+ else None,
276
257
  **kwargs,
277
258
  )
278
259
 
@@ -295,11 +276,6 @@ class GiveRating(WithBriefing, LLMUsage):
295
276
  if len(criteria) < 2: # noqa: PLR2004
296
277
  raise ValueError("At least two criteria are required to draft rating weights")
297
278
 
298
- def _validator(resp: str) -> float | None:
299
- if (cap := JsonCapture.convert_with(resp, orjson.loads)) is not None and isinstance(cap, float):
300
- return cap
301
- return None
302
-
303
279
  criteria = list(criteria) # freeze the order
304
280
  windows = windowed(criteria, 2)
305
281
 
@@ -316,7 +292,7 @@ class GiveRating(WithBriefing, LLMUsage):
316
292
  )
317
293
  for pair in windows
318
294
  ],
319
- validator=_validator,
295
+ validator=lambda resp: JsonCapture.validate_with(resp, target_type=float),
320
296
  **GenerateKwargs(system_message=f"# your personal briefing: \n{self.briefing}", **kwargs),
321
297
  )
322
298
  weights = [1]
@@ -5,21 +5,21 @@ from typing import Any, Dict, List, Optional, Tuple, Unpack
5
5
 
6
6
  import orjson
7
7
  from fabricatio._rust_instances import template_manager
8
+ from fabricatio.capabilities.propose import Propose
8
9
  from fabricatio.config import configs
9
10
  from fabricatio.models.generic import WithBriefing
10
11
  from fabricatio.models.kwargs_types import ChooseKwargs, ValidateKwargs
11
12
  from fabricatio.models.task import Task
12
13
  from fabricatio.models.tool import Tool, ToolExecutor
13
- from fabricatio.models.usages import LLMUsage, ToolBoxUsage
14
+ from fabricatio.models.usages import ToolBoxUsage
14
15
  from fabricatio.parser import JsonCapture, PythonCapture
15
16
  from loguru import logger
16
- from pydantic import ValidationError
17
17
 
18
18
 
19
- class ProposeTask(WithBriefing, LLMUsage):
19
+ class ProposeTask(WithBriefing, Propose):
20
20
  """A class that proposes a task based on a prompt."""
21
21
 
22
- async def propose[T](
22
+ async def propose_task[T](
23
23
  self,
24
24
  prompt: str,
25
25
  **kwargs: Unpack[ValidateKwargs],
@@ -34,27 +34,10 @@ class ProposeTask(WithBriefing, LLMUsage):
34
34
  A Task object based on the proposal result.
35
35
  """
36
36
  if not prompt:
37
- err = f"{self.name}: Prompt must be provided."
38
- logger.error(err)
37
+ logger.error(err := f"{self.name}: Prompt must be provided.")
39
38
  raise ValueError(err)
40
39
 
41
- def _validate_json(response: str) -> None | Task:
42
- try:
43
- cap = JsonCapture.capture(response)
44
- logger.debug(f"Response: \n{response}")
45
- logger.info(f"Captured JSON: \n{cap}")
46
- return Task.model_validate_json(cap)
47
- except ValidationError as e:
48
- logger.error(f"Failed to parse task from JSON: {e}")
49
- return None
50
-
51
- template_data = {"prompt": prompt, "json_example": Task.json_example()}
52
- return await self.aask_validate(
53
- question=template_manager.render_template(configs.templates.propose_task_template, template_data),
54
- validator=_validate_json,
55
- system_message=f"# your personal briefing: \n{self.briefing}",
56
- **kwargs,
57
- )
40
+ return await self.propose(Task, prompt, system_message=f"# your personal briefing: \n{self.briefing}", **kwargs)
58
41
 
59
42
 
60
43
  class HandleTask(WithBriefing, ToolBoxUsage):