fabricatio 0.2.1.dev0__cp313-cp313-win_amd64.whl → 0.3.14.dev5__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +12 -20
- fabricatio/actions/__init__.py +1 -5
- fabricatio/actions/article.py +319 -0
- fabricatio/actions/article_rag.py +416 -0
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +248 -0
- fabricatio/actions/rag.py +96 -0
- fabricatio/actions/rules.py +83 -0
- fabricatio/capabilities/__init__.py +1 -0
- fabricatio/capabilities/advanced_judge.py +20 -0
- fabricatio/capabilities/advanced_rag.py +61 -0
- fabricatio/capabilities/censor.py +105 -0
- fabricatio/capabilities/check.py +212 -0
- fabricatio/capabilities/correct.py +228 -0
- fabricatio/capabilities/extract.py +74 -0
- fabricatio/capabilities/persist.py +103 -0
- fabricatio/capabilities/propose.py +65 -0
- fabricatio/capabilities/rag.py +263 -0
- fabricatio/capabilities/rating.py +404 -0
- fabricatio/capabilities/review.py +114 -0
- fabricatio/capabilities/task.py +113 -0
- fabricatio/decorators.py +251 -179
- fabricatio/{core.py → emitter.py} +31 -21
- fabricatio/fs/__init__.py +32 -2
- fabricatio/fs/curd.py +32 -9
- fabricatio/fs/readers.py +44 -7
- fabricatio/journal.py +3 -19
- fabricatio/models/action.py +185 -61
- fabricatio/models/adv_kwargs_types.py +63 -0
- fabricatio/models/extra/__init__.py +1 -0
- fabricatio/models/extra/advanced_judge.py +32 -0
- fabricatio/models/extra/aricle_rag.py +284 -0
- fabricatio/models/extra/article_base.py +422 -0
- fabricatio/models/extra/article_essence.py +101 -0
- fabricatio/models/extra/article_main.py +284 -0
- fabricatio/models/extra/article_outline.py +46 -0
- fabricatio/models/extra/article_proposal.py +52 -0
- fabricatio/models/extra/patches.py +20 -0
- fabricatio/models/extra/problem.py +165 -0
- fabricatio/models/extra/rag.py +98 -0
- fabricatio/models/extra/rule.py +52 -0
- fabricatio/models/generic.py +704 -36
- fabricatio/models/kwargs_types.py +112 -17
- fabricatio/models/role.py +74 -27
- fabricatio/models/task.py +94 -60
- fabricatio/models/tool.py +328 -188
- fabricatio/models/usages.py +791 -515
- fabricatio/parser.py +81 -60
- fabricatio/rust.cp313-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +886 -0
- fabricatio/toolboxes/__init__.py +1 -3
- fabricatio/toolboxes/fs.py +17 -1
- fabricatio/utils.py +156 -0
- fabricatio/workflows/__init__.py +1 -0
- fabricatio/workflows/articles.py +24 -0
- fabricatio/workflows/rag.py +11 -0
- fabricatio-0.3.14.dev5.data/scripts/tdown.exe +0 -0
- fabricatio-0.3.14.dev5.data/scripts/ttm.exe +0 -0
- fabricatio-0.3.14.dev5.dist-info/METADATA +188 -0
- fabricatio-0.3.14.dev5.dist-info/RECORD +64 -0
- {fabricatio-0.2.1.dev0.dist-info → fabricatio-0.3.14.dev5.dist-info}/WHEEL +1 -1
- fabricatio/_rust.cp313-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +0 -53
- fabricatio/_rust_instances.py +0 -8
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio/config.py +0 -263
- fabricatio/models/advanced.py +0 -128
- fabricatio/models/events.py +0 -82
- fabricatio/models/utils.py +0 -78
- fabricatio/toolboxes/task.py +0 -6
- fabricatio-0.2.1.dev0.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.1.dev0.dist-info/METADATA +0 -420
- fabricatio-0.2.1.dev0.dist-info/RECORD +0 -35
- {fabricatio-0.2.1.dev0.dist-info → fabricatio-0.3.14.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,263 @@
|
|
1
|
+
"""A module for the RAG (Retrieval Augmented Generation) model."""
|
2
|
+
|
3
|
+
from abc import ABC
|
4
|
+
|
5
|
+
try:
|
6
|
+
from pymilvus import MilvusClient
|
7
|
+
except ImportError as e:
|
8
|
+
raise RuntimeError(
|
9
|
+
"pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
10
|
+
) from e
|
11
|
+
from functools import lru_cache
|
12
|
+
from operator import itemgetter
|
13
|
+
from typing import List, Optional, Self, Type, Unpack
|
14
|
+
|
15
|
+
from more_itertools.recipes import flatten, unique
|
16
|
+
from pydantic import Field, PrivateAttr
|
17
|
+
|
18
|
+
from fabricatio.journal import logger
|
19
|
+
from fabricatio.models.adv_kwargs_types import CollectionConfigKwargs, FetchKwargs
|
20
|
+
from fabricatio.models.extra.rag import MilvusDataBase
|
21
|
+
from fabricatio.models.kwargs_types import ChooseKwargs
|
22
|
+
from fabricatio.models.usages import EmbeddingUsage
|
23
|
+
from fabricatio.rust import CONFIG, TEMPLATE_MANAGER
|
24
|
+
from fabricatio.utils import ok
|
25
|
+
|
26
|
+
|
27
|
+
@lru_cache(maxsize=None)
|
28
|
+
def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
|
29
|
+
"""Create a Milvus client."""
|
30
|
+
return MilvusClient(
|
31
|
+
uri=uri,
|
32
|
+
token=token,
|
33
|
+
timeout=timeout,
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
class RAG(EmbeddingUsage, ABC):
|
38
|
+
"""A class representing the RAG (Retrieval Augmented Generation) model."""
|
39
|
+
|
40
|
+
target_collection: Optional[str] = Field(default=None)
|
41
|
+
"""The name of the collection being viewed."""
|
42
|
+
|
43
|
+
_client: Optional[MilvusClient] = PrivateAttr(None)
|
44
|
+
"""The Milvus client used for the RAG model."""
|
45
|
+
|
46
|
+
@property
|
47
|
+
def client(self) -> MilvusClient:
|
48
|
+
"""Return the Milvus client."""
|
49
|
+
if self._client is None:
|
50
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
51
|
+
return self._client
|
52
|
+
|
53
|
+
def init_client(
|
54
|
+
self,
|
55
|
+
milvus_uri: Optional[str] = None,
|
56
|
+
milvus_token: Optional[str] = None,
|
57
|
+
milvus_timeout: Optional[float] = None,
|
58
|
+
) -> Self:
|
59
|
+
"""Initialize the Milvus client."""
|
60
|
+
self._client = create_client(
|
61
|
+
uri=milvus_uri or ok(self.milvus_uri or CONFIG.rag.milvus_uri),
|
62
|
+
token=milvus_token
|
63
|
+
or (token.get_secret_value() if (token := (self.milvus_token or CONFIG.rag.milvus_token)) else ""),
|
64
|
+
timeout=milvus_timeout or self.milvus_timeout or CONFIG.rag.milvus_timeout,
|
65
|
+
)
|
66
|
+
return self
|
67
|
+
|
68
|
+
def check_client(self, init: bool = True) -> Self:
|
69
|
+
"""Check if the client is initialized, and if not, initialize it."""
|
70
|
+
if self._client is None and init:
|
71
|
+
return self.init_client()
|
72
|
+
if self._client is None and not init:
|
73
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
74
|
+
return self
|
75
|
+
|
76
|
+
def view(
|
77
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
|
78
|
+
) -> Self:
|
79
|
+
"""View the specified collection.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
collection_name (str): The name of the collection.
|
83
|
+
create (bool): Whether to create the collection if it does not exist.
|
84
|
+
**kwargs (Unpack[CollectionConfigKwargs]): Additional keyword arguments for collection configuration.
|
85
|
+
"""
|
86
|
+
if create and collection_name and not self.check_client().client.has_collection(collection_name):
|
87
|
+
kwargs["dimension"] = ok(
|
88
|
+
kwargs.get("dimension")
|
89
|
+
or self.milvus_dimensions
|
90
|
+
or CONFIG.rag.milvus_dimensions
|
91
|
+
or self.embedding_dimensions
|
92
|
+
or CONFIG.embedding.dimensions,
|
93
|
+
"`dimension` is not set at any level.",
|
94
|
+
)
|
95
|
+
self.client.create_collection(collection_name, auto_id=True, **kwargs)
|
96
|
+
logger.info(f"Creating collection {collection_name}")
|
97
|
+
|
98
|
+
self.target_collection = collection_name
|
99
|
+
return self
|
100
|
+
|
101
|
+
def quit_viewing(self) -> Self:
|
102
|
+
"""Quit the current view.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Self: The current instance, allowing for method chaining.
|
106
|
+
"""
|
107
|
+
return self.view(None)
|
108
|
+
|
109
|
+
@property
|
110
|
+
def safe_target_collection(self) -> str:
|
111
|
+
"""Get the name of the collection being viewed, raise an error if not viewing any collection.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
str: The name of the collection being viewed.
|
115
|
+
"""
|
116
|
+
return ok(self.target_collection, "No collection is being viewed. Have you called `self.view()`?")
|
117
|
+
|
118
|
+
async def add_document[D: MilvusDataBase](
|
119
|
+
self, data: List[D] | D, collection_name: Optional[str] = None, flush: bool = False
|
120
|
+
) -> Self:
|
121
|
+
"""Adds a document to the specified collection.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
data (Union[Dict[str, Any], MilvusDataBase] | List[Union[Dict[str, Any], MilvusDataBase]]): The data to be added to the collection.
|
125
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
126
|
+
flush (bool): Whether to flush the collection after insertion.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Self: The current instance, allowing for method chaining.
|
130
|
+
"""
|
131
|
+
if isinstance(data, MilvusDataBase):
|
132
|
+
data = [data]
|
133
|
+
|
134
|
+
data_vec = await self.vectorize([d.prepare_vectorization() for d in data])
|
135
|
+
prepared_data = [d.prepare_insertion(vec) for d, vec in zip(data, data_vec, strict=True)]
|
136
|
+
|
137
|
+
c_name = collection_name or self.safe_target_collection
|
138
|
+
self.check_client().client.insert(c_name, prepared_data)
|
139
|
+
|
140
|
+
if flush:
|
141
|
+
logger.debug(f"Flushing collection {c_name}")
|
142
|
+
self.client.flush(c_name)
|
143
|
+
return self
|
144
|
+
|
145
|
+
async def afetch_document[D: MilvusDataBase](
|
146
|
+
self,
|
147
|
+
query: List[str],
|
148
|
+
document_model: Type[D],
|
149
|
+
collection_name: Optional[str] = None,
|
150
|
+
similarity_threshold: float = 0.37,
|
151
|
+
result_per_query: int = 10,
|
152
|
+
tei_endpoint: Optional[str] = None,
|
153
|
+
reranker_threshold: float = 0.7,
|
154
|
+
filter_expr: str = "",
|
155
|
+
) -> List[D]:
|
156
|
+
"""Asynchronously fetches documents from a Milvus database based on input vectors.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
query (List[str]): A list of vectors to search for in the database.
|
160
|
+
document_model (Type[D]): The model class used to convert fetched data into document objects.
|
161
|
+
collection_name (Optional[str]): The name of the collection to search within.
|
162
|
+
If None, the currently viewed collection is used.
|
163
|
+
similarity_threshold (float): The similarity threshold for vector search. Defaults to 0.37.
|
164
|
+
result_per_query (int): The maximum number of results to return per query. Defaults to 10.
|
165
|
+
tei_endpoint (str): the endpoint of the TEI api.
|
166
|
+
reranker_threshold (float): The threshold used to filtered low relativity document.
|
167
|
+
filter_expr (str) : The filter expression used to filter out unwanted documents.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
List[D]: A list of document objects created from the fetched data.
|
171
|
+
"""
|
172
|
+
# Step 1: Search for vectors
|
173
|
+
search_results = self.check_client().client.search(
|
174
|
+
collection_name or self.safe_target_collection,
|
175
|
+
await self.vectorize(query),
|
176
|
+
search_params={"radius": similarity_threshold},
|
177
|
+
output_fields=list(document_model.model_fields),
|
178
|
+
filter=filter_expr,
|
179
|
+
limit=result_per_query,
|
180
|
+
)
|
181
|
+
if tei_endpoint is not None:
|
182
|
+
from fabricatio.rust import TEIClient
|
183
|
+
|
184
|
+
reranker = TEIClient(base_url=tei_endpoint)
|
185
|
+
|
186
|
+
retrieved_id = set()
|
187
|
+
raw_result = []
|
188
|
+
|
189
|
+
for q, g in zip(query, search_results, strict=True):
|
190
|
+
models = document_model.from_sequence([res["entity"] for res in g if res["id"] not in retrieved_id])
|
191
|
+
logger.debug(f"Retrived {len(g)} raw document, filtered out {len(models)}.")
|
192
|
+
retrieved_id.update(res["id"] for res in g)
|
193
|
+
if not models:
|
194
|
+
continue
|
195
|
+
rank_scores = await reranker.arerank(q, [m.prepare_vectorization() for m in models], truncate=True)
|
196
|
+
raw_result.extend((models[idx], scr) for (idx, scr) in rank_scores if scr > reranker_threshold)
|
197
|
+
|
198
|
+
raw_result_sorted = sorted(raw_result, key=lambda x: x[1], reverse=True)
|
199
|
+
return [r[0] for r in raw_result_sorted]
|
200
|
+
|
201
|
+
# Step 2: Flatten the search results
|
202
|
+
flattened_results = flatten(search_results)
|
203
|
+
unique_results = unique(flattened_results, key=itemgetter("id"))
|
204
|
+
|
205
|
+
# Step 3: Sort by distance (descending)
|
206
|
+
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
207
|
+
|
208
|
+
logger.debug(
|
209
|
+
f"Fetched {len(sorted_results)} document,searched similarities: {[t['distance'] for t in sorted_results]}"
|
210
|
+
)
|
211
|
+
# Step 4: Extract the entities
|
212
|
+
resp = [result["entity"] for result in sorted_results]
|
213
|
+
|
214
|
+
return document_model.from_sequence(resp)
|
215
|
+
|
216
|
+
async def aretrieve[D: MilvusDataBase](
|
217
|
+
self,
|
218
|
+
query: List[str] | str,
|
219
|
+
document_model: Type[D],
|
220
|
+
max_accepted: int = 20,
|
221
|
+
**kwargs: Unpack[FetchKwargs],
|
222
|
+
) -> List[D]:
|
223
|
+
"""Retrieve data from the collection.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
query (List[str] | str): The query to be used for retrieval.
|
227
|
+
document_model (Type[D]): The model class used to convert retrieved data into document objects.
|
228
|
+
max_accepted (int): The final limit on the number of results to return.
|
229
|
+
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
List[D]: A list of document objects created from the retrieved data.
|
233
|
+
"""
|
234
|
+
if isinstance(query, str):
|
235
|
+
query = [query]
|
236
|
+
|
237
|
+
return (
|
238
|
+
await self.afetch_document(
|
239
|
+
query=query,
|
240
|
+
document_model=document_model,
|
241
|
+
**kwargs,
|
242
|
+
)
|
243
|
+
)[:max_accepted]
|
244
|
+
|
245
|
+
async def arefined_query(
|
246
|
+
self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs[Optional[List[str]]]]
|
247
|
+
) -> Optional[List[str]]:
|
248
|
+
"""Refines the given question using a template.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
question (List[str] | str): The question to be refined.
|
252
|
+
**kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
List[str]: A list of refined questions.
|
256
|
+
"""
|
257
|
+
return await self.alist_str(
|
258
|
+
TEMPLATE_MANAGER.render_template(
|
259
|
+
CONFIG.templates.refined_query_template,
|
260
|
+
{"question": [question] if isinstance(question, str) else question},
|
261
|
+
),
|
262
|
+
**kwargs,
|
263
|
+
)
|
@@ -0,0 +1,404 @@
|
|
1
|
+
"""A module that provides functionality to rate tasks based on a rating manual and score range."""
|
2
|
+
|
3
|
+
from abc import ABC
|
4
|
+
from itertools import permutations
|
5
|
+
from random import sample
|
6
|
+
from typing import Dict, List, Optional, Set, Tuple, Union, Unpack, overload
|
7
|
+
|
8
|
+
from more_itertools import flatten, windowed
|
9
|
+
from pydantic import Field, NonNegativeInt, PositiveInt, create_model
|
10
|
+
|
11
|
+
from fabricatio.capabilities.propose import Propose
|
12
|
+
from fabricatio.journal import logger
|
13
|
+
from fabricatio.models.generic import Display, ProposedAble
|
14
|
+
from fabricatio.models.kwargs_types import CompositeScoreKwargs, ValidateKwargs
|
15
|
+
from fabricatio.parser import JsonCapture
|
16
|
+
from fabricatio.rust import CONFIG, TEMPLATE_MANAGER
|
17
|
+
from fabricatio.utils import ok, override_kwargs
|
18
|
+
|
19
|
+
|
20
|
+
class Rating(Propose, ABC):
|
21
|
+
"""A class that provides functionality to rate tasks based on a rating manual and score range.
|
22
|
+
|
23
|
+
References:
|
24
|
+
Lu X, Li J, Takeuchi K, et al. AHP-powered LLM reasoning for multi-criteria evaluation of open-ended responses[A/OL]. arXiv, 2024. DOI: 10.48550/arXiv.2410.01246.
|
25
|
+
"""
|
26
|
+
|
27
|
+
async def rate_fine_grind(
|
28
|
+
self,
|
29
|
+
to_rate: str | List[str],
|
30
|
+
rating_manual: Dict[str, str],
|
31
|
+
score_range: Tuple[float, float],
|
32
|
+
**kwargs: Unpack[ValidateKwargs[Dict[str, float]]],
|
33
|
+
) -> Dict[str, float] | List[Dict[str, float]] | List[Optional[Dict[str, float]]] | None:
|
34
|
+
"""Rate a given string based on a rating manual and score range.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
to_rate (str): The string to be rated.
|
38
|
+
rating_manual (Dict[str, str]): A dictionary containing the rating criteria.
|
39
|
+
score_range (Tuple[float, float]): A tuple representing the valid score range.
|
40
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Dict[str, float]: A dictionary with the ratings for each dimension.
|
44
|
+
"""
|
45
|
+
min_score, max_score = score_range
|
46
|
+
tip = (max_score - min_score) / 9
|
47
|
+
|
48
|
+
model = create_model( # pyright: ignore [reportCallIssue]
|
49
|
+
"RatingResult",
|
50
|
+
__base__=ProposedAble,
|
51
|
+
__doc__=f"The rating result contains the scores against each criterion, with min_score={min_score} and max_score={max_score}.",
|
52
|
+
**{ # pyright: ignore [reportArgumentType]
|
53
|
+
criterion: (
|
54
|
+
float,
|
55
|
+
Field(
|
56
|
+
ge=min_score,
|
57
|
+
le=max_score,
|
58
|
+
description=desc,
|
59
|
+
examples=[round(min_score + tip * i, 2) for i in range(10)],
|
60
|
+
),
|
61
|
+
)
|
62
|
+
for criterion, desc in rating_manual.items()
|
63
|
+
},
|
64
|
+
)
|
65
|
+
|
66
|
+
res = await self.propose(
|
67
|
+
model,
|
68
|
+
TEMPLATE_MANAGER.render_template(
|
69
|
+
CONFIG.templates.rate_fine_grind_template,
|
70
|
+
{"to_rate": to_rate, "min_score": min_score, "max_score": max_score},
|
71
|
+
)
|
72
|
+
if isinstance(to_rate, str)
|
73
|
+
else [
|
74
|
+
TEMPLATE_MANAGER.render_template(
|
75
|
+
CONFIG.templates.rate_fine_grind_template,
|
76
|
+
{"to_rate": t, "min_score": min_score, "max_score": max_score},
|
77
|
+
)
|
78
|
+
for t in to_rate
|
79
|
+
],
|
80
|
+
**override_kwargs(kwargs, default=None),
|
81
|
+
)
|
82
|
+
default = kwargs.get("default")
|
83
|
+
if isinstance(res, list):
|
84
|
+
return [r.model_dump() if r else default for r in res]
|
85
|
+
if res is None:
|
86
|
+
return default
|
87
|
+
return res.model_dump()
|
88
|
+
|
89
|
+
@overload
|
90
|
+
async def rate(
|
91
|
+
self,
|
92
|
+
to_rate: str,
|
93
|
+
topic: str,
|
94
|
+
criteria: Set[str],
|
95
|
+
manual: Optional[Dict[str, str]] = None,
|
96
|
+
score_range: Tuple[float, float] = (0.0, 1.0),
|
97
|
+
**kwargs: Unpack[ValidateKwargs],
|
98
|
+
) -> Dict[str, float]: ...
|
99
|
+
|
100
|
+
@overload
|
101
|
+
async def rate(
|
102
|
+
self,
|
103
|
+
to_rate: List[str],
|
104
|
+
topic: str,
|
105
|
+
criteria: Set[str],
|
106
|
+
manual: Optional[Dict[str, str]] = None,
|
107
|
+
score_range: Tuple[float, float] = (0.0, 1.0),
|
108
|
+
**kwargs: Unpack[ValidateKwargs],
|
109
|
+
) -> List[Dict[str, float]]: ...
|
110
|
+
|
111
|
+
async def rate(
|
112
|
+
self,
|
113
|
+
to_rate: Union[str, List[str]],
|
114
|
+
topic: str,
|
115
|
+
criteria: Set[str],
|
116
|
+
manual: Optional[Dict[str, str]] = None,
|
117
|
+
score_range: Tuple[float, float] = (0.0, 1.0),
|
118
|
+
**kwargs: Unpack[ValidateKwargs],
|
119
|
+
) -> Dict[str, float] | List[Dict[str, float]] | List[Optional[Dict[str, float]]] | None:
|
120
|
+
"""Rate a given string or a sequence of strings based on a topic, criteria, and score range.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
to_rate (Union[str, List[str]]): The string or sequence of strings to be rated.
|
124
|
+
topic (str): The topic related to the task.
|
125
|
+
criteria (Set[str]): A set of criteria for rating.
|
126
|
+
manual (Optional[Dict[str, str]]): A dictionary containing the rating criteria. If not provided, then this method will draft the criteria automatically.
|
127
|
+
score_range (Tuple[float, float], optional): A tuple representing the valid score range. Defaults to (0.0, 1.0).
|
128
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
Union[Dict[str, float], List[Dict[str, float]]]: A dictionary with the ratings for each criterion if a single string is provided,
|
132
|
+
or a list of dictionaries with the ratings for each criterion if a sequence of strings is provided.
|
133
|
+
"""
|
134
|
+
manual = (
|
135
|
+
manual
|
136
|
+
or await self.draft_rating_manual(topic, criteria, **override_kwargs(kwargs, default=None))
|
137
|
+
or dict(zip(criteria, criteria, strict=True))
|
138
|
+
)
|
139
|
+
|
140
|
+
return await self.rate_fine_grind(to_rate, manual, score_range, **kwargs)
|
141
|
+
|
142
|
+
async def draft_rating_manual(
|
143
|
+
self, topic: str, criteria: Optional[Set[str]] = None, **kwargs: Unpack[ValidateKwargs[Dict[str, str]]]
|
144
|
+
) -> Optional[Dict[str, str]]:
|
145
|
+
"""Drafts a rating manual based on a topic and dimensions.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
topic (str): The topic for the rating manual.
|
149
|
+
criteria (Optional[Set[str]], optional): A set of criteria for the rating manual. If not specified, then this method will draft the criteria automatically.
|
150
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Dict[str, str]: A dictionary representing the drafted rating manual.
|
154
|
+
"""
|
155
|
+
|
156
|
+
def _validator(response: str) -> Dict[str, str] | None:
|
157
|
+
if (
|
158
|
+
(json_data := JsonCapture.validate_with(response, target_type=dict, elements_type=str)) is not None
|
159
|
+
and json_data.keys() == criteria
|
160
|
+
and all(isinstance(v, str) for v in json_data.values())
|
161
|
+
):
|
162
|
+
return json_data
|
163
|
+
return None
|
164
|
+
|
165
|
+
criteria = criteria or await self.draft_rating_criteria(topic, **override_kwargs(dict(kwargs), default=None))
|
166
|
+
|
167
|
+
if criteria is None:
|
168
|
+
logger.error(f"Failed to draft rating criteria for topic {topic}")
|
169
|
+
return None
|
170
|
+
|
171
|
+
return await self.aask_validate(
|
172
|
+
question=(
|
173
|
+
TEMPLATE_MANAGER.render_template(
|
174
|
+
CONFIG.templates.draft_rating_manual_template,
|
175
|
+
{
|
176
|
+
"topic": topic,
|
177
|
+
"criteria": list(criteria),
|
178
|
+
},
|
179
|
+
)
|
180
|
+
),
|
181
|
+
validator=_validator,
|
182
|
+
**kwargs,
|
183
|
+
)
|
184
|
+
|
185
|
+
async def draft_rating_criteria(
|
186
|
+
self,
|
187
|
+
topic: str,
|
188
|
+
criteria_count: NonNegativeInt = 0,
|
189
|
+
**kwargs: Unpack[ValidateKwargs[Set[str]]],
|
190
|
+
) -> Optional[Set[str]]:
|
191
|
+
"""Drafts rating dimensions based on a topic.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
topic (str): The topic for the rating dimensions.
|
195
|
+
criteria_count (NonNegativeInt, optional): The number of dimensions to draft, 0 means no limit. Defaults to 0.
|
196
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
Set[str]: A set of rating dimensions.
|
200
|
+
"""
|
201
|
+
return await self.aask_validate(
|
202
|
+
question=(
|
203
|
+
TEMPLATE_MANAGER.render_template(
|
204
|
+
CONFIG.templates.draft_rating_criteria_template,
|
205
|
+
{
|
206
|
+
"topic": topic,
|
207
|
+
"criteria_count": criteria_count,
|
208
|
+
},
|
209
|
+
)
|
210
|
+
),
|
211
|
+
validator=lambda resp: set(out)
|
212
|
+
if (out := JsonCapture.validate_with(resp, list, str, criteria_count)) is not None
|
213
|
+
else out,
|
214
|
+
**kwargs,
|
215
|
+
)
|
216
|
+
|
217
|
+
async def draft_rating_criteria_from_examples(
|
218
|
+
self,
|
219
|
+
topic: str,
|
220
|
+
examples: List[str],
|
221
|
+
m: NonNegativeInt = 0,
|
222
|
+
reasons_count: PositiveInt = 2,
|
223
|
+
criteria_count: PositiveInt = 5,
|
224
|
+
**kwargs: Unpack[ValidateKwargs],
|
225
|
+
) -> Optional[Set[str]]:
|
226
|
+
"""Asynchronously drafts a set of rating criteria based on provided examples.
|
227
|
+
|
228
|
+
This function generates rating criteria by analyzing examples and extracting reasons for comparison,
|
229
|
+
then further condensing these reasons into a specified number of criteria.
|
230
|
+
|
231
|
+
Parameters:
|
232
|
+
topic (str): The subject topic for the rating criteria.
|
233
|
+
examples (List[str]): A list of example texts to analyze.
|
234
|
+
m (NonNegativeInt, optional): The number of examples to sample from the provided list. Defaults to 0 (no sampling).
|
235
|
+
reasons_count (PositiveInt, optional): The number of reasons to extract from each pair of examples. Defaults to 2.
|
236
|
+
criteria_count (PositiveInt, optional): The final number of rating criteria to draft. Defaults to 5.
|
237
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for validation.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
Set[str]: A set of drafted rating criteria.
|
241
|
+
|
242
|
+
Warnings:
|
243
|
+
Since this function uses pairwise comparisons, it may not be suitable for large lists of examples.
|
244
|
+
For that reason, consider using a smaller list of examples or setting `m` to a non-zero value smaller than the length of the examples.
|
245
|
+
"""
|
246
|
+
if m:
|
247
|
+
examples = sample(examples, m)
|
248
|
+
|
249
|
+
# extract reasons from the comparison of ordered pairs of extracted from examples
|
250
|
+
reasons = flatten(
|
251
|
+
await self.aask_validate( # pyright: ignore [reportArgumentType]
|
252
|
+
question=[
|
253
|
+
TEMPLATE_MANAGER.render_template(
|
254
|
+
CONFIG.templates.extract_reasons_from_examples_template,
|
255
|
+
{
|
256
|
+
"topic": topic,
|
257
|
+
"first": pair[0],
|
258
|
+
"second": pair[1],
|
259
|
+
"reasons_count": reasons_count,
|
260
|
+
},
|
261
|
+
)
|
262
|
+
for pair in (permutations(examples, 2))
|
263
|
+
],
|
264
|
+
validator=lambda resp: JsonCapture.validate_with(
|
265
|
+
resp, target_type=list, elements_type=str, length=reasons_count
|
266
|
+
),
|
267
|
+
**kwargs,
|
268
|
+
)
|
269
|
+
)
|
270
|
+
# extract certain mount of criteria from reasons according to their importance and frequency
|
271
|
+
return await self.aask_validate(
|
272
|
+
question=(
|
273
|
+
TEMPLATE_MANAGER.render_template(
|
274
|
+
CONFIG.templates.extract_criteria_from_reasons_template,
|
275
|
+
{
|
276
|
+
"topic": topic,
|
277
|
+
"reasons": list(reasons),
|
278
|
+
"criteria_count": criteria_count,
|
279
|
+
},
|
280
|
+
)
|
281
|
+
),
|
282
|
+
validator=lambda resp: set(out)
|
283
|
+
if (out := JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=criteria_count))
|
284
|
+
else None,
|
285
|
+
**kwargs,
|
286
|
+
)
|
287
|
+
|
288
|
+
async def drafting_rating_weights_klee(
|
289
|
+
self,
|
290
|
+
topic: str,
|
291
|
+
criteria: Set[str],
|
292
|
+
**kwargs: Unpack[ValidateKwargs[float]],
|
293
|
+
) -> Dict[str, float]:
|
294
|
+
"""Drafts rating weights for a given topic and criteria using the Klee method.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
topic (str): The topic for the rating weights.
|
298
|
+
criteria (Set[str]): A set of criteria for the rating weights.
|
299
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
Dict[str, float]: A dictionary representing the drafted rating weights for each criterion.
|
303
|
+
"""
|
304
|
+
if len(criteria) < 2: # noqa: PLR2004
|
305
|
+
raise ValueError("At least two criteria are required to draft rating weights")
|
306
|
+
|
307
|
+
criteria_seq = list(criteria) # freeze the order
|
308
|
+
windows = windowed(criteria_seq, 2)
|
309
|
+
|
310
|
+
# get the importance multiplier indicating how important is second criterion compared to the first one
|
311
|
+
relative_weights = await self.aask_validate(
|
312
|
+
question=[
|
313
|
+
TEMPLATE_MANAGER.render_template(
|
314
|
+
CONFIG.templates.draft_rating_weights_klee_template,
|
315
|
+
{
|
316
|
+
"topic": topic,
|
317
|
+
"first": pair[0],
|
318
|
+
"second": pair[1],
|
319
|
+
},
|
320
|
+
)
|
321
|
+
for pair in windows
|
322
|
+
],
|
323
|
+
validator=lambda resp: JsonCapture.validate_with(resp, target_type=float),
|
324
|
+
**kwargs,
|
325
|
+
)
|
326
|
+
if not all(relative_weights):
|
327
|
+
raise ValueError(f"found illegal weight: {relative_weights}")
|
328
|
+
weights = [1.0]
|
329
|
+
for rw in relative_weights:
|
330
|
+
weights.append(weights[-1] * rw) # pyright: ignore [reportOperatorIssue]
|
331
|
+
total = sum(weights)
|
332
|
+
return dict(zip(criteria_seq, [w / total for w in weights], strict=True))
|
333
|
+
|
334
|
+
async def composite_score(
|
335
|
+
self,
|
336
|
+
topic: str,
|
337
|
+
to_rate: List[str],
|
338
|
+
criteria: Optional[Set[str]] = None,
|
339
|
+
weights: Optional[Dict[str, float]] = None,
|
340
|
+
manual: Optional[Dict[str, str]] = None,
|
341
|
+
approx: bool = False,
|
342
|
+
**kwargs: Unpack[ValidateKwargs[List[Dict[str, float]]]],
|
343
|
+
) -> List[float]:
|
344
|
+
"""Calculates the composite scores for a list of items based on a given topic and criteria.
|
345
|
+
|
346
|
+
Args:
|
347
|
+
topic (str): The topic for the rating.
|
348
|
+
to_rate (List[str]): A list of strings to be rated.
|
349
|
+
criteria (Optional[Set[str]]): A set of criteria for the rating. Defaults to None.
|
350
|
+
weights (Optional[Dict[str, float]]): A dictionary of rating weights for each criterion. Defaults to None.
|
351
|
+
manual (Optional[Dict[str, str]]): A dictionary of manual ratings for each item. Defaults to None.
|
352
|
+
approx (bool): Whether to use approximate rating criteria. Defaults to False.
|
353
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
List[float]: A list of composite scores for the items.
|
357
|
+
"""
|
358
|
+
criteria = ok(
|
359
|
+
criteria
|
360
|
+
or (await self.draft_rating_criteria(topic, **override_kwargs(kwargs, default=None)) if approx else None)
|
361
|
+
or await self.draft_rating_criteria_from_examples(topic, to_rate, **override_kwargs(kwargs, default=None))
|
362
|
+
)
|
363
|
+
weights = ok(
|
364
|
+
weights or await self.drafting_rating_weights_klee(topic, criteria, **override_kwargs(kwargs, default=None))
|
365
|
+
)
|
366
|
+
logger.info(f"Criteria: {criteria}\nWeights: {weights}")
|
367
|
+
ratings_seq = await self.rate(to_rate, topic, criteria, manual, **kwargs)
|
368
|
+
|
369
|
+
return [sum(ratings[c] * weights[c] for c in criteria) for ratings in ratings_seq]
|
370
|
+
|
371
|
+
@overload
|
372
|
+
async def best(self, candidates: List[str], k: int = 1, **kwargs: Unpack[CompositeScoreKwargs]) -> List[str]: ...
|
373
|
+
|
374
|
+
@overload
|
375
|
+
async def best[T: Display](
|
376
|
+
self, candidates: List[T], k: int = 1, **kwargs: Unpack[CompositeScoreKwargs]
|
377
|
+
) -> List[T]: ...
|
378
|
+
|
379
|
+
async def best[T: Display](
|
380
|
+
self, candidates: List[str] | List[T], k: int = 1, **kwargs: Unpack[CompositeScoreKwargs]
|
381
|
+
) -> Optional[List[str] | List[T]]:
|
382
|
+
"""Choose the best candidates from the list of candidates based on the composite score.
|
383
|
+
|
384
|
+
Args:
|
385
|
+
k (int): The number of best candidates to choose.
|
386
|
+
candidates (List[str]): A list of candidates to choose from.
|
387
|
+
**kwargs (CompositeScoreKwargs): Additional keyword arguments for the composite score calculation.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
List[str]: The best candidates.
|
391
|
+
"""
|
392
|
+
if (leng := len(candidates)) == 0:
|
393
|
+
logger.warning(f"No candidates, got {leng}, return None.")
|
394
|
+
return None
|
395
|
+
|
396
|
+
if leng == 1:
|
397
|
+
logger.warning(f"Only one candidate, got {leng}, return it.")
|
398
|
+
return candidates
|
399
|
+
logger.info(f"Choose best {k} from {leng} candidates.")
|
400
|
+
|
401
|
+
rating_seq = await self.composite_score(
|
402
|
+
to_rate=[c.display() if isinstance(c, Display) else c for c in candidates], **kwargs
|
403
|
+
)
|
404
|
+
return [a[0] for a in sorted(zip(candidates, rating_seq, strict=True), key=lambda x: x[1], reverse=True)[:k]] # pyright: ignore [reportReturnType]
|