fabricatio 0.2.3.dev3__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.4.dev1__cp312-cp312-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +6 -2
- fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
- fabricatio/actions/__init__.py +2 -2
- fabricatio/actions/article.py +44 -0
- fabricatio/capabilities/propose.py +55 -0
- fabricatio/capabilities/rag.py +129 -44
- fabricatio/capabilities/rating.py +12 -36
- fabricatio/capabilities/task.py +6 -23
- fabricatio/config.py +37 -2
- fabricatio/models/action.py +3 -3
- fabricatio/models/events.py +36 -0
- fabricatio/models/extra.py +96 -0
- fabricatio/models/generic.py +194 -7
- fabricatio/models/kwargs_types.py +14 -0
- fabricatio/models/task.py +5 -23
- fabricatio/models/usages.py +117 -184
- fabricatio/models/utils.py +19 -0
- fabricatio/parser.py +35 -8
- fabricatio-0.2.4.dev1.data/scripts/tdown +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/METADATA +65 -177
- fabricatio-0.2.4.dev1.dist-info/RECORD +38 -0
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio-0.2.3.dev3.data/scripts/tdown +0 -0
- fabricatio-0.2.3.dev3.dist-info/RECORD +0 -37
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -3,11 +3,13 @@
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
5
5
|
from fabricatio._rust_instances import template_manager
|
6
|
+
from fabricatio.actions import ExtractArticleEssence
|
6
7
|
from fabricatio.core import env
|
7
8
|
from fabricatio.fs import magika
|
8
9
|
from fabricatio.journal import logger
|
9
10
|
from fabricatio.models.action import Action, WorkFlow
|
10
11
|
from fabricatio.models.events import Event
|
12
|
+
from fabricatio.models.extra import ArticleEssence
|
11
13
|
from fabricatio.models.role import Role
|
12
14
|
from fabricatio.models.task import Task
|
13
15
|
from fabricatio.models.tool import ToolBox
|
@@ -17,9 +19,11 @@ from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
|
17
19
|
|
18
20
|
__all__ = [
|
19
21
|
"Action",
|
22
|
+
"ArticleEssence",
|
20
23
|
"Capture",
|
21
24
|
"CodeBlockCapture",
|
22
25
|
"Event",
|
26
|
+
"ExtractArticleEssence",
|
23
27
|
"JsonCapture",
|
24
28
|
"Message",
|
25
29
|
"Messages",
|
@@ -40,6 +44,6 @@ __all__ = [
|
|
40
44
|
|
41
45
|
|
42
46
|
if find_spec("pymilvus"):
|
43
|
-
from fabricatio.capabilities.rag import
|
47
|
+
from fabricatio.capabilities.rag import RAG
|
44
48
|
|
45
|
-
__all__ += ["
|
49
|
+
__all__ += ["RAG"]
|
Binary file
|
fabricatio/actions/__init__.py
CHANGED
@@ -0,0 +1,44 @@
|
|
1
|
+
"""Actions for transmitting tasks to targets."""
|
2
|
+
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable, List
|
6
|
+
|
7
|
+
from fabricatio.journal import logger
|
8
|
+
from fabricatio.models.action import Action
|
9
|
+
from fabricatio.models.extra import ArticleEssence
|
10
|
+
from fabricatio.models.task import Task
|
11
|
+
|
12
|
+
|
13
|
+
class ExtractArticleEssence(Action):
|
14
|
+
"""Extract the essence of article(s)."""
|
15
|
+
|
16
|
+
name: str = "extract article essence"
|
17
|
+
"""The name of the action."""
|
18
|
+
description: str = "Extract the essence of an article. output as json"
|
19
|
+
"""The description of the action."""
|
20
|
+
|
21
|
+
output_key: str = "article_essence"
|
22
|
+
"""The key of the output data."""
|
23
|
+
|
24
|
+
async def _execute[P: PathLike | str](
|
25
|
+
self,
|
26
|
+
task_input: Task,
|
27
|
+
reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
28
|
+
**_,
|
29
|
+
) -> List[ArticleEssence]:
|
30
|
+
if not await self.ajudge(
|
31
|
+
f"= Task\n{task_input.briefing}\n\n\n= Role\n{self.briefing}",
|
32
|
+
affirm_case="The task does not violate the role, and could be approved since the file dependencies are specified.",
|
33
|
+
deny_case="The task does violate the role, and could not be approved.",
|
34
|
+
):
|
35
|
+
logger.info(err := "Task not approved.")
|
36
|
+
raise RuntimeError(err)
|
37
|
+
|
38
|
+
# trim the references
|
39
|
+
contents = ["References".join(c.split("References")[:-1]) for c in map(reader, task_input.dependencies)]
|
40
|
+
return await self.propose(
|
41
|
+
ArticleEssence,
|
42
|
+
contents,
|
43
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
44
|
+
)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""A module for the task capabilities of the Fabricatio library."""
|
2
|
+
|
3
|
+
from typing import List, Type, Unpack, overload
|
4
|
+
|
5
|
+
from fabricatio.models.generic import ProposedAble
|
6
|
+
from fabricatio.models.kwargs_types import GenerateKwargs
|
7
|
+
from fabricatio.models.usages import LLMUsage
|
8
|
+
|
9
|
+
|
10
|
+
class Propose[M: ProposedAble](LLMUsage):
|
11
|
+
"""A class that proposes an Obj based on a prompt."""
|
12
|
+
|
13
|
+
@overload
|
14
|
+
async def propose(
|
15
|
+
self,
|
16
|
+
cls: Type[M],
|
17
|
+
prompt: List[str],
|
18
|
+
**kwargs: Unpack[GenerateKwargs],
|
19
|
+
) -> List[M]: ...
|
20
|
+
|
21
|
+
@overload
|
22
|
+
async def propose(
|
23
|
+
self,
|
24
|
+
cls: Type[M],
|
25
|
+
prompt: str,
|
26
|
+
**kwargs: Unpack[GenerateKwargs],
|
27
|
+
) -> M: ...
|
28
|
+
|
29
|
+
async def propose(
|
30
|
+
self,
|
31
|
+
cls: Type[M],
|
32
|
+
prompt: List[str] | str,
|
33
|
+
**kwargs: Unpack[GenerateKwargs],
|
34
|
+
) -> List[M] | M:
|
35
|
+
"""Asynchronously proposes a task based on a given prompt and parameters.
|
36
|
+
|
37
|
+
Parameters:
|
38
|
+
cls: The class type of the task to be proposed.
|
39
|
+
prompt: The prompt text for proposing a task, which is a string that must be provided.
|
40
|
+
**kwargs: The keyword arguments for the LLM (Large Language Model) usage.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
A Task object based on the proposal result.
|
44
|
+
"""
|
45
|
+
if isinstance(prompt, str):
|
46
|
+
return await self.aask_validate(
|
47
|
+
question=cls.create_json_prompt(prompt),
|
48
|
+
validator=cls.instantiate_from_string,
|
49
|
+
**kwargs,
|
50
|
+
)
|
51
|
+
return await self.aask_validate_batch(
|
52
|
+
questions=[cls.create_json_prompt(p) for p in prompt],
|
53
|
+
validator=cls.instantiate_from_string,
|
54
|
+
**kwargs,
|
55
|
+
)
|
fabricatio/capabilities/rag.py
CHANGED
@@ -1,71 +1,114 @@
|
|
1
1
|
"""A module for the RAG (Retrieval Augmented Generation) model."""
|
2
2
|
|
3
|
+
try:
|
4
|
+
from pymilvus import MilvusClient
|
5
|
+
except ImportError as e:
|
6
|
+
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
3
7
|
from functools import lru_cache
|
4
8
|
from operator import itemgetter
|
5
9
|
from os import PathLike
|
6
10
|
from pathlib import Path
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
|
8
12
|
|
9
|
-
from fabricatio import template_manager
|
13
|
+
from fabricatio._rust_instances import template_manager
|
10
14
|
from fabricatio.config import configs
|
11
|
-
from fabricatio.
|
12
|
-
from fabricatio.models.
|
15
|
+
from fabricatio.journal import logger
|
16
|
+
from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
|
17
|
+
from fabricatio.models.usages import EmbeddingUsage
|
13
18
|
from fabricatio.models.utils import MilvusData
|
14
19
|
from more_itertools.recipes import flatten
|
15
|
-
|
16
|
-
try:
|
17
|
-
from pymilvus import MilvusClient
|
18
|
-
except ImportError as e:
|
19
|
-
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
20
20
|
from pydantic import Field, PrivateAttr
|
21
21
|
|
22
22
|
|
23
23
|
@lru_cache(maxsize=None)
|
24
|
-
def create_client(
|
25
|
-
uri: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None
|
26
|
-
) -> MilvusClient:
|
24
|
+
def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
|
27
25
|
"""Create a Milvus client."""
|
28
26
|
return MilvusClient(
|
29
|
-
uri=uri
|
30
|
-
token=token
|
31
|
-
timeout=timeout
|
27
|
+
uri=uri,
|
28
|
+
token=token,
|
29
|
+
timeout=timeout,
|
32
30
|
)
|
33
31
|
|
34
32
|
|
35
|
-
class
|
33
|
+
class RAG(EmbeddingUsage):
|
36
34
|
"""A class representing the RAG (Retrieval Augmented Generation) model."""
|
37
35
|
|
38
|
-
milvus_uri: Optional[str] = Field(default=None, frozen=True)
|
39
|
-
"""The URI of the Milvus server."""
|
40
|
-
milvus_token: Optional[str] = Field(default=None, frozen=True)
|
41
|
-
"""The token for the Milvus server."""
|
42
|
-
milvus_timeout: Optional[float] = Field(default=None, frozen=True)
|
43
|
-
"""The timeout for the Milvus server."""
|
44
36
|
target_collection: Optional[str] = Field(default=None)
|
45
37
|
"""The name of the collection being viewed."""
|
46
38
|
|
47
|
-
_client: MilvusClient = PrivateAttr(None)
|
39
|
+
_client: Optional[MilvusClient] = PrivateAttr(None)
|
48
40
|
"""The Milvus client used for the RAG model."""
|
49
41
|
|
50
42
|
@property
|
51
43
|
def client(self) -> MilvusClient:
|
52
44
|
"""Return the Milvus client."""
|
45
|
+
if self._client is None:
|
46
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
53
47
|
return self._client
|
54
48
|
|
55
|
-
def
|
56
|
-
|
57
|
-
|
58
|
-
|
49
|
+
def init_client(
|
50
|
+
self,
|
51
|
+
milvus_uri: Optional[str] = None,
|
52
|
+
milvus_token: Optional[str] = None,
|
53
|
+
milvus_timeout: Optional[float] = None,
|
54
|
+
) -> Self:
|
55
|
+
"""Initialize the Milvus client."""
|
56
|
+
self._client = create_client(
|
57
|
+
uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
58
|
+
token=milvus_token
|
59
|
+
or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
|
60
|
+
timeout=milvus_timeout or self.milvus_timeout,
|
61
|
+
)
|
62
|
+
return self
|
63
|
+
|
64
|
+
@overload
|
65
|
+
async def pack(
|
66
|
+
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
67
|
+
) -> List[MilvusData]: ...
|
68
|
+
@overload
|
69
|
+
async def pack(
|
70
|
+
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
71
|
+
) -> MilvusData: ...
|
72
|
+
|
73
|
+
async def pack(
|
74
|
+
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
75
|
+
) -> List[MilvusData] | MilvusData:
|
76
|
+
"""Asynchronously generates MilvusData objects for the given input text.
|
59
77
|
|
60
|
-
|
78
|
+
Args:
|
79
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
80
|
+
subject (Optional[str]): The subject of the input text. Defaults to None.
|
81
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
85
|
+
"""
|
86
|
+
if isinstance(input_text, str):
|
87
|
+
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
88
|
+
vecs = await self.vectorize(input_text, **kwargs)
|
89
|
+
return [
|
90
|
+
MilvusData(
|
91
|
+
vector=vec,
|
92
|
+
text=text,
|
93
|
+
subject=subject,
|
94
|
+
)
|
95
|
+
for text, vec in zip(input_text, vecs, strict=True)
|
96
|
+
]
|
97
|
+
|
98
|
+
def view(
|
99
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
|
100
|
+
) -> Self:
|
61
101
|
"""View the specified collection.
|
62
102
|
|
63
103
|
Args:
|
64
104
|
collection_name (str): The name of the collection.
|
65
105
|
create (bool): Whether to create the collection if it does not exist.
|
106
|
+
**kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
|
66
107
|
"""
|
67
108
|
if create and collection_name and not self._client.has_collection(collection_name):
|
68
|
-
self.
|
109
|
+
kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
|
110
|
+
self._client.create_collection(collection_name, auto_id=True, **kwargs)
|
111
|
+
logger.info(f"Creating collection {collection_name}")
|
69
112
|
|
70
113
|
self.target_collection = collection_name
|
71
114
|
return self
|
@@ -90,13 +133,14 @@ class Rag(LLMUsage):
|
|
90
133
|
return self.target_collection
|
91
134
|
|
92
135
|
def add_document[D: Union[Dict[str, Any], MilvusData]](
|
93
|
-
self, data: D | List[D], collection_name: Optional[str] = None
|
136
|
+
self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
|
94
137
|
) -> Self:
|
95
138
|
"""Adds a document to the specified collection.
|
96
139
|
|
97
140
|
Args:
|
98
141
|
data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
|
99
142
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
143
|
+
flush (bool): Whether to flush the collection after insertion.
|
100
144
|
|
101
145
|
Returns:
|
102
146
|
Self: The current instance, allowing for method chaining.
|
@@ -105,11 +149,19 @@ class Rag(LLMUsage):
|
|
105
149
|
data = data.prepare_insertion()
|
106
150
|
if isinstance(data, list):
|
107
151
|
data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
|
108
|
-
|
152
|
+
c_name = collection_name or self.safe_target_collection
|
153
|
+
self._client.insert(c_name, data)
|
154
|
+
|
155
|
+
if flush:
|
156
|
+
logger.debug(f"Flushing collection {c_name}")
|
157
|
+
self._client.flush(c_name)
|
109
158
|
return self
|
110
159
|
|
111
|
-
def
|
112
|
-
self,
|
160
|
+
async def consume_file(
|
161
|
+
self,
|
162
|
+
source: List[PathLike] | PathLike,
|
163
|
+
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
164
|
+
collection_name: Optional[str] = None,
|
113
165
|
) -> Self:
|
114
166
|
"""Consume a file and add its content to the collection.
|
115
167
|
|
@@ -121,8 +173,21 @@ class Rag(LLMUsage):
|
|
121
173
|
Returns:
|
122
174
|
Self: The current instance, allowing for method chaining.
|
123
175
|
"""
|
124
|
-
|
125
|
-
|
176
|
+
if not isinstance(source, list):
|
177
|
+
source = [source]
|
178
|
+
return await self.consume_string([reader(s) for s in source], collection_name)
|
179
|
+
|
180
|
+
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
181
|
+
"""Consume a string and add it to the collection.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
text (List[str] | str): The text to be added to the collection.
|
185
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Self: The current instance, allowing for method chaining.
|
189
|
+
"""
|
190
|
+
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
126
191
|
return self
|
127
192
|
|
128
193
|
async def afetch_document(
|
@@ -130,6 +195,7 @@ class Rag(LLMUsage):
|
|
130
195
|
vecs: List[List[float]],
|
131
196
|
desired_fields: List[str] | str,
|
132
197
|
collection_name: Optional[str] = None,
|
198
|
+
similarity_threshold: float = 0.37,
|
133
199
|
result_per_query: int = 10,
|
134
200
|
) -> List[Dict[str, Any]] | List[Any]:
|
135
201
|
"""Fetch data from the collection.
|
@@ -138,6 +204,7 @@ class Rag(LLMUsage):
|
|
138
204
|
vecs (List[List[float]]): The vectors to search for.
|
139
205
|
desired_fields (List[str] | str): The fields to retrieve.
|
140
206
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
207
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
141
208
|
result_per_query (int): The number of results to return per query.
|
142
209
|
|
143
210
|
Returns:
|
@@ -147,6 +214,7 @@ class Rag(LLMUsage):
|
|
147
214
|
search_results = self._client.search(
|
148
215
|
collection_name or self.safe_target_collection,
|
149
216
|
vecs,
|
217
|
+
search_params={"radius": similarity_threshold},
|
150
218
|
output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
|
151
219
|
limit=result_per_query,
|
152
220
|
)
|
@@ -157,6 +225,7 @@ class Rag(LLMUsage):
|
|
157
225
|
# Step 3: Sort by distance (descending)
|
158
226
|
sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
|
159
227
|
|
228
|
+
logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
|
160
229
|
# Step 4: Extract the entities
|
161
230
|
resp = [result["entity"] for result in sorted_results]
|
162
231
|
|
@@ -168,27 +237,29 @@ class Rag(LLMUsage):
|
|
168
237
|
self,
|
169
238
|
query: List[str] | str,
|
170
239
|
collection_name: Optional[str] = None,
|
171
|
-
result_per_query: int = 10,
|
172
240
|
final_limit: int = 20,
|
241
|
+
**kwargs: Unpack[FetchKwargs],
|
173
242
|
) -> List[str]:
|
174
243
|
"""Retrieve data from the collection.
|
175
244
|
|
176
245
|
Args:
|
177
246
|
query (List[str] | str): The query to be used for retrieval.
|
178
247
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
179
|
-
result_per_query (int): The number of results to be returned per query.
|
180
248
|
final_limit (int): The final limit on the number of results to return.
|
249
|
+
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
181
250
|
|
182
251
|
Returns:
|
183
252
|
List[str]: A list of strings containing the retrieved data.
|
184
253
|
"""
|
185
254
|
if isinstance(query, str):
|
186
255
|
query = [query]
|
187
|
-
return
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
256
|
+
return (
|
257
|
+
await self.afetch_document(
|
258
|
+
vecs=(await self.vectorize(query)),
|
259
|
+
desired_fields="text",
|
260
|
+
collection_name=collection_name,
|
261
|
+
**kwargs,
|
262
|
+
)
|
192
263
|
)[:final_limit]
|
193
264
|
|
194
265
|
async def aask_retrieved(
|
@@ -196,8 +267,10 @@ class Rag(LLMUsage):
|
|
196
267
|
question: str | List[str],
|
197
268
|
query: List[str] | str,
|
198
269
|
collection_name: Optional[str] = None,
|
270
|
+
extra_system_message: str = "",
|
199
271
|
result_per_query: int = 10,
|
200
272
|
final_limit: int = 20,
|
273
|
+
similarity_threshold: float = 0.37,
|
201
274
|
**kwargs: Unpack[LLMKwargs],
|
202
275
|
) -> str:
|
203
276
|
"""Asks a question by retrieving relevant documents based on the provided query.
|
@@ -210,16 +283,28 @@ class Rag(LLMUsage):
|
|
210
283
|
query (List[str] | str): The query or list of queries used for document retrieval.
|
211
284
|
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
212
285
|
If not provided, the currently viewed collection is used.
|
286
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
213
287
|
result_per_query (int): The number of results to return per query. Default is 10.
|
214
288
|
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
289
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
215
290
|
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
216
291
|
|
217
292
|
Returns:
|
218
293
|
str: A string response generated after asking with the context of retrieved documents.
|
219
294
|
"""
|
220
|
-
docs = await self.aretrieve(
|
295
|
+
docs = await self.aretrieve(
|
296
|
+
query,
|
297
|
+
collection_name,
|
298
|
+
final_limit,
|
299
|
+
result_per_query=result_per_query,
|
300
|
+
similarity_threshold=similarity_threshold,
|
301
|
+
)
|
302
|
+
|
303
|
+
rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
|
304
|
+
|
305
|
+
logger.debug(f"Retrieved Documents: \n{rendered}")
|
221
306
|
return await self.aask(
|
222
307
|
question,
|
223
|
-
|
308
|
+
f"{rendered}\n\n{extra_system_message}",
|
224
309
|
**kwargs,
|
225
310
|
)
|
@@ -131,8 +131,7 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
131
131
|
|
132
132
|
def _validator(response: str) -> Dict[str, str] | None:
|
133
133
|
if (
|
134
|
-
(json_data := JsonCapture.
|
135
|
-
and isinstance(json_data, dict)
|
134
|
+
(json_data := JsonCapture.validate_with(response, target_type=dict, elements_type=str)) is not None
|
136
135
|
and json_data.keys() == criteria
|
137
136
|
and all(isinstance(v, str) for v in json_data.values())
|
138
137
|
):
|
@@ -173,11 +172,10 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
173
172
|
|
174
173
|
def _validator(response: str) -> Set[str] | None:
|
175
174
|
if (
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
):
|
175
|
+
json_data := JsonCapture.validate_with(
|
176
|
+
response, target_type=list, elements_type=str, length=criteria_count
|
177
|
+
)
|
178
|
+
) is not None:
|
181
179
|
return set(json_data)
|
182
180
|
return None
|
183
181
|
|
@@ -219,27 +217,6 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
219
217
|
Returns:
|
220
218
|
Set[str]: A set of drafted rating criteria.
|
221
219
|
"""
|
222
|
-
|
223
|
-
def _reasons_validator(response: str) -> List[str] | None:
|
224
|
-
if (
|
225
|
-
(json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
|
226
|
-
and isinstance(json_data, list)
|
227
|
-
and all(isinstance(v, str) for v in json_data)
|
228
|
-
and len(json_data) == reasons_count
|
229
|
-
):
|
230
|
-
return json_data
|
231
|
-
return None
|
232
|
-
|
233
|
-
def _criteria_validator(response: str) -> Set[str] | None:
|
234
|
-
if (
|
235
|
-
(json_data := JsonCapture.convert_with(response, orjson.loads)) is not None
|
236
|
-
and isinstance(json_data, list)
|
237
|
-
and all(isinstance(v, str) for v in json_data)
|
238
|
-
and len(json_data) == criteria_count
|
239
|
-
):
|
240
|
-
return set(json_data)
|
241
|
-
return None
|
242
|
-
|
243
220
|
kwargs = GenerateKwargs(system_message=f"# your personal briefing: \n{self.briefing}", **kwargs)
|
244
221
|
# extract reasons from the comparison of ordered pairs of extracted from examples
|
245
222
|
reasons = flatten(
|
@@ -256,7 +233,9 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
256
233
|
)
|
257
234
|
for pair in (permutations(examples, 2))
|
258
235
|
],
|
259
|
-
validator=
|
236
|
+
validator=lambda resp: JsonCapture.validate_with(
|
237
|
+
resp, target_type=list, elements_type=str, length=reasons_count
|
238
|
+
),
|
260
239
|
**kwargs,
|
261
240
|
)
|
262
241
|
)
|
@@ -272,7 +251,9 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
272
251
|
},
|
273
252
|
)
|
274
253
|
),
|
275
|
-
validator=
|
254
|
+
validator=lambda resp: set(out)
|
255
|
+
if (out := JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=criteria_count))
|
256
|
+
else None,
|
276
257
|
**kwargs,
|
277
258
|
)
|
278
259
|
|
@@ -295,11 +276,6 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
295
276
|
if len(criteria) < 2: # noqa: PLR2004
|
296
277
|
raise ValueError("At least two criteria are required to draft rating weights")
|
297
278
|
|
298
|
-
def _validator(resp: str) -> float | None:
|
299
|
-
if (cap := JsonCapture.convert_with(resp, orjson.loads)) is not None and isinstance(cap, float):
|
300
|
-
return cap
|
301
|
-
return None
|
302
|
-
|
303
279
|
criteria = list(criteria) # freeze the order
|
304
280
|
windows = windowed(criteria, 2)
|
305
281
|
|
@@ -316,7 +292,7 @@ class GiveRating(WithBriefing, LLMUsage):
|
|
316
292
|
)
|
317
293
|
for pair in windows
|
318
294
|
],
|
319
|
-
validator=
|
295
|
+
validator=lambda resp: JsonCapture.validate_with(resp, target_type=float),
|
320
296
|
**GenerateKwargs(system_message=f"# your personal briefing: \n{self.briefing}", **kwargs),
|
321
297
|
)
|
322
298
|
weights = [1]
|
fabricatio/capabilities/task.py
CHANGED
@@ -5,21 +5,21 @@ from typing import Any, Dict, List, Optional, Tuple, Unpack
|
|
5
5
|
|
6
6
|
import orjson
|
7
7
|
from fabricatio._rust_instances import template_manager
|
8
|
+
from fabricatio.capabilities.propose import Propose
|
8
9
|
from fabricatio.config import configs
|
9
10
|
from fabricatio.models.generic import WithBriefing
|
10
11
|
from fabricatio.models.kwargs_types import ChooseKwargs, ValidateKwargs
|
11
12
|
from fabricatio.models.task import Task
|
12
13
|
from fabricatio.models.tool import Tool, ToolExecutor
|
13
|
-
from fabricatio.models.usages import
|
14
|
+
from fabricatio.models.usages import ToolBoxUsage
|
14
15
|
from fabricatio.parser import JsonCapture, PythonCapture
|
15
16
|
from loguru import logger
|
16
|
-
from pydantic import ValidationError
|
17
17
|
|
18
18
|
|
19
|
-
class ProposeTask(WithBriefing,
|
19
|
+
class ProposeTask(WithBriefing, Propose):
|
20
20
|
"""A class that proposes a task based on a prompt."""
|
21
21
|
|
22
|
-
async def
|
22
|
+
async def propose_task[T](
|
23
23
|
self,
|
24
24
|
prompt: str,
|
25
25
|
**kwargs: Unpack[ValidateKwargs],
|
@@ -34,27 +34,10 @@ class ProposeTask(WithBriefing, LLMUsage):
|
|
34
34
|
A Task object based on the proposal result.
|
35
35
|
"""
|
36
36
|
if not prompt:
|
37
|
-
err
|
38
|
-
logger.error(err)
|
37
|
+
logger.error(err := f"{self.name}: Prompt must be provided.")
|
39
38
|
raise ValueError(err)
|
40
39
|
|
41
|
-
|
42
|
-
try:
|
43
|
-
cap = JsonCapture.capture(response)
|
44
|
-
logger.debug(f"Response: \n{response}")
|
45
|
-
logger.info(f"Captured JSON: \n{cap}")
|
46
|
-
return Task.model_validate_json(cap)
|
47
|
-
except ValidationError as e:
|
48
|
-
logger.error(f"Failed to parse task from JSON: {e}")
|
49
|
-
return None
|
50
|
-
|
51
|
-
template_data = {"prompt": prompt, "json_example": Task.json_example()}
|
52
|
-
return await self.aask_validate(
|
53
|
-
question=template_manager.render_template(configs.templates.propose_task_template, template_data),
|
54
|
-
validator=_validate_json,
|
55
|
-
system_message=f"# your personal briefing: \n{self.briefing}",
|
56
|
-
**kwargs,
|
57
|
-
)
|
40
|
+
return await self.propose(Task, prompt, system_message=f"# your personal briefing: \n{self.briefing}", **kwargs)
|
58
41
|
|
59
42
|
|
60
43
|
class HandleTask(WithBriefing, ToolBoxUsage):
|