fabricatio 0.2.3.dev2__cp312-cp312-win_amd64.whl → 0.2.4.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +10 -0
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/actions/__init__.py +2 -2
- fabricatio/actions/article.py +127 -0
- fabricatio/capabilities/propose.py +55 -0
- fabricatio/capabilities/rag.py +181 -50
- fabricatio/capabilities/task.py +6 -23
- fabricatio/config.py +40 -2
- fabricatio/models/action.py +1 -0
- fabricatio/models/events.py +36 -0
- fabricatio/models/generic.py +158 -7
- fabricatio/models/kwargs_types.py +14 -0
- fabricatio/models/task.py +12 -30
- fabricatio/models/usages.py +103 -162
- fabricatio/models/utils.py +19 -0
- fabricatio/parser.py +34 -3
- fabricatio-0.2.4.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dev0.dist-info}/METADATA +40 -148
- fabricatio-0.2.4.dev0.dist-info/RECORD +37 -0
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio-0.2.3.dev2.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.3.dev2.dist-info/RECORD +0 -37
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dev0.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
"""Fabricatio is a Python library for building llm app using event-based agent structure."""
|
2
2
|
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
3
5
|
from fabricatio._rust_instances import template_manager
|
6
|
+
from fabricatio.actions import ExtractArticleEssence
|
4
7
|
from fabricatio.core import env
|
5
8
|
from fabricatio.fs import magika
|
6
9
|
from fabricatio.journal import logger
|
@@ -18,6 +21,7 @@ __all__ = [
|
|
18
21
|
"Capture",
|
19
22
|
"CodeBlockCapture",
|
20
23
|
"Event",
|
24
|
+
"ExtractArticleEssence",
|
21
25
|
"JsonCapture",
|
22
26
|
"Message",
|
23
27
|
"Messages",
|
@@ -35,3 +39,9 @@ __all__ = [
|
|
35
39
|
"task_toolbox",
|
36
40
|
"template_manager",
|
37
41
|
]
|
42
|
+
|
43
|
+
|
44
|
+
if find_spec("pymilvus"):
|
45
|
+
from fabricatio.capabilities.rag import RAG
|
46
|
+
|
47
|
+
__all__ += ["RAG"]
|
Binary file
|
fabricatio/actions/__init__.py
CHANGED
@@ -0,0 +1,127 @@
|
|
1
|
+
"""Actions for transmitting tasks to targets."""
|
2
|
+
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable, List
|
6
|
+
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from pydantic.config import ConfigDict
|
9
|
+
|
10
|
+
from fabricatio.journal import logger
|
11
|
+
from fabricatio.models.action import Action
|
12
|
+
from fabricatio.models.generic import ProposedAble
|
13
|
+
from fabricatio.models.task import Task
|
14
|
+
|
15
|
+
|
16
|
+
class Equation(BaseModel):
|
17
|
+
"""Structured representation of mathematical equations (including their physical or conceptual meanings)."""
|
18
|
+
|
19
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
20
|
+
|
21
|
+
description: str = Field(...)
|
22
|
+
"""A concise explanation of the equation's meaning, purpose, and relevance in the context of the research."""
|
23
|
+
|
24
|
+
latex_code: str = Field(...)
|
25
|
+
"""The LaTeX code used to represent the equation in a publication-ready format."""
|
26
|
+
|
27
|
+
|
28
|
+
class Figure(BaseModel):
|
29
|
+
"""Structured representation of figures (including their academic significance and explanatory captions)."""
|
30
|
+
|
31
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
32
|
+
|
33
|
+
description: str = Field(...)
|
34
|
+
"""A detailed explanation of the figure's content and its role in conveying key insights."""
|
35
|
+
|
36
|
+
figure_caption: str = Field(...)
|
37
|
+
"""The caption accompanying the figure, summarizing its main points and academic value."""
|
38
|
+
|
39
|
+
|
40
|
+
class ArticleEssence(ProposedAble):
|
41
|
+
"""Structured representation of the core elements of an academic paper(providing a comprehensive digital profile of the paper's essential information)."""
|
42
|
+
|
43
|
+
# Basic Metadata
|
44
|
+
title: str = Field(...)
|
45
|
+
"""The full title of the paper, including any subtitles if applicable."""
|
46
|
+
|
47
|
+
authors: List[str] = Field(default_factory=list)
|
48
|
+
"""A list of the paper's authors, typically in the order of contribution."""
|
49
|
+
|
50
|
+
keywords: List[str] = Field(default_factory=list)
|
51
|
+
"""A list of keywords that summarize the paper's focus and facilitate indexing."""
|
52
|
+
|
53
|
+
publication_year: int = Field(None)
|
54
|
+
"""The year in which the paper was published."""
|
55
|
+
|
56
|
+
# Core Content Elements
|
57
|
+
domain: List[str] = Field(default_factory=list)
|
58
|
+
"""The research domains or fields addressed by the paper (e.g., ['Natural Language Processing', 'Computer Vision'])."""
|
59
|
+
|
60
|
+
abstract: str = Field(...)
|
61
|
+
"""A structured abstract that outlines the research problem, methodology, and conclusions in three distinct sections."""
|
62
|
+
|
63
|
+
core_contributions: List[str] = Field(default_factory=list)
|
64
|
+
"""Key academic contributions that distinguish the paper from prior work in the field."""
|
65
|
+
|
66
|
+
technical_novelty: List[str] = Field(default_factory=list)
|
67
|
+
"""Specific technical innovations introduced by the research, listed as individual points."""
|
68
|
+
|
69
|
+
# Academic Achievements Showcase
|
70
|
+
highlighted_equations: List[Equation] = Field(default_factory=list)
|
71
|
+
"""Core mathematical equations that represent breakthroughs in the field, accompanied by explanations of their physical or conceptual significance."""
|
72
|
+
|
73
|
+
highlighted_algorithms: List[str] = Field(default_factory=list)
|
74
|
+
"""Pseudocode for key algorithms, annotated to highlight innovative components."""
|
75
|
+
|
76
|
+
highlighted_figures: List[Figure] = Field(default_factory=list)
|
77
|
+
"""Critical diagrams or illustrations, each accompanied by a caption explaining their academic importance."""
|
78
|
+
|
79
|
+
highlighted_tables: List[str] = Field(default_factory=list)
|
80
|
+
"""Important data tables, annotated to indicate statistical significance or other notable findings."""
|
81
|
+
|
82
|
+
# Academic Discussion Dimensions
|
83
|
+
research_problem: str = Field("")
|
84
|
+
"""A clearly defined research question or problem addressed by the study."""
|
85
|
+
|
86
|
+
limitations: List[str] = Field(default_factory=list)
|
87
|
+
"""An analysis of the methodological or experimental limitations of the research."""
|
88
|
+
|
89
|
+
future_work: List[str] = Field(default_factory=list)
|
90
|
+
"""Suggestions for potential directions or topics for follow-up studies."""
|
91
|
+
|
92
|
+
impact_analysis: str = Field("")
|
93
|
+
"""An assessment of the paper's potential influence on the development of the field."""
|
94
|
+
|
95
|
+
|
96
|
+
class ExtractArticleEssence(Action):
|
97
|
+
"""Extract the essence of article(s)."""
|
98
|
+
|
99
|
+
name: str = "extract article essence"
|
100
|
+
"""The name of the action."""
|
101
|
+
description: str = "Extract the essence of an article. output as json"
|
102
|
+
"""The description of the action."""
|
103
|
+
|
104
|
+
output_key: str = "article_essence"
|
105
|
+
"""The key of the output data."""
|
106
|
+
|
107
|
+
async def _execute[P: PathLike | str](
|
108
|
+
self,
|
109
|
+
task_input: Task,
|
110
|
+
reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
111
|
+
**_,
|
112
|
+
) -> List[ArticleEssence]:
|
113
|
+
if not await self.ajudge(
|
114
|
+
f"= Task\n{task_input.briefing}\n\n\n= Role\n{self.briefing}",
|
115
|
+
affirm_case="The task does not violate the role, and could be approved since the file dependencies are specified.",
|
116
|
+
deny_case="The task does violate the role, and could not be approved.",
|
117
|
+
):
|
118
|
+
logger.info(err := "Task not approved.")
|
119
|
+
raise RuntimeError(err)
|
120
|
+
|
121
|
+
# trim the references
|
122
|
+
contents = ["References".join(c.split("References")[:-1]) for c in map(reader, task_input.dependencies)]
|
123
|
+
return await self.propose(
|
124
|
+
ArticleEssence,
|
125
|
+
contents,
|
126
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
127
|
+
)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""A module for the task capabilities of the Fabricatio library."""
|
2
|
+
|
3
|
+
from typing import List, Type, Unpack, overload
|
4
|
+
|
5
|
+
from fabricatio.models.generic import ProposedAble
|
6
|
+
from fabricatio.models.kwargs_types import GenerateKwargs
|
7
|
+
from fabricatio.models.usages import LLMUsage
|
8
|
+
|
9
|
+
|
10
|
+
class Propose[M: ProposedAble](LLMUsage):
|
11
|
+
"""A class that proposes an Obj based on a prompt."""
|
12
|
+
|
13
|
+
@overload
|
14
|
+
async def propose(
|
15
|
+
self,
|
16
|
+
cls: Type[M],
|
17
|
+
prompt: List[str],
|
18
|
+
**kwargs: Unpack[GenerateKwargs],
|
19
|
+
) -> List[M]: ...
|
20
|
+
|
21
|
+
@overload
|
22
|
+
async def propose(
|
23
|
+
self,
|
24
|
+
cls: Type[M],
|
25
|
+
prompt: str,
|
26
|
+
**kwargs: Unpack[GenerateKwargs],
|
27
|
+
) -> M: ...
|
28
|
+
|
29
|
+
async def propose(
|
30
|
+
self,
|
31
|
+
cls: Type[M],
|
32
|
+
prompt: List[str] | str,
|
33
|
+
**kwargs: Unpack[GenerateKwargs],
|
34
|
+
) -> List[M] | M:
|
35
|
+
"""Asynchronously proposes a task based on a given prompt and parameters.
|
36
|
+
|
37
|
+
Parameters:
|
38
|
+
cls: The class type of the task to be proposed.
|
39
|
+
prompt: The prompt text for proposing a task, which is a string that must be provided.
|
40
|
+
**kwargs: The keyword arguments for the LLM (Large Language Model) usage.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
A Task object based on the proposal result.
|
44
|
+
"""
|
45
|
+
if isinstance(prompt, str):
|
46
|
+
return await self.aask_validate(
|
47
|
+
question=cls.create_json_prompt(prompt),
|
48
|
+
validator=cls.instantiate_from_string,
|
49
|
+
**kwargs,
|
50
|
+
)
|
51
|
+
return await self.aask_validate_batch(
|
52
|
+
questions=[cls.create_json_prompt(p) for p in prompt],
|
53
|
+
validator=cls.instantiate_from_string,
|
54
|
+
**kwargs,
|
55
|
+
)
|
fabricatio/capabilities/rag.py
CHANGED
@@ -1,89 +1,146 @@
|
|
1
1
|
"""A module for the RAG (Retrieval Augmented Generation) model."""
|
2
2
|
|
3
|
+
try:
|
4
|
+
from pymilvus import MilvusClient
|
5
|
+
except ImportError as e:
|
6
|
+
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
7
|
+
from functools import lru_cache
|
3
8
|
from operator import itemgetter
|
4
9
|
from os import PathLike
|
5
10
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
|
7
12
|
|
13
|
+
from fabricatio._rust_instances import template_manager
|
8
14
|
from fabricatio.config import configs
|
9
|
-
from fabricatio.
|
15
|
+
from fabricatio.journal import logger
|
16
|
+
from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
|
17
|
+
from fabricatio.models.usages import EmbeddingUsage
|
10
18
|
from fabricatio.models.utils import MilvusData
|
11
19
|
from more_itertools.recipes import flatten
|
20
|
+
from pydantic import Field, PrivateAttr
|
12
21
|
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
22
|
+
|
23
|
+
@lru_cache(maxsize=None)
|
24
|
+
def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
|
25
|
+
"""Create a Milvus client."""
|
26
|
+
return MilvusClient(
|
27
|
+
uri=uri,
|
28
|
+
token=token,
|
29
|
+
timeout=timeout,
|
30
|
+
)
|
18
31
|
|
19
32
|
|
20
|
-
class
|
33
|
+
class RAG(EmbeddingUsage):
|
21
34
|
"""A class representing the RAG (Retrieval Augmented Generation) model."""
|
22
35
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
),
|
29
|
-
)
|
30
|
-
_target_collection: Optional[str] = PrivateAttr(default=None)
|
36
|
+
target_collection: Optional[str] = Field(default=None)
|
37
|
+
"""The name of the collection being viewed."""
|
38
|
+
|
39
|
+
_client: Optional[MilvusClient] = PrivateAttr(None)
|
40
|
+
"""The Milvus client used for the RAG model."""
|
31
41
|
|
32
42
|
@property
|
33
43
|
def client(self) -> MilvusClient:
|
34
|
-
"""
|
44
|
+
"""Return the Milvus client."""
|
45
|
+
if self._client is None:
|
46
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
35
47
|
return self._client
|
36
48
|
|
37
|
-
def
|
49
|
+
def init_client(
|
50
|
+
self,
|
51
|
+
milvus_uri: Optional[str] = None,
|
52
|
+
milvus_token: Optional[str] = None,
|
53
|
+
milvus_timeout: Optional[float] = None,
|
54
|
+
) -> Self:
|
55
|
+
"""Initialize the Milvus client."""
|
56
|
+
self._client = create_client(
|
57
|
+
uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
58
|
+
token=milvus_token
|
59
|
+
or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
|
60
|
+
timeout=milvus_timeout or self.milvus_timeout,
|
61
|
+
)
|
62
|
+
return self
|
63
|
+
|
64
|
+
@overload
|
65
|
+
async def pack(
|
66
|
+
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
67
|
+
) -> List[MilvusData]: ...
|
68
|
+
@overload
|
69
|
+
async def pack(
|
70
|
+
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
71
|
+
) -> MilvusData: ...
|
72
|
+
|
73
|
+
async def pack(
|
74
|
+
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
75
|
+
) -> List[MilvusData] | MilvusData:
|
76
|
+
"""Asynchronously generates MilvusData objects for the given input text.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
80
|
+
subject (Optional[str]): The subject of the input text. Defaults to None.
|
81
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
85
|
+
"""
|
86
|
+
if isinstance(input_text, str):
|
87
|
+
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
88
|
+
vecs = await self.vectorize(input_text, **kwargs)
|
89
|
+
return [
|
90
|
+
MilvusData(
|
91
|
+
vector=vec,
|
92
|
+
text=text,
|
93
|
+
subject=subject,
|
94
|
+
)
|
95
|
+
for text, vec in zip(input_text, vecs, strict=True)
|
96
|
+
]
|
97
|
+
|
98
|
+
def view(
|
99
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
|
100
|
+
) -> Self:
|
38
101
|
"""View the specified collection.
|
39
102
|
|
40
103
|
Args:
|
41
104
|
collection_name (str): The name of the collection.
|
42
105
|
create (bool): Whether to create the collection if it does not exist.
|
106
|
+
**kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
|
43
107
|
"""
|
44
|
-
if create and self._client.has_collection(collection_name):
|
45
|
-
self.
|
108
|
+
if create and collection_name and not self._client.has_collection(collection_name):
|
109
|
+
kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
|
110
|
+
self._client.create_collection(collection_name, auto_id=True, **kwargs)
|
111
|
+
logger.info(f"Creating collection {collection_name}")
|
46
112
|
|
47
|
-
self.
|
113
|
+
self.target_collection = collection_name
|
48
114
|
return self
|
49
115
|
|
50
|
-
def
|
116
|
+
def quit_viewing(self) -> Self:
|
51
117
|
"""Quit the current view.
|
52
118
|
|
53
119
|
Returns:
|
54
120
|
Self: The current instance, allowing for method chaining.
|
55
121
|
"""
|
56
|
-
self.
|
57
|
-
return self
|
58
|
-
|
59
|
-
@property
|
60
|
-
def viewing_collection(self) -> Optional[str]:
|
61
|
-
"""Get the name of the collection being viewed.
|
62
|
-
|
63
|
-
Returns:
|
64
|
-
Optional[str]: The name of the collection being viewed.
|
65
|
-
"""
|
66
|
-
return self._target_collection
|
122
|
+
return self.view(None)
|
67
123
|
|
68
124
|
@property
|
69
|
-
def
|
125
|
+
def safe_target_collection(self) -> str:
|
70
126
|
"""Get the name of the collection being viewed, raise an error if not viewing any collection.
|
71
127
|
|
72
128
|
Returns:
|
73
129
|
str: The name of the collection being viewed.
|
74
130
|
"""
|
75
|
-
if self.
|
131
|
+
if self.target_collection is None:
|
76
132
|
raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
|
77
|
-
return self.
|
133
|
+
return self.target_collection
|
78
134
|
|
79
135
|
def add_document[D: Union[Dict[str, Any], MilvusData]](
|
80
|
-
self, data: D | List[D], collection_name: Optional[str] = None
|
136
|
+
self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
|
81
137
|
) -> Self:
|
82
138
|
"""Adds a document to the specified collection.
|
83
139
|
|
84
140
|
Args:
|
85
141
|
data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
|
86
142
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
143
|
+
flush (bool): Whether to flush the collection after insertion.
|
87
144
|
|
88
145
|
Returns:
|
89
146
|
Self: The current instance, allowing for method chaining.
|
@@ -92,11 +149,19 @@ class Rag(LLMUsage):
|
|
92
149
|
data = data.prepare_insertion()
|
93
150
|
if isinstance(data, list):
|
94
151
|
data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
|
95
|
-
|
152
|
+
c_name = collection_name or self.safe_target_collection
|
153
|
+
self._client.insert(c_name, data)
|
154
|
+
|
155
|
+
if flush:
|
156
|
+
logger.debug(f"Flushing collection {c_name}")
|
157
|
+
self._client.flush(c_name)
|
96
158
|
return self
|
97
159
|
|
98
|
-
def
|
99
|
-
self,
|
160
|
+
async def consume_file(
|
161
|
+
self,
|
162
|
+
source: List[PathLike] | PathLike,
|
163
|
+
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
164
|
+
collection_name: Optional[str] = None,
|
100
165
|
) -> Self:
|
101
166
|
"""Consume a file and add its content to the collection.
|
102
167
|
|
@@ -108,8 +173,21 @@ class Rag(LLMUsage):
|
|
108
173
|
Returns:
|
109
174
|
Self: The current instance, allowing for method chaining.
|
110
175
|
"""
|
111
|
-
|
112
|
-
|
176
|
+
if not isinstance(source, list):
|
177
|
+
source = [source]
|
178
|
+
return await self.consume_string([reader(s) for s in source], collection_name)
|
179
|
+
|
180
|
+
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
181
|
+
"""Consume a string and add it to the collection.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
text (List[str] | str): The text to be added to the collection.
|
185
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Self: The current instance, allowing for method chaining.
|
189
|
+
"""
|
190
|
+
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
113
191
|
return self
|
114
192
|
|
115
193
|
async def afetch_document(
|
@@ -117,6 +195,7 @@ class Rag(LLMUsage):
|
|
117
195
|
vecs: List[List[float]],
|
118
196
|
desired_fields: List[str] | str,
|
119
197
|
collection_name: Optional[str] = None,
|
198
|
+
similarity_threshold: float = 0.37,
|
120
199
|
result_per_query: int = 10,
|
121
200
|
) -> List[Dict[str, Any]] | List[Any]:
|
122
201
|
"""Fetch data from the collection.
|
@@ -125,6 +204,7 @@ class Rag(LLMUsage):
|
|
125
204
|
vecs (List[List[float]]): The vectors to search for.
|
126
205
|
desired_fields (List[str] | str): The fields to retrieve.
|
127
206
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
207
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
128
208
|
result_per_query (int): The number of results to return per query.
|
129
209
|
|
130
210
|
Returns:
|
@@ -132,8 +212,9 @@ class Rag(LLMUsage):
|
|
132
212
|
"""
|
133
213
|
# Step 1: Search for vectors
|
134
214
|
search_results = self._client.search(
|
135
|
-
collection_name or self.
|
215
|
+
collection_name or self.safe_target_collection,
|
136
216
|
vecs,
|
217
|
+
search_params={"radius": similarity_threshold},
|
137
218
|
output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
|
138
219
|
limit=result_per_query,
|
139
220
|
)
|
@@ -144,6 +225,7 @@ class Rag(LLMUsage):
|
|
144
225
|
# Step 3: Sort by distance (descending)
|
145
226
|
sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
|
146
227
|
|
228
|
+
logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
|
147
229
|
# Step 4: Extract the entities
|
148
230
|
resp = [result["entity"] for result in sorted_results]
|
149
231
|
|
@@ -155,25 +237,74 @@ class Rag(LLMUsage):
|
|
155
237
|
self,
|
156
238
|
query: List[str] | str,
|
157
239
|
collection_name: Optional[str] = None,
|
158
|
-
result_per_query: int = 10,
|
159
240
|
final_limit: int = 20,
|
241
|
+
**kwargs: Unpack[FetchKwargs],
|
160
242
|
) -> List[str]:
|
161
243
|
"""Retrieve data from the collection.
|
162
244
|
|
163
245
|
Args:
|
164
246
|
query (List[str] | str): The query to be used for retrieval.
|
165
247
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
166
|
-
result_per_query (int): The number of results to be returned per query.
|
167
248
|
final_limit (int): The final limit on the number of results to return.
|
249
|
+
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
168
250
|
|
169
251
|
Returns:
|
170
252
|
List[str]: A list of strings containing the retrieved data.
|
171
253
|
"""
|
172
254
|
if isinstance(query, str):
|
173
255
|
query = [query]
|
174
|
-
return
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
256
|
+
return (
|
257
|
+
await self.afetch_document(
|
258
|
+
vecs=(await self.vectorize(query)),
|
259
|
+
desired_fields="text",
|
260
|
+
collection_name=collection_name,
|
261
|
+
**kwargs,
|
262
|
+
)
|
179
263
|
)[:final_limit]
|
264
|
+
|
265
|
+
async def aask_retrieved(
|
266
|
+
self,
|
267
|
+
question: str | List[str],
|
268
|
+
query: List[str] | str,
|
269
|
+
collection_name: Optional[str] = None,
|
270
|
+
extra_system_message: str = "",
|
271
|
+
result_per_query: int = 10,
|
272
|
+
final_limit: int = 20,
|
273
|
+
similarity_threshold: float = 0.37,
|
274
|
+
**kwargs: Unpack[LLMKwargs],
|
275
|
+
) -> str:
|
276
|
+
"""Asks a question by retrieving relevant documents based on the provided query.
|
277
|
+
|
278
|
+
This method performs document retrieval using the given query, then asks the
|
279
|
+
specified question using the retrieved documents as context.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
question (str | List[str]): The question or list of questions to be asked.
|
283
|
+
query (List[str] | str): The query or list of queries used for document retrieval.
|
284
|
+
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
285
|
+
If not provided, the currently viewed collection is used.
|
286
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
287
|
+
result_per_query (int): The number of results to return per query. Default is 10.
|
288
|
+
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
289
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
290
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
str: A string response generated after asking with the context of retrieved documents.
|
294
|
+
"""
|
295
|
+
docs = await self.aretrieve(
|
296
|
+
query,
|
297
|
+
collection_name,
|
298
|
+
final_limit,
|
299
|
+
result_per_query=result_per_query,
|
300
|
+
similarity_threshold=similarity_threshold,
|
301
|
+
)
|
302
|
+
|
303
|
+
rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
|
304
|
+
|
305
|
+
logger.debug(f"Retrieved Documents: \n{rendered}")
|
306
|
+
return await self.aask(
|
307
|
+
question,
|
308
|
+
f"{rendered}\n\n{extra_system_message}",
|
309
|
+
**kwargs,
|
310
|
+
)
|
fabricatio/capabilities/task.py
CHANGED
@@ -5,21 +5,21 @@ from typing import Any, Dict, List, Optional, Tuple, Unpack
|
|
5
5
|
|
6
6
|
import orjson
|
7
7
|
from fabricatio._rust_instances import template_manager
|
8
|
+
from fabricatio.capabilities.propose import Propose
|
8
9
|
from fabricatio.config import configs
|
9
10
|
from fabricatio.models.generic import WithBriefing
|
10
11
|
from fabricatio.models.kwargs_types import ChooseKwargs, ValidateKwargs
|
11
12
|
from fabricatio.models.task import Task
|
12
13
|
from fabricatio.models.tool import Tool, ToolExecutor
|
13
|
-
from fabricatio.models.usages import
|
14
|
+
from fabricatio.models.usages import ToolBoxUsage
|
14
15
|
from fabricatio.parser import JsonCapture, PythonCapture
|
15
16
|
from loguru import logger
|
16
|
-
from pydantic import ValidationError
|
17
17
|
|
18
18
|
|
19
|
-
class ProposeTask(WithBriefing,
|
19
|
+
class ProposeTask(WithBriefing, Propose):
|
20
20
|
"""A class that proposes a task based on a prompt."""
|
21
21
|
|
22
|
-
async def
|
22
|
+
async def propose_task[T](
|
23
23
|
self,
|
24
24
|
prompt: str,
|
25
25
|
**kwargs: Unpack[ValidateKwargs],
|
@@ -34,27 +34,10 @@ class ProposeTask(WithBriefing, LLMUsage):
|
|
34
34
|
A Task object based on the proposal result.
|
35
35
|
"""
|
36
36
|
if not prompt:
|
37
|
-
err
|
38
|
-
logger.error(err)
|
37
|
+
logger.error(err := f"{self.name}: Prompt must be provided.")
|
39
38
|
raise ValueError(err)
|
40
39
|
|
41
|
-
|
42
|
-
try:
|
43
|
-
cap = JsonCapture.capture(response)
|
44
|
-
logger.debug(f"Response: \n{response}")
|
45
|
-
logger.info(f"Captured JSON: \n{cap}")
|
46
|
-
return Task.model_validate_json(cap)
|
47
|
-
except ValidationError as e:
|
48
|
-
logger.error(f"Failed to parse task from JSON: {e}")
|
49
|
-
return None
|
50
|
-
|
51
|
-
template_data = {"prompt": prompt, "json_example": Task.json_example()}
|
52
|
-
return await self.aask_validate(
|
53
|
-
question=template_manager.render_template(configs.templates.propose_task_template, template_data),
|
54
|
-
validator=_validate_json,
|
55
|
-
system_message=f"# your personal briefing: \n{self.briefing}",
|
56
|
-
**kwargs,
|
57
|
-
)
|
40
|
+
return await self.propose(Task, prompt, system_message=f"# your personal briefing: \n{self.briefing}", **kwargs)
|
58
41
|
|
59
42
|
|
60
43
|
class HandleTask(WithBriefing, ToolBoxUsage):
|