fabricatio 0.2.9.dev3__cp312-cp312-win_amd64.whl → 0.2.10__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +24 -114
- fabricatio/actions/article_rag.py +156 -18
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +40 -18
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/check.py +15 -9
- fabricatio/capabilities/correct.py +5 -6
- fabricatio/capabilities/rag.py +41 -231
- fabricatio/capabilities/rating.py +46 -40
- fabricatio/config.py +6 -4
- fabricatio/constants.py +20 -0
- fabricatio/decorators.py +23 -0
- fabricatio/fs/readers.py +20 -1
- fabricatio/models/adv_kwargs_types.py +35 -0
- fabricatio/models/events.py +6 -6
- fabricatio/models/extra/advanced_judge.py +4 -4
- fabricatio/models/extra/aricle_rag.py +170 -0
- fabricatio/models/extra/article_base.py +25 -211
- fabricatio/models/extra/article_essence.py +8 -7
- fabricatio/models/extra/article_main.py +98 -97
- fabricatio/models/extra/article_proposal.py +15 -14
- fabricatio/models/extra/patches.py +6 -6
- fabricatio/models/extra/problem.py +12 -17
- fabricatio/models/extra/rag.py +98 -0
- fabricatio/models/extra/rule.py +1 -2
- fabricatio/models/generic.py +53 -13
- fabricatio/models/kwargs_types.py +8 -36
- fabricatio/models/task.py +3 -3
- fabricatio/models/usages.py +85 -9
- fabricatio/parser.py +5 -5
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +137 -10
- fabricatio/utils.py +62 -4
- fabricatio-0.2.10.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/METADATA +1 -4
- fabricatio-0.2.10.dist-info/RECORD +64 -0
- fabricatio/models/utils.py +0 -148
- fabricatio-0.2.9.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.9.dev3.dist-info/RECORD +0 -61
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/licenses/LICENSE +0 -0
@@ -57,7 +57,7 @@ class Correct(Rating, Propose):
|
|
57
57
|
self.decide_solution(
|
58
58
|
ps,
|
59
59
|
**fallback_kwargs(
|
60
|
-
kwargs, topic=f"which solution is better to deal this problem {ps.problem.
|
60
|
+
kwargs, topic=f"which solution is better to deal this problem {ps.problem.description}\n\n"
|
61
61
|
),
|
62
62
|
)
|
63
63
|
for ps in improvement.problem_solutions
|
@@ -167,13 +167,12 @@ class Correct(Rating, Propose):
|
|
167
167
|
logger.info(f"Improvement {improvement.focused_on} not decided, start deciding...")
|
168
168
|
improvement = await self.decide_improvement(improvement, **override_kwargs(kwargs, default=None))
|
169
169
|
|
170
|
-
|
171
|
-
|
170
|
+
total = len(improvement.problem_solutions)
|
171
|
+
for idx, ps in enumerate(improvement.problem_solutions):
|
172
|
+
logger.info(f"[{idx + 1}/{total}] Fixing {obj.__class__.__name__} for problem `{ps.problem.name}`")
|
172
173
|
fixed_obj = await self.fix_troubled_obj(obj, ps, reference, **kwargs)
|
173
174
|
if fixed_obj is None:
|
174
|
-
logger.error(
|
175
|
-
f"Failed to fix troubling obj {obj.__class__.__name__} when deal with problem: {ps.problem.name}",
|
176
|
-
)
|
175
|
+
logger.error(f"[{idx + 1}/{total}] Failed to fix problem `{ps.problem.name}`")
|
177
176
|
return None
|
178
177
|
obj = fixed_obj
|
179
178
|
return obj
|
fabricatio/capabilities/rag.py
CHANGED
@@ -3,28 +3,22 @@
|
|
3
3
|
try:
|
4
4
|
from pymilvus import MilvusClient
|
5
5
|
except ImportError as e:
|
6
|
-
raise RuntimeError(
|
6
|
+
raise RuntimeError(
|
7
|
+
"pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
8
|
+
) from e
|
7
9
|
from functools import lru_cache
|
8
10
|
from operator import itemgetter
|
9
|
-
from
|
10
|
-
from pathlib import Path
|
11
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
|
11
|
+
from typing import List, Optional, Self, Type, Unpack
|
12
12
|
|
13
13
|
from more_itertools.recipes import flatten, unique
|
14
14
|
from pydantic import Field, PrivateAttr
|
15
15
|
|
16
16
|
from fabricatio.config import configs
|
17
17
|
from fabricatio.journal import logger
|
18
|
-
from fabricatio.models.
|
19
|
-
|
20
|
-
|
21
|
-
EmbeddingKwargs,
|
22
|
-
FetchKwargs,
|
23
|
-
LLMKwargs,
|
24
|
-
RetrievalKwargs,
|
25
|
-
)
|
18
|
+
from fabricatio.models.adv_kwargs_types import CollectionConfigKwargs, FetchKwargs
|
19
|
+
from fabricatio.models.extra.rag import MilvusDataBase
|
20
|
+
from fabricatio.models.kwargs_types import ChooseKwargs
|
26
21
|
from fabricatio.models.usages import EmbeddingUsage
|
27
|
-
from fabricatio.models.utils import MilvusData
|
28
22
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
29
23
|
from fabricatio.utils import ok
|
30
24
|
|
@@ -78,40 +72,6 @@ class RAG(EmbeddingUsage):
|
|
78
72
|
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
79
73
|
return self
|
80
74
|
|
81
|
-
@overload
|
82
|
-
async def pack(
|
83
|
-
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
84
|
-
) -> List[MilvusData]: ...
|
85
|
-
@overload
|
86
|
-
async def pack(
|
87
|
-
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
88
|
-
) -> MilvusData: ...
|
89
|
-
|
90
|
-
async def pack(
|
91
|
-
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
92
|
-
) -> List[MilvusData] | MilvusData:
|
93
|
-
"""Asynchronously generates MilvusData objects for the given input text.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
97
|
-
subject (Optional[str]): The subject of the input text. Defaults to None.
|
98
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
99
|
-
|
100
|
-
Returns:
|
101
|
-
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
102
|
-
"""
|
103
|
-
if isinstance(input_text, str):
|
104
|
-
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
105
|
-
vecs = await self.vectorize(input_text, **kwargs)
|
106
|
-
return [
|
107
|
-
MilvusData(
|
108
|
-
vector=vec,
|
109
|
-
text=text,
|
110
|
-
subject=subject,
|
111
|
-
)
|
112
|
-
for text, vec in zip(input_text, vecs, strict=True)
|
113
|
-
]
|
114
|
-
|
115
75
|
def view(
|
116
76
|
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
|
117
77
|
) -> Self:
|
@@ -152,29 +112,27 @@ class RAG(EmbeddingUsage):
|
|
152
112
|
Returns:
|
153
113
|
str: The name of the collection being viewed.
|
154
114
|
"""
|
155
|
-
|
156
|
-
raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
|
157
|
-
return self.target_collection
|
115
|
+
return ok(self.target_collection, "No collection is being viewed. Have you called `self.view()`?")
|
158
116
|
|
159
|
-
def add_document[D:
|
160
|
-
self, data: D |
|
117
|
+
async def add_document[D: MilvusDataBase](
|
118
|
+
self, data: List[D] | D, collection_name: Optional[str] = None, flush: bool = False
|
161
119
|
) -> Self:
|
162
120
|
"""Adds a document to the specified collection.
|
163
121
|
|
164
122
|
Args:
|
165
|
-
data (Union[Dict[str, Any],
|
123
|
+
data (Union[Dict[str, Any], MilvusDataBase] | List[Union[Dict[str, Any], MilvusDataBase]]): The data to be added to the collection.
|
166
124
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
167
125
|
flush (bool): Whether to flush the collection after insertion.
|
168
126
|
|
169
127
|
Returns:
|
170
128
|
Self: The current instance, allowing for method chaining.
|
171
129
|
"""
|
172
|
-
if isinstance(data,
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
130
|
+
if isinstance(data, MilvusDataBase):
|
131
|
+
data = [data]
|
132
|
+
|
133
|
+
data_vec = await self.vectorize([d.prepare_vectorization() for d in data])
|
134
|
+
prepared_data = [d.prepare_insertion(vec) for d, vec in zip(data, data_vec, strict=True)]
|
135
|
+
|
178
136
|
c_name = collection_name or self.safe_target_collection
|
179
137
|
self.check_client().client.insert(c_name, prepared_data)
|
180
138
|
|
@@ -183,84 +141,33 @@ class RAG(EmbeddingUsage):
|
|
183
141
|
self.client.flush(c_name)
|
184
142
|
return self
|
185
143
|
|
186
|
-
async def
|
187
|
-
self,
|
188
|
-
source: List[PathLike] | PathLike,
|
189
|
-
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
190
|
-
collection_name: Optional[str] = None,
|
191
|
-
) -> Self:
|
192
|
-
"""Consume a file and add its content to the collection.
|
193
|
-
|
194
|
-
Args:
|
195
|
-
source (PathLike): The path to the file to be consumed.
|
196
|
-
reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
|
197
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
198
|
-
|
199
|
-
Returns:
|
200
|
-
Self: The current instance, allowing for method chaining.
|
201
|
-
"""
|
202
|
-
if not isinstance(source, list):
|
203
|
-
source = [source]
|
204
|
-
return await self.consume_string([reader(s) for s in source], collection_name)
|
205
|
-
|
206
|
-
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
207
|
-
"""Consume a string and add it to the collection.
|
208
|
-
|
209
|
-
Args:
|
210
|
-
text (List[str] | str): The text to be added to the collection.
|
211
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
212
|
-
|
213
|
-
Returns:
|
214
|
-
Self: The current instance, allowing for method chaining.
|
215
|
-
"""
|
216
|
-
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
217
|
-
return self
|
218
|
-
|
219
|
-
@overload
|
220
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
144
|
+
async def afetch_document[D: MilvusDataBase](
|
221
145
|
self,
|
222
146
|
vecs: List[List[float]],
|
223
|
-
|
147
|
+
document_model: Type[D],
|
224
148
|
collection_name: Optional[str] = None,
|
225
149
|
similarity_threshold: float = 0.37,
|
226
150
|
result_per_query: int = 10,
|
227
|
-
) -> List[
|
228
|
-
|
229
|
-
@overload
|
230
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
231
|
-
self,
|
232
|
-
vecs: List[List[float]],
|
233
|
-
desired_fields: str,
|
234
|
-
collection_name: Optional[str] = None,
|
235
|
-
similarity_threshold: float = 0.37,
|
236
|
-
result_per_query: int = 10,
|
237
|
-
) -> List[V]: ...
|
238
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
239
|
-
self,
|
240
|
-
vecs: List[List[float]],
|
241
|
-
desired_fields: List[str] | str,
|
242
|
-
collection_name: Optional[str] = None,
|
243
|
-
similarity_threshold: float = 0.37,
|
244
|
-
result_per_query: int = 10,
|
245
|
-
) -> List[Dict[str, Any]] | List[V]:
|
246
|
-
"""Fetch data from the collection.
|
151
|
+
) -> List[D]:
|
152
|
+
"""Asynchronously fetches documents from a Milvus database based on input vectors.
|
247
153
|
|
248
154
|
Args:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
155
|
+
vecs (List[List[float]]): A list of vectors to search for in the database.
|
156
|
+
document_model (Type[D]): The model class used to convert fetched data into document objects.
|
157
|
+
collection_name (Optional[str]): The name of the collection to search within.
|
158
|
+
If None, the currently viewed collection is used.
|
159
|
+
similarity_threshold (float): The similarity threshold for vector search. Defaults to 0.37.
|
160
|
+
result_per_query (int): The maximum number of results to return per query. Defaults to 10.
|
254
161
|
|
255
162
|
Returns:
|
256
|
-
|
163
|
+
List[D]: A list of document objects created from the fetched data.
|
257
164
|
"""
|
258
165
|
# Step 1: Search for vectors
|
259
166
|
search_results = self.check_client().client.search(
|
260
167
|
collection_name or self.safe_target_collection,
|
261
168
|
vecs,
|
262
169
|
search_params={"radius": similarity_threshold},
|
263
|
-
output_fields=
|
170
|
+
output_fields=list(document_model.model_fields),
|
264
171
|
limit=result_per_query,
|
265
172
|
)
|
266
173
|
|
@@ -270,104 +177,42 @@ class RAG(EmbeddingUsage):
|
|
270
177
|
# Step 3: Sort by distance (descending)
|
271
178
|
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
272
179
|
|
273
|
-
logger.debug(
|
180
|
+
logger.debug(
|
181
|
+
f"Fetched {len(sorted_results)} document,searched similarities: {[t['distance'] for t in sorted_results]}"
|
182
|
+
)
|
274
183
|
# Step 4: Extract the entities
|
275
184
|
resp = [result["entity"] for result in sorted_results]
|
276
185
|
|
277
|
-
|
278
|
-
return resp
|
279
|
-
return [r.get(desired_fields) for r in resp] # extract the single field as list
|
186
|
+
return document_model.from_sequence(resp)
|
280
187
|
|
281
|
-
async def aretrieve(
|
188
|
+
async def aretrieve[D: MilvusDataBase](
|
282
189
|
self,
|
283
190
|
query: List[str] | str,
|
191
|
+
document_model: Type[D],
|
284
192
|
final_limit: int = 20,
|
285
193
|
**kwargs: Unpack[FetchKwargs],
|
286
|
-
) -> List[
|
194
|
+
) -> List[D]:
|
287
195
|
"""Retrieve data from the collection.
|
288
196
|
|
289
197
|
Args:
|
290
198
|
query (List[str] | str): The query to be used for retrieval.
|
199
|
+
document_model (Type[D]): The model class used to convert retrieved data into document objects.
|
291
200
|
final_limit (int): The final limit on the number of results to return.
|
292
201
|
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
293
202
|
|
294
203
|
Returns:
|
295
|
-
List[
|
204
|
+
List[D]: A list of document objects created from the retrieved data.
|
296
205
|
"""
|
297
206
|
if isinstance(query, str):
|
298
207
|
query = [query]
|
299
|
-
return
|
300
|
-
"List[str]",
|
208
|
+
return (
|
301
209
|
await self.afetch_document(
|
302
210
|
vecs=(await self.vectorize(query)),
|
303
|
-
|
211
|
+
document_model=document_model,
|
304
212
|
**kwargs,
|
305
|
-
)
|
213
|
+
)
|
306
214
|
)[:final_limit]
|
307
215
|
|
308
|
-
async def aretrieve_compact(
|
309
|
-
self,
|
310
|
-
query: List[str] | str,
|
311
|
-
**kwargs: Unpack[RetrievalKwargs],
|
312
|
-
) -> str:
|
313
|
-
"""Retrieve data from the collection and format it for display.
|
314
|
-
|
315
|
-
Args:
|
316
|
-
query (List[str] | str): The query to be used for retrieval.
|
317
|
-
**kwargs (Unpack[RetrievalKwargs]): Additional keyword arguments for retrieval.
|
318
|
-
|
319
|
-
Returns:
|
320
|
-
str: A formatted string containing the retrieved data.
|
321
|
-
"""
|
322
|
-
return TEMPLATE_MANAGER.render_template(
|
323
|
-
configs.templates.retrieved_display_template, {"docs": (await self.aretrieve(query, **kwargs))}
|
324
|
-
)
|
325
|
-
|
326
|
-
async def aask_retrieved(
|
327
|
-
self,
|
328
|
-
question: str,
|
329
|
-
query: Optional[List[str] | str] = None,
|
330
|
-
collection_name: Optional[str] = None,
|
331
|
-
extra_system_message: str = "",
|
332
|
-
result_per_query: int = 10,
|
333
|
-
final_limit: int = 20,
|
334
|
-
similarity_threshold: float = 0.37,
|
335
|
-
**kwargs: Unpack[LLMKwargs],
|
336
|
-
) -> str:
|
337
|
-
"""Asks a question by retrieving relevant documents based on the provided query.
|
338
|
-
|
339
|
-
This method performs document retrieval using the given query, then asks the
|
340
|
-
specified question using the retrieved documents as context.
|
341
|
-
|
342
|
-
Args:
|
343
|
-
question (str): The question to be asked.
|
344
|
-
query (List[str] | str): The query or list of queries used for document retrieval.
|
345
|
-
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
346
|
-
If not provided, the currently viewed collection is used.
|
347
|
-
extra_system_message (str): An additional system message to be included in the prompt.
|
348
|
-
result_per_query (int): The number of results to return per query. Default is 10.
|
349
|
-
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
350
|
-
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
351
|
-
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
352
|
-
|
353
|
-
Returns:
|
354
|
-
str: A string response generated after asking with the context of retrieved documents.
|
355
|
-
"""
|
356
|
-
rendered = await self.aretrieve_compact(
|
357
|
-
query or question,
|
358
|
-
final_limit=final_limit,
|
359
|
-
collection_name=collection_name,
|
360
|
-
result_per_query=result_per_query,
|
361
|
-
similarity_threshold=similarity_threshold,
|
362
|
-
)
|
363
|
-
|
364
|
-
logger.debug(f"Retrieved Documents: \n{rendered}")
|
365
|
-
return await self.aask(
|
366
|
-
question,
|
367
|
-
f"{rendered}\n\n{extra_system_message}",
|
368
|
-
**kwargs,
|
369
|
-
)
|
370
|
-
|
371
216
|
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> Optional[List[str]]:
|
372
217
|
"""Refines the given question using a template.
|
373
218
|
|
@@ -385,38 +230,3 @@ class RAG(EmbeddingUsage):
|
|
385
230
|
),
|
386
231
|
**kwargs,
|
387
232
|
)
|
388
|
-
|
389
|
-
async def aask_refined(
|
390
|
-
self,
|
391
|
-
question: str,
|
392
|
-
collection_name: Optional[str] = None,
|
393
|
-
extra_system_message: str = "",
|
394
|
-
result_per_query: int = 10,
|
395
|
-
final_limit: int = 20,
|
396
|
-
similarity_threshold: float = 0.37,
|
397
|
-
**kwargs: Unpack[LLMKwargs],
|
398
|
-
) -> str:
|
399
|
-
"""Asks a question using a refined query based on the provided question.
|
400
|
-
|
401
|
-
Args:
|
402
|
-
question (str): The question to be asked.
|
403
|
-
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
404
|
-
extra_system_message (str): An additional system message to be included in the prompt.
|
405
|
-
result_per_query (int): The number of results to return per query. Default is 10.
|
406
|
-
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
407
|
-
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
408
|
-
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
409
|
-
|
410
|
-
Returns:
|
411
|
-
str: A string response generated after asking with the refined question.
|
412
|
-
"""
|
413
|
-
return await self.aask_retrieved(
|
414
|
-
question,
|
415
|
-
await self.arefined_query(question, **kwargs),
|
416
|
-
collection_name=collection_name,
|
417
|
-
extra_system_message=extra_system_message,
|
418
|
-
result_per_query=result_per_query,
|
419
|
-
final_limit=final_limit,
|
420
|
-
similarity_threshold=similarity_threshold,
|
421
|
-
**kwargs,
|
422
|
-
)
|
@@ -5,19 +5,19 @@ from random import sample
|
|
5
5
|
from typing import Dict, List, Optional, Set, Tuple, Union, Unpack, overload
|
6
6
|
|
7
7
|
from more_itertools import flatten, windowed
|
8
|
-
from pydantic import NonNegativeInt, PositiveInt
|
8
|
+
from pydantic import Field, NonNegativeInt, PositiveInt, create_model
|
9
9
|
|
10
|
+
from fabricatio.capabilities.propose import Propose
|
10
11
|
from fabricatio.config import configs
|
11
12
|
from fabricatio.journal import logger
|
12
|
-
from fabricatio.models.generic import Display
|
13
|
+
from fabricatio.models.generic import Display, ProposedAble
|
13
14
|
from fabricatio.models.kwargs_types import CompositeScoreKwargs, ValidateKwargs
|
14
|
-
from fabricatio.models.usages import LLMUsage
|
15
15
|
from fabricatio.parser import JsonCapture
|
16
16
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
17
|
-
from fabricatio.utils import ok, override_kwargs
|
17
|
+
from fabricatio.utils import fallback_kwargs, ok, override_kwargs
|
18
18
|
|
19
19
|
|
20
|
-
class Rating(
|
20
|
+
class Rating(Propose):
|
21
21
|
"""A class that provides functionality to rate tasks based on a rating manual and score range.
|
22
22
|
|
23
23
|
References:
|
@@ -30,7 +30,7 @@ class Rating(LLMUsage):
|
|
30
30
|
rating_manual: Dict[str, str],
|
31
31
|
score_range: Tuple[float, float],
|
32
32
|
**kwargs: Unpack[ValidateKwargs[Dict[str, float]]],
|
33
|
-
) ->
|
33
|
+
) -> Dict[str, float] | List[Dict[str, float]] | List[Optional[Dict[str, float]]] | None:
|
34
34
|
"""Rate a given string based on a rating manual and score range.
|
35
35
|
|
36
36
|
Args:
|
@@ -42,45 +42,49 @@ class Rating(LLMUsage):
|
|
42
42
|
Returns:
|
43
43
|
Dict[str, float]: A dictionary with the ratings for each dimension.
|
44
44
|
"""
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
"to_rate": to_rate,
|
62
|
-
"min_score": score_range[0],
|
63
|
-
"max_score": score_range[1],
|
64
|
-
"rating_manual": rating_manual,
|
65
|
-
},
|
45
|
+
min_score, max_score = score_range
|
46
|
+
tip = (max_score - min_score) / 9
|
47
|
+
|
48
|
+
model = create_model( # pyright: ignore [reportCallIssue]
|
49
|
+
"RatingResult",
|
50
|
+
__base__=ProposedAble,
|
51
|
+
__doc__=f"The rating result contains the scores against each criterion, with min_score={min_score} and max_score={max_score}.",
|
52
|
+
**{ # pyright: ignore [reportArgumentType]
|
53
|
+
criterion: (
|
54
|
+
float,
|
55
|
+
Field(
|
56
|
+
ge=min_score,
|
57
|
+
le=max_score,
|
58
|
+
description=desc,
|
59
|
+
examples=[round(min_score + tip * i, 2) for i in range(10)],
|
60
|
+
),
|
66
61
|
)
|
62
|
+
for criterion, desc in rating_manual.items()
|
63
|
+
},
|
64
|
+
)
|
65
|
+
|
66
|
+
res = await self.propose(
|
67
|
+
model,
|
68
|
+
TEMPLATE_MANAGER.render_template(
|
69
|
+
configs.templates.rate_fine_grind_template,
|
70
|
+
{"to_rate": to_rate, "min_score": min_score, "max_score": max_score},
|
67
71
|
)
|
68
72
|
if isinstance(to_rate, str)
|
69
73
|
else [
|
70
74
|
TEMPLATE_MANAGER.render_template(
|
71
75
|
configs.templates.rate_fine_grind_template,
|
72
|
-
{
|
73
|
-
"to_rate": item,
|
74
|
-
"min_score": score_range[0],
|
75
|
-
"max_score": score_range[1],
|
76
|
-
"rating_manual": rating_manual,
|
77
|
-
},
|
76
|
+
{"to_rate": t, "min_score": min_score, "max_score": max_score},
|
78
77
|
)
|
79
|
-
for
|
78
|
+
for t in to_rate
|
80
79
|
],
|
81
|
-
|
82
|
-
**kwargs,
|
80
|
+
**override_kwargs(kwargs, default=None),
|
83
81
|
)
|
82
|
+
default = kwargs.get("default")
|
83
|
+
if isinstance(res, list):
|
84
|
+
return [r.model_dump() if r else default for r in res]
|
85
|
+
if res is None:
|
86
|
+
return default
|
87
|
+
return res.model_dump()
|
84
88
|
|
85
89
|
@overload
|
86
90
|
async def rate(
|
@@ -112,7 +116,7 @@ class Rating(LLMUsage):
|
|
112
116
|
manual: Optional[Dict[str, str]] = None,
|
113
117
|
score_range: Tuple[float, float] = (0.0, 1.0),
|
114
118
|
**kwargs: Unpack[ValidateKwargs],
|
115
|
-
) ->
|
119
|
+
) -> Dict[str, float] | List[Dict[str, float]] | List[Optional[Dict[str, float]]] | None:
|
116
120
|
"""Rate a given string or a sequence of strings based on a topic, criteria, and score range.
|
117
121
|
|
118
122
|
Args:
|
@@ -133,7 +137,7 @@ class Rating(LLMUsage):
|
|
133
137
|
or dict(zip(criteria, criteria, strict=True))
|
134
138
|
)
|
135
139
|
|
136
|
-
return await self.rate_fine_grind(to_rate, manual, score_range, **kwargs)
|
140
|
+
return await self.rate_fine_grind(to_rate, manual, score_range, **fallback_kwargs(kwargs, co_extractor={}))
|
137
141
|
|
138
142
|
async def draft_rating_manual(
|
139
143
|
self, topic: str, criteria: Optional[Set[str]] = None, **kwargs: Unpack[ValidateKwargs[Dict[str, str]]]
|
@@ -244,7 +248,7 @@ class Rating(LLMUsage):
|
|
244
248
|
|
245
249
|
# extract reasons from the comparison of ordered pairs of extracted from examples
|
246
250
|
reasons = flatten(
|
247
|
-
await self.aask_validate(
|
251
|
+
await self.aask_validate( # pyright: ignore [reportArgumentType]
|
248
252
|
question=[
|
249
253
|
TEMPLATE_MANAGER.render_template(
|
250
254
|
configs.templates.extract_reasons_from_examples_template,
|
@@ -319,9 +323,11 @@ class Rating(LLMUsage):
|
|
319
323
|
validator=lambda resp: JsonCapture.validate_with(resp, target_type=float),
|
320
324
|
**kwargs,
|
321
325
|
)
|
326
|
+
if not all(relative_weights):
|
327
|
+
raise ValueError(f"found illegal weight: {relative_weights}")
|
322
328
|
weights = [1.0]
|
323
329
|
for rw in relative_weights:
|
324
|
-
weights.append(weights[-1] * rw)
|
330
|
+
weights.append(weights[-1] * rw) # pyright: ignore [reportOperatorIssue]
|
325
331
|
total = sum(weights)
|
326
332
|
return dict(zip(criteria_seq, [w / total for w in weights], strict=True))
|
327
333
|
|
fabricatio/config.py
CHANGED
@@ -44,7 +44,7 @@ class LLMConfig(BaseModel):
|
|
44
44
|
top_p (NonNegativeFloat): The top p of the LLM model. Controls diversity via nucleus sampling. Set to 0.35 as per request.
|
45
45
|
generation_count (PositiveInt): The number of generations to generate. Default is 1.
|
46
46
|
stream (bool): Whether to stream the LLM model's response. Default is False.
|
47
|
-
max_tokens (PositiveInt): The maximum number of tokens to generate.
|
47
|
+
max_tokens (PositiveInt): The maximum number of tokens to generate.
|
48
48
|
"""
|
49
49
|
|
50
50
|
model_config = ConfigDict(use_attribute_docstrings=True)
|
@@ -79,15 +79,17 @@ class LLMConfig(BaseModel):
|
|
79
79
|
"""Whether to stream the LLM model's response. Default is False."""
|
80
80
|
|
81
81
|
max_tokens: Optional[PositiveInt] = Field(default=None)
|
82
|
-
"""The maximum number of tokens to generate.
|
82
|
+
"""The maximum number of tokens to generate."""
|
83
83
|
|
84
84
|
rpm: Optional[PositiveInt] = Field(default=100)
|
85
85
|
"""The rate limit of the LLM model in requests per minute. None means not checked."""
|
86
86
|
|
87
87
|
tpm: Optional[PositiveInt] = Field(default=1000000)
|
88
88
|
"""The rate limit of the LLM model in tokens per minute. None means not checked."""
|
89
|
-
|
90
|
-
|
89
|
+
presence_penalty:Optional[PositiveFloat]=None
|
90
|
+
"""The presence penalty of the LLM model."""
|
91
|
+
frequency_penalty:Optional[PositiveFloat]=None
|
92
|
+
"""The frequency penalty of the LLM model."""
|
91
93
|
class EmbeddingConfig(BaseModel):
|
92
94
|
"""Embedding configuration class."""
|
93
95
|
|
fabricatio/constants.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
"""A module containing constants used throughout the library."""
|
2
|
+
from enum import StrEnum
|
3
|
+
|
4
|
+
|
5
|
+
class TaskStatus(StrEnum):
|
6
|
+
"""An enumeration representing the status of a task.
|
7
|
+
|
8
|
+
Attributes:
|
9
|
+
Pending: The task is pending.
|
10
|
+
Running: The task is currently running.
|
11
|
+
Finished: The task has been successfully completed.
|
12
|
+
Failed: The task has failed.
|
13
|
+
Cancelled: The task has been cancelled.
|
14
|
+
"""
|
15
|
+
|
16
|
+
Pending = "pending"
|
17
|
+
Running = "running"
|
18
|
+
Finished = "finished"
|
19
|
+
Failed = "failed"
|
20
|
+
Cancelled = "cancelled"
|
fabricatio/decorators.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from asyncio import iscoroutinefunction
|
4
4
|
from functools import wraps
|
5
|
+
from importlib.util import find_spec
|
5
6
|
from inspect import signature
|
6
7
|
from shutil import which
|
7
8
|
from types import ModuleType
|
@@ -209,3 +210,25 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
209
210
|
return result
|
210
211
|
|
211
212
|
return _wrapper
|
213
|
+
|
214
|
+
|
215
|
+
def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
216
|
+
"""Check if a package exists in the current environment.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
package_name (str): The name of the package to check.
|
220
|
+
msg (str): The message to display if the package is not found.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
bool: True if the package exists, False otherwise.
|
224
|
+
"""
|
225
|
+
|
226
|
+
def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
|
227
|
+
def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
228
|
+
if find_spec(package_name):
|
229
|
+
return func(*args, **kwargs)
|
230
|
+
raise RuntimeError(msg)
|
231
|
+
|
232
|
+
return _inner
|
233
|
+
|
234
|
+
return _wrapper
|
fabricatio/fs/readers.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
"""Filesystem readers for Fabricatio."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Dict
|
4
|
+
from typing import Dict, List, Tuple
|
5
5
|
|
6
6
|
import orjson
|
7
|
+
import regex
|
7
8
|
from magika import Magika
|
8
9
|
|
9
10
|
from fabricatio.config import configs
|
@@ -44,3 +45,21 @@ def safe_json_read(path: Path | str) -> Dict:
|
|
44
45
|
except (orjson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
|
45
46
|
logger.error(f"Failed to read file {path}: {e!s}")
|
46
47
|
return {}
|
48
|
+
|
49
|
+
|
50
|
+
def extract_sections(string: str, level: int, section_char: str = "#") -> List[Tuple[str, str]]:
|
51
|
+
"""Extract sections from markdown-style text by header level.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
string (str): Input text to parse
|
55
|
+
level (int): Header level (e.g., 1 for '#', 2 for '##')
|
56
|
+
section_char (str, optional): The character used for headers (default: '#')
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
List[Tuple[str, str]]: List of (header_text, section_content) tuples
|
60
|
+
"""
|
61
|
+
return regex.findall(
|
62
|
+
r"^%s{%d}\s+(.+?)\n((?:(?!^%s{%d}\s).|\n)*)" % (section_char, level, section_char, level),
|
63
|
+
string,
|
64
|
+
regex.MULTILINE,
|
65
|
+
)
|