fabricatio 0.2.6.dev5__cp39-cp39-win_amd64.whl → 0.2.6.dev7__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/_rust.cp39-win_amd64.pyd +0 -0
- fabricatio/actions/article.py +7 -4
- fabricatio/capabilities/rag.py +6 -6
- fabricatio/config.py +5 -5
- fabricatio/models/action.py +1 -2
- fabricatio/models/extra.py +551 -64
- fabricatio/models/kwargs_types.py +17 -6
- fabricatio/models/role.py +1 -2
- fabricatio/models/usages.py +58 -27
- fabricatio/models/utils.py +21 -0
- fabricatio/parser.py +1 -0
- {fabricatio-0.2.6.dev5.data → fabricatio-0.2.6.dev7.data}/scripts/tdown.exe +0 -0
- {fabricatio-0.2.6.dev5.dist-info → fabricatio-0.2.6.dev7.dist-info}/METADATA +1 -1
- {fabricatio-0.2.6.dev5.dist-info → fabricatio-0.2.6.dev7.dist-info}/RECORD +16 -17
- {fabricatio-0.2.6.dev5.dist-info → fabricatio-0.2.6.dev7.dist-info}/WHEEL +1 -1
- fabricatio/capabilities/covalidate.py +0 -160
- {fabricatio-0.2.6.dev5.dist-info → fabricatio-0.2.6.dev7.dist-info}/licenses/LICENSE +0 -0
@@ -1,19 +1,29 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
+
from importlib.util import find_spec
|
3
4
|
from typing import Any, Required, TypedDict
|
4
5
|
|
5
6
|
from litellm.caching.caching import CacheMode
|
6
7
|
from litellm.types.caching import CachingSupportedCallTypes
|
7
8
|
|
9
|
+
if find_spec("pymilvus"):
|
10
|
+
from pymilvus import CollectionSchema
|
11
|
+
from pymilvus.milvus_client import IndexParams
|
8
12
|
|
9
|
-
class
|
10
|
-
|
13
|
+
class CollectionConfigKwargs(TypedDict, total=False):
|
14
|
+
"""Configuration parameters for a vector collection.
|
11
15
|
|
12
|
-
|
13
|
-
|
16
|
+
These arguments are typically used when configuring connections to vector databases.
|
17
|
+
"""
|
14
18
|
|
15
|
-
|
16
|
-
|
19
|
+
dimension: int | None
|
20
|
+
primary_field_name: str
|
21
|
+
id_type: str
|
22
|
+
vector_field_name: str
|
23
|
+
metric_type: str
|
24
|
+
timeout: float | None
|
25
|
+
schema: CollectionSchema | None
|
26
|
+
index_params: IndexParams | None
|
17
27
|
|
18
28
|
|
19
29
|
class FetchKwargs(TypedDict, total=False):
|
@@ -81,6 +91,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
|
|
81
91
|
|
82
92
|
default: T
|
83
93
|
max_validations: int
|
94
|
+
co_extractor: GenerateKwargs
|
84
95
|
|
85
96
|
|
86
97
|
# noinspection PyTypedDict
|
fabricatio/models/role.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
from typing import Any, Self, Set
|
4
4
|
|
5
5
|
from fabricatio.capabilities.correct import Correct
|
6
|
-
from fabricatio.capabilities.covalidate import CoValidate
|
7
6
|
from fabricatio.capabilities.task import HandleTask, ProposeTask
|
8
7
|
from fabricatio.core import env
|
9
8
|
from fabricatio.journal import logger
|
@@ -13,7 +12,7 @@ from fabricatio.models.tool import ToolBox
|
|
13
12
|
from pydantic import Field
|
14
13
|
|
15
14
|
|
16
|
-
class Role(ProposeTask, HandleTask, Correct
|
15
|
+
class Role(ProposeTask, HandleTask, Correct):
|
17
16
|
"""Class that represents a role with a registry of events and workflows.
|
18
17
|
|
19
18
|
A Role serves as a container for workflows, managing their registration to events
|
fabricatio/models/usages.py
CHANGED
@@ -12,9 +12,9 @@ from fabricatio.models.generic import ScopedConfig, WithBriefing
|
|
12
12
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
13
13
|
from fabricatio.models.task import Task
|
14
14
|
from fabricatio.models.tool import Tool, ToolBox
|
15
|
-
from fabricatio.models.utils import Messages
|
15
|
+
from fabricatio.models.utils import Messages, ok
|
16
16
|
from fabricatio.parser import GenericCapture, JsonCapture
|
17
|
-
from litellm import Router, stream_chunk_builder
|
17
|
+
from litellm import Router, stream_chunk_builder # pyright: ignore [reportPrivateImportUsage]
|
18
18
|
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
19
19
|
from litellm.types.utils import (
|
20
20
|
Choices,
|
@@ -70,14 +70,22 @@ class LLMUsage(ScopedConfig):
|
|
70
70
|
"""
|
71
71
|
# Call the underlying asynchronous completion function with the provided and default parameters
|
72
72
|
# noinspection PyTypeChecker,PydanticTypeChecker
|
73
|
-
|
74
73
|
return await self._deploy(
|
75
74
|
Deployment(
|
76
|
-
model_name=(
|
75
|
+
model_name=(
|
76
|
+
m_name := ok(
|
77
|
+
kwargs.get("model") or self.llm_model or configs.llm.model, "model name is not set at any place"
|
78
|
+
)
|
79
|
+
), # pyright: ignore [reportCallIssue]
|
77
80
|
litellm_params=(
|
78
81
|
p := LiteLLM_Params(
|
79
|
-
api_key=(
|
80
|
-
|
82
|
+
api_key=ok(
|
83
|
+
self.llm_api_key or configs.llm.api_key, "llm api key is not set at any place"
|
84
|
+
).get_secret_value(),
|
85
|
+
api_base=ok(
|
86
|
+
self.llm_api_endpoint or configs.llm.api_endpoint,
|
87
|
+
"llm api endpoint is not set at any place",
|
88
|
+
).unicode_string(),
|
81
89
|
model=m_name,
|
82
90
|
tpm=self.llm_tpm or configs.llm.tpm,
|
83
91
|
rpm=self.llm_rpm or configs.llm.rpm,
|
@@ -88,14 +96,14 @@ class LLMUsage(ScopedConfig):
|
|
88
96
|
model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
|
89
97
|
)
|
90
98
|
).acompletion(
|
91
|
-
messages=messages,
|
99
|
+
messages=messages, # pyright: ignore [reportArgumentType]
|
92
100
|
n=n or self.llm_generation_count or configs.llm.generation_count,
|
93
101
|
model=m_name,
|
94
102
|
temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
|
95
103
|
stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
|
96
104
|
top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
|
97
105
|
max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
|
98
|
-
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
106
|
+
stream=ok(kwargs.get("stream") or self.llm_stream or configs.llm.stream, "stream is not set at any place"),
|
99
107
|
cache={
|
100
108
|
"no-cache": kwargs.get("no_cache"),
|
101
109
|
"no-store": kwargs.get("no_store"),
|
@@ -196,15 +204,15 @@ class LLMUsage(ScopedConfig):
|
|
196
204
|
for q, sm in zip(q_seq, sm_seq, strict=True)
|
197
205
|
]
|
198
206
|
)
|
199
|
-
return [r[0].message.content for r in res]
|
207
|
+
return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
|
200
208
|
case (list(q_seq), str(sm)):
|
201
209
|
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
|
202
|
-
return [r[0].message.content for r in res]
|
210
|
+
return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
|
203
211
|
case (str(q), list(sm_seq)):
|
204
212
|
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
|
205
|
-
return [r[0].message.content for r in res]
|
213
|
+
return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
|
206
214
|
case (str(q), str(sm)):
|
207
|
-
return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
|
215
|
+
return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
|
208
216
|
case _:
|
209
217
|
raise RuntimeError("Should not reach here.")
|
210
218
|
|
@@ -215,6 +223,7 @@ class LLMUsage(ScopedConfig):
|
|
215
223
|
validator: Callable[[str], T | None],
|
216
224
|
default: T = ...,
|
217
225
|
max_validations: PositiveInt = 2,
|
226
|
+
co_extractor: Optional[GenerateKwargs] = None,
|
218
227
|
**kwargs: Unpack[GenerateKwargs],
|
219
228
|
) -> T: ...
|
220
229
|
@overload
|
@@ -224,6 +233,7 @@ class LLMUsage(ScopedConfig):
|
|
224
233
|
validator: Callable[[str], T | None],
|
225
234
|
default: T = ...,
|
226
235
|
max_validations: PositiveInt = 2,
|
236
|
+
co_extractor: Optional[GenerateKwargs] = None,
|
227
237
|
**kwargs: Unpack[GenerateKwargs],
|
228
238
|
) -> List[T]: ...
|
229
239
|
@overload
|
@@ -233,6 +243,7 @@ class LLMUsage(ScopedConfig):
|
|
233
243
|
validator: Callable[[str], T | None],
|
234
244
|
default: None = None,
|
235
245
|
max_validations: PositiveInt = 2,
|
246
|
+
co_extractor: Optional[GenerateKwargs] = None,
|
236
247
|
**kwargs: Unpack[GenerateKwargs],
|
237
248
|
) -> Optional[T]: ...
|
238
249
|
|
@@ -243,6 +254,7 @@ class LLMUsage(ScopedConfig):
|
|
243
254
|
validator: Callable[[str], T | None],
|
244
255
|
default: None = None,
|
245
256
|
max_validations: PositiveInt = 2,
|
257
|
+
co_extractor: Optional[GenerateKwargs] = None,
|
246
258
|
**kwargs: Unpack[GenerateKwargs],
|
247
259
|
) -> List[Optional[T]]: ...
|
248
260
|
|
@@ -252,6 +264,7 @@ class LLMUsage(ScopedConfig):
|
|
252
264
|
validator: Callable[[str], T | None],
|
253
265
|
default: Optional[T] = None,
|
254
266
|
max_validations: PositiveInt = 2,
|
267
|
+
co_extractor: Optional[GenerateKwargs] = None,
|
255
268
|
**kwargs: Unpack[GenerateKwargs],
|
256
269
|
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
257
270
|
"""Asynchronously asks a question and validates the response using a given validator.
|
@@ -261,6 +274,7 @@ class LLMUsage(ScopedConfig):
|
|
261
274
|
validator (Callable[[str], T | None]): A function to validate the response.
|
262
275
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
263
276
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
|
277
|
+
co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
|
264
278
|
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
265
279
|
|
266
280
|
Returns:
|
@@ -274,6 +288,23 @@ class LLMUsage(ScopedConfig):
|
|
274
288
|
if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
|
275
289
|
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
276
290
|
return validated
|
291
|
+
|
292
|
+
if co_extractor and (
|
293
|
+
(
|
294
|
+
co_response := await self.aask(
|
295
|
+
question=(
|
296
|
+
TEMPLATE_MANAGER.render_template(
|
297
|
+
configs.templates.co_validation_template,
|
298
|
+
{"original_q": q, "original_a": response},
|
299
|
+
)
|
300
|
+
),
|
301
|
+
**co_extractor,
|
302
|
+
)
|
303
|
+
)
|
304
|
+
and (validated := validator(co_response))
|
305
|
+
):
|
306
|
+
logger.debug(f"Successfully validated the co-response at {lap}th attempt.")
|
307
|
+
return validated
|
277
308
|
except Exception as e: # noqa: BLE001
|
278
309
|
logger.error(f"Error during validation: \n{e}")
|
279
310
|
break
|
@@ -291,7 +322,7 @@ class LLMUsage(ScopedConfig):
|
|
291
322
|
|
292
323
|
async def aliststr(
|
293
324
|
self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
|
294
|
-
) -> List[str]:
|
325
|
+
) -> Optional[List[str]]:
|
295
326
|
"""Asynchronously generates a list of strings based on a given requirement.
|
296
327
|
|
297
328
|
Args:
|
@@ -311,7 +342,7 @@ class LLMUsage(ScopedConfig):
|
|
311
342
|
**kwargs,
|
312
343
|
)
|
313
344
|
|
314
|
-
async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> List[str]:
|
345
|
+
async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> Optional[List[str]]:
|
315
346
|
"""Asynchronously generates a list of strings based on a given requirement.
|
316
347
|
|
317
348
|
Args:
|
@@ -339,7 +370,7 @@ class LLMUsage(ScopedConfig):
|
|
339
370
|
Returns:
|
340
371
|
str: The validated response as a single string.
|
341
372
|
"""
|
342
|
-
return (
|
373
|
+
return ok(
|
343
374
|
await self.apathstr(
|
344
375
|
requirement,
|
345
376
|
k=1,
|
@@ -347,7 +378,7 @@ class LLMUsage(ScopedConfig):
|
|
347
378
|
)
|
348
379
|
).pop()
|
349
380
|
|
350
|
-
async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> str:
|
381
|
+
async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> Optional[str]:
|
351
382
|
"""Asynchronously generates a generic string based on a given requirement.
|
352
383
|
|
353
384
|
Args:
|
@@ -357,7 +388,7 @@ class LLMUsage(ScopedConfig):
|
|
357
388
|
Returns:
|
358
389
|
str: The generated string.
|
359
390
|
"""
|
360
|
-
return await self.aask_validate(
|
391
|
+
return await self.aask_validate( # pyright: ignore [reportReturnType]
|
361
392
|
TEMPLATE_MANAGER.render_template(
|
362
393
|
configs.templates.generic_string_template,
|
363
394
|
{"requirement": requirement, "language": GenericCapture.capture_type},
|
@@ -372,7 +403,7 @@ class LLMUsage(ScopedConfig):
|
|
372
403
|
choices: List[T],
|
373
404
|
k: NonNegativeInt = 0,
|
374
405
|
**kwargs: Unpack[ValidateKwargs[List[T]]],
|
375
|
-
) -> List[T]:
|
406
|
+
) -> Optional[List[T]]:
|
376
407
|
"""Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
|
377
408
|
|
378
409
|
Args:
|
@@ -437,13 +468,13 @@ class LLMUsage(ScopedConfig):
|
|
437
468
|
Raises:
|
438
469
|
ValueError: If validation fails after maximum attempts or if no valid selection is made.
|
439
470
|
"""
|
440
|
-
return (
|
471
|
+
return ok(
|
441
472
|
await self.achoose(
|
442
473
|
instruction=instruction,
|
443
474
|
choices=choices,
|
444
475
|
k=1,
|
445
476
|
**kwargs,
|
446
|
-
)
|
477
|
+
),
|
447
478
|
)[0]
|
448
479
|
|
449
480
|
async def ajudge(
|
@@ -500,7 +531,7 @@ class EmbeddingUsage(LLMUsage):
|
|
500
531
|
"""
|
501
532
|
# check seq length
|
502
533
|
max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
|
503
|
-
if any(len(t) > max_len for t in input_text):
|
534
|
+
if max_len and any(len(t) > max_len for t in input_text):
|
504
535
|
logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
|
505
536
|
raise ValueError(err)
|
506
537
|
|
@@ -514,10 +545,10 @@ class EmbeddingUsage(LLMUsage):
|
|
514
545
|
or configs.embedding.timeout
|
515
546
|
or self.llm_timeout
|
516
547
|
or configs.llm.timeout,
|
517
|
-
api_key=(
|
548
|
+
api_key=ok(
|
518
549
|
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
519
550
|
).get_secret_value(),
|
520
|
-
api_base=(
|
551
|
+
api_base=ok(
|
521
552
|
self.embedding_api_endpoint
|
522
553
|
or configs.embedding.api_endpoint
|
523
554
|
or self.llm_api_endpoint
|
@@ -566,7 +597,7 @@ class ToolBoxUsage(LLMUsage):
|
|
566
597
|
self,
|
567
598
|
task: Task,
|
568
599
|
**kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
|
569
|
-
) -> List[ToolBox]:
|
600
|
+
) -> Optional[List[ToolBox]]:
|
570
601
|
"""Asynchronously executes a multi-choice decision-making process to choose toolboxes.
|
571
602
|
|
572
603
|
Args:
|
@@ -591,7 +622,7 @@ class ToolBoxUsage(LLMUsage):
|
|
591
622
|
task: Task,
|
592
623
|
toolbox: ToolBox,
|
593
624
|
**kwargs: Unpack[ChooseKwargs[List[Tool]]],
|
594
|
-
) -> List[Tool]:
|
625
|
+
) -> Optional[List[Tool]]:
|
595
626
|
"""Asynchronously executes a multi-choice decision-making process to choose tools.
|
596
627
|
|
597
628
|
Args:
|
@@ -631,11 +662,11 @@ class ToolBoxUsage(LLMUsage):
|
|
631
662
|
tool_choose_kwargs = tool_choose_kwargs or {}
|
632
663
|
|
633
664
|
# Choose the toolboxes
|
634
|
-
chosen_toolboxes = await self.choose_toolboxes(task, **box_choose_kwargs)
|
665
|
+
chosen_toolboxes = ok(await self.choose_toolboxes(task, **box_choose_kwargs))
|
635
666
|
# Choose the tools
|
636
667
|
chosen_tools = []
|
637
668
|
for toolbox in chosen_toolboxes:
|
638
|
-
chosen_tools.extend(await self.choose_tools(task, toolbox, **tool_choose_kwargs))
|
669
|
+
chosen_tools.extend(ok(await self.choose_tools(task, toolbox, **tool_choose_kwargs)))
|
639
670
|
return chosen_tools
|
640
671
|
|
641
672
|
async def gather_tools(self, task: Task, **kwargs: Unpack[ChooseKwargs]) -> List[Tool]:
|
fabricatio/models/utils.py
CHANGED
@@ -165,3 +165,24 @@ async def ask_edit(
|
|
165
165
|
if edited:
|
166
166
|
res.append(edited)
|
167
167
|
return res
|
168
|
+
|
169
|
+
|
170
|
+
def override_kwargs[T](kwargs: Dict[str, T], **overrides) -> Dict[str, T]:
|
171
|
+
"""Override the values in kwargs with the provided overrides."""
|
172
|
+
kwargs.update({k: v for k, v in overrides.items() if v is not None})
|
173
|
+
return kwargs
|
174
|
+
|
175
|
+
|
176
|
+
def ok[T](val: Optional[T], msg:str="Value is None") -> T:
|
177
|
+
"""Check if a value is None and raise a ValueError with the provided message if it is.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
val: The value to check.
|
181
|
+
msg: The message to include in the ValueError if val is None.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
T: The value if it is not None.
|
185
|
+
"""
|
186
|
+
if val is None:
|
187
|
+
raise ValueError(msg)
|
188
|
+
return val
|
fabricatio/parser.py
CHANGED
Binary file
|
@@ -1,34 +1,33 @@
|
|
1
|
-
fabricatio-0.2.6.
|
2
|
-
fabricatio-0.2.6.
|
3
|
-
fabricatio-0.2.6.
|
4
|
-
fabricatio/actions/article.py,sha256=
|
1
|
+
fabricatio-0.2.6.dev7.dist-info/METADATA,sha256=KJKAdxeQyQdtZjg8fyqf5BCnAZA4EJkAAMaNsDgXYCQ,14085
|
2
|
+
fabricatio-0.2.6.dev7.dist-info/WHEEL,sha256=mDFV3bKFgwlxLHvOsPqpR9up9dUKYzsUQNKBdkW5c08,94
|
3
|
+
fabricatio-0.2.6.dev7.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
|
4
|
+
fabricatio/actions/article.py,sha256=LfIWnbFYB9e3Bq2YDPk1geWDbJTq7zCitLtpFhAhYHM,4563
|
5
5
|
fabricatio/actions/output.py,sha256=KSSLvEvXsA10ACN2mbqGo98QwKLVUAoMUJNKYk6HhGc,645
|
6
6
|
fabricatio/actions/rag.py,sha256=GpT7YlqOYznZyaT-6Y84_33HtZGT-5s71ZK8iroQA9g,813
|
7
7
|
fabricatio/capabilities/correct.py,sha256=0BYhjo9WrLwKsXQR8bTPvdQITbrMs7RX1xpzhuQt_yY,5222
|
8
|
-
fabricatio/capabilities/covalidate.py,sha256=zl0b0Z8ZC3XkpzISIZJY4CZZAdVsx4qd1rdTLrFHFz8,6621
|
9
8
|
fabricatio/capabilities/propose.py,sha256=y3kge5g6bb8HYuV8e9h4MdqOMTlsfAIZpqE_cagWPTY,1593
|
10
|
-
fabricatio/capabilities/rag.py,sha256=
|
9
|
+
fabricatio/capabilities/rag.py,sha256=R1yUD675CDEmGakXb2nzEzZe0vjN7edMS7VHtPOAriU,15771
|
11
10
|
fabricatio/capabilities/rating.py,sha256=R9otyZVE2E3kKxrOCTZMeesBCPbC-fSb7bXgZPMQzfU,14406
|
12
11
|
fabricatio/capabilities/review.py,sha256=XYzpSnFCT9HS2XytQT8HDgV4SjXehexoJgucZFMx6P8,11102
|
13
12
|
fabricatio/capabilities/task.py,sha256=MBiDyC3oHwTbTiLiGyqUEVfVGSN42lU03ndeapTpyjQ,4609
|
14
|
-
fabricatio/config.py,sha256=
|
13
|
+
fabricatio/config.py,sha256=f3B_Mwhc4mGEdECG4EqcxGww0Eu7KhCAwPXXJlHf1a8,16635
|
15
14
|
fabricatio/core.py,sha256=VQ_JKgUGIy2gZ8xsTBZCdr_IP7wC5aPg0_bsOmjQ588,6458
|
16
15
|
fabricatio/decorators.py,sha256=uzsP4tFKQNjDHBkofsjjoJA0IUAaYOtt6YVedoyOqlo,6551
|
17
16
|
fabricatio/fs/curd.py,sha256=N6l2MncjrFfnXBRtteRouXp5Rjy8EAKC_i29_G-zz98,4618
|
18
17
|
fabricatio/fs/readers.py,sha256=EZKN_AZdrp8DggJECP53QHw3uHeSDf-AwCAA_V7fNKU,1202
|
19
18
|
fabricatio/fs/__init__.py,sha256=PCf0s_9KDjVfNw7AfPoJzGt3jMq4gJOfbcT4pb0D0ZY,588
|
20
19
|
fabricatio/journal.py,sha256=stnEP88aUBA_GmU9gfTF2EZI8FS2OyMLGaMSTgK4QgA,476
|
21
|
-
fabricatio/models/action.py,sha256
|
20
|
+
fabricatio/models/action.py,sha256=dSmwIrW68JhCrkhWENRgTLIQ-0grVA4408QAUy23HZo,8210
|
22
21
|
fabricatio/models/events.py,sha256=QvlnS8FEELg6KNabcJMeh2GV_y0ZBzKOPphcteKYWYU,4183
|
23
|
-
fabricatio/models/extra.py,sha256=
|
22
|
+
fabricatio/models/extra.py,sha256=oPCrh80u-O5XoFMVvZ6D6SVpSSW0zkxw4zfaTeK_wLU,26263
|
24
23
|
fabricatio/models/generic.py,sha256=IdPJMf3qxZFq8yqd6OuAYKfCM0wBlJkozgxvxQZVEEc,14025
|
25
|
-
fabricatio/models/kwargs_types.py,sha256=
|
26
|
-
fabricatio/models/role.py,sha256=
|
24
|
+
fabricatio/models/kwargs_types.py,sha256=H6DI3Jdben-FER_kx7owiRzmbSFKuu0sFjCADA1LJB0,5008
|
25
|
+
fabricatio/models/role.py,sha256=mmQbJ6GKr2Gx3wtjEz8d-vYoXs09ffcEkT_eCXaDd3E,2782
|
27
26
|
fabricatio/models/task.py,sha256=8NaR7ojQWyM740EDTqt9stwHKdrD6axCRpLKo0QzS-I,10492
|
28
27
|
fabricatio/models/tool.py,sha256=4b-v4WIC_LuLOKzzXL9bvKXr8vmGZ8O2uAFv5-1KRA0,7052
|
29
|
-
fabricatio/models/usages.py,sha256
|
30
|
-
fabricatio/models/utils.py,sha256=
|
31
|
-
fabricatio/parser.py,sha256=
|
28
|
+
fabricatio/models/usages.py,sha256=-689ssQ5F1SmxDToDHbv0EH8YaPTjhkn14l_M6Aer-M,30859
|
29
|
+
fabricatio/models/utils.py,sha256=3HW0tM6WwOK8g14tnIzVWTXzIRLHjMKPjjSl9pMRWkw,5668
|
30
|
+
fabricatio/parser.py,sha256=9Jzw-yV6uKbFvf6sPna-XHdziVGVBZWvPctgX_6ODL8,6251
|
32
31
|
fabricatio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
32
|
fabricatio/toolboxes/arithmetic.py,sha256=WLqhY-Pikv11Y_0SGajwZx3WhsLNpHKf9drzAqOf_nY,1369
|
34
33
|
fabricatio/toolboxes/fs.py,sha256=l4L1CVxJmjw9Ld2XUpIlWfV0_Fu_2Og6d3E13I-S4aE,736
|
@@ -38,6 +37,6 @@ fabricatio/workflows/rag.py,sha256=-YYp2tlE9Vtfgpg6ROpu6QVO8j8yVSPa6yDzlN3qVxs,5
|
|
38
37
|
fabricatio/_rust.pyi,sha256=eawBfpyGrB-JtOh4I6RSbjFSq83SSl-0syBeZ-g8270,3491
|
39
38
|
fabricatio/_rust_instances.py,sha256=2GwF8aVfYNemRI2feBzH1CZfBGno-XJJE5imJokGEYw,314
|
40
39
|
fabricatio/__init__.py,sha256=SzBYsRhZeL77jLtfJEjmoHOSwHwUGyvMATX6xfndLDM,1135
|
41
|
-
fabricatio/_rust.cp39-win_amd64.pyd,sha256=
|
42
|
-
fabricatio-0.2.6.
|
43
|
-
fabricatio-0.2.6.
|
40
|
+
fabricatio/_rust.cp39-win_amd64.pyd,sha256=GvYOGn9Xya6YMX-nhmqv-w908ndgc2HSinAYMkhypKo,1826304
|
41
|
+
fabricatio-0.2.6.dev7.data/scripts/tdown.exe,sha256=5mZx7mp19U-nnWHwoZTyRmJun2iR77nar1wab1j_Jj8,3397632
|
42
|
+
fabricatio-0.2.6.dev7.dist-info/RECORD,,
|
@@ -1,160 +0,0 @@
|
|
1
|
-
"""Co-validation capability for LLMs."""
|
2
|
-
|
3
|
-
from asyncio import gather
|
4
|
-
from typing import Callable, List, Optional, Union, Unpack, overload
|
5
|
-
|
6
|
-
from fabricatio import TEMPLATE_MANAGER
|
7
|
-
from fabricatio.config import configs
|
8
|
-
from fabricatio.journal import logger
|
9
|
-
from fabricatio.models.kwargs_types import GenerateKwargs
|
10
|
-
from fabricatio.models.usages import LLMUsage
|
11
|
-
|
12
|
-
|
13
|
-
class CoValidate(LLMUsage):
|
14
|
-
"""Class that represents a co-validation capability using multiple LLMs.
|
15
|
-
|
16
|
-
This class provides methods to validate responses by attempting multiple approaches:
|
17
|
-
1. Using the primary LLM to generate a response
|
18
|
-
2. Using a secondary (co-) model to refine responses that fail validation
|
19
|
-
3. Trying multiple times if needed
|
20
|
-
"""
|
21
|
-
|
22
|
-
@overload
|
23
|
-
async def aask_covalidate[T](
|
24
|
-
self,
|
25
|
-
question: str,
|
26
|
-
validator: Callable[[str], T | None],
|
27
|
-
co_model: Optional[str] = None,
|
28
|
-
co_temperature: Optional[float] = None,
|
29
|
-
co_top_p: Optional[float] = None,
|
30
|
-
co_max_tokens: Optional[int] = None,
|
31
|
-
max_validations: int = 2,
|
32
|
-
default: None = None,
|
33
|
-
**kwargs: Unpack[GenerateKwargs],
|
34
|
-
) -> T | None: ...
|
35
|
-
|
36
|
-
@overload
|
37
|
-
async def aask_covalidate[T](
|
38
|
-
self,
|
39
|
-
question: str,
|
40
|
-
validator: Callable[[str], T | None],
|
41
|
-
co_model: Optional[str] = None,
|
42
|
-
co_temperature: Optional[float] = None,
|
43
|
-
co_top_p: Optional[float] = None,
|
44
|
-
co_max_tokens: Optional[int] = None,
|
45
|
-
max_validations: int = 2,
|
46
|
-
default: T = ...,
|
47
|
-
**kwargs: Unpack[GenerateKwargs],
|
48
|
-
) -> T: ...
|
49
|
-
|
50
|
-
@overload
|
51
|
-
async def aask_covalidate[T](
|
52
|
-
self,
|
53
|
-
question: List[str],
|
54
|
-
validator: Callable[[str], T | None],
|
55
|
-
co_model: Optional[str] = None,
|
56
|
-
co_temperature: Optional[float] = None,
|
57
|
-
co_top_p: Optional[float] = None,
|
58
|
-
co_max_tokens: Optional[int] = None,
|
59
|
-
max_validations: int = 2,
|
60
|
-
default: None = None,
|
61
|
-
**kwargs: Unpack[GenerateKwargs],
|
62
|
-
) -> List[T | None]: ...
|
63
|
-
|
64
|
-
@overload
|
65
|
-
async def aask_covalidate[T](
|
66
|
-
self,
|
67
|
-
question: List[str],
|
68
|
-
validator: Callable[[str], T | None],
|
69
|
-
co_model: Optional[str] = None,
|
70
|
-
co_temperature: Optional[float] = None,
|
71
|
-
co_top_p: Optional[float] = None,
|
72
|
-
co_max_tokens: Optional[int] = None,
|
73
|
-
max_validations: int = 2,
|
74
|
-
default: T = ...,
|
75
|
-
**kwargs: Unpack[GenerateKwargs],
|
76
|
-
) -> List[T]: ...
|
77
|
-
|
78
|
-
async def aask_covalidate[T](
|
79
|
-
self,
|
80
|
-
question: Union[str, List[str]],
|
81
|
-
validator: Callable[[str], T | None],
|
82
|
-
co_model: Optional[str] = None,
|
83
|
-
co_temperature: Optional[float] = None,
|
84
|
-
co_top_p: Optional[float] = None,
|
85
|
-
co_max_tokens: Optional[int] = None,
|
86
|
-
max_validations: int = 2,
|
87
|
-
default: Optional[T] = None,
|
88
|
-
**kwargs: Unpack[GenerateKwargs],
|
89
|
-
) -> Union[T | None, List[T | None]]:
|
90
|
-
"""Ask the LLM with co-validation to obtain a validated response.
|
91
|
-
|
92
|
-
This method attempts to generate a response that passes validation using two approaches:
|
93
|
-
1. First, it asks the primary LLM using the original question
|
94
|
-
2. If validation fails, it uses a secondary (co-) model with a template to improve the response
|
95
|
-
3. The process repeats up to max_validations times
|
96
|
-
|
97
|
-
Args:
|
98
|
-
question: String question or list of questions to ask
|
99
|
-
validator: Function that validates responses, returns result or None if invalid
|
100
|
-
co_model: Optional model name for the co-validator
|
101
|
-
co_temperature: Optional temperature setting for the co-validator
|
102
|
-
co_top_p: Optional top_p setting for the co-validator
|
103
|
-
co_max_tokens: Optional maximum tokens for the co-validator response
|
104
|
-
max_validations: Maximum number of validation attempts
|
105
|
-
default: Default value to return if validation fails
|
106
|
-
**kwargs: Additional keyword arguments passed to aask method
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
The validated result (T) or default if validation fails.
|
110
|
-
If input is a list of questions, returns a list of results.
|
111
|
-
"""
|
112
|
-
|
113
|
-
async def validate_single_question(q: str) -> Optional[T]:
|
114
|
-
"""Process a single question with validation attempts."""
|
115
|
-
validation_kwargs = kwargs.copy()
|
116
|
-
|
117
|
-
for lap in range(max_validations):
|
118
|
-
try:
|
119
|
-
# First attempt: direct question to primary model
|
120
|
-
response = await self.aask(question=q, **validation_kwargs)
|
121
|
-
if response and (validated := validator(response)):
|
122
|
-
logger.debug(f"Successfully validated the primary response at {lap}th attempt.")
|
123
|
-
return validated
|
124
|
-
|
125
|
-
# Second attempt: use co-model with validation template
|
126
|
-
co_prompt = TEMPLATE_MANAGER.render_template(
|
127
|
-
configs.templates.co_validation_template,
|
128
|
-
{"original_q": q, "original_a": response},
|
129
|
-
)
|
130
|
-
co_response = await self.aask(
|
131
|
-
question=co_prompt,
|
132
|
-
model=co_model,
|
133
|
-
temperature=co_temperature,
|
134
|
-
top_p=co_top_p,
|
135
|
-
max_tokens=co_max_tokens,
|
136
|
-
)
|
137
|
-
|
138
|
-
if co_response and (validated := validator(co_response)):
|
139
|
-
logger.debug(f"Successfully validated the co-response at {lap}th attempt.")
|
140
|
-
return validated
|
141
|
-
|
142
|
-
except Exception as e: # noqa: BLE001
|
143
|
-
logger.error(f"Error during validation: \n{e}")
|
144
|
-
break
|
145
|
-
|
146
|
-
# Disable caching for subsequent attempts
|
147
|
-
if not validation_kwargs.get("no_cache"):
|
148
|
-
validation_kwargs["no_cache"] = True
|
149
|
-
logger.debug("Disabled cache for the next attempt")
|
150
|
-
|
151
|
-
if default is None:
|
152
|
-
logger.error(f"Failed to validate the response after {max_validations} attempts.")
|
153
|
-
return default
|
154
|
-
|
155
|
-
# Handle single question or list of questions
|
156
|
-
if isinstance(question, str):
|
157
|
-
return await validate_single_question(question)
|
158
|
-
|
159
|
-
# Process multiple questions in parallel
|
160
|
-
return await gather(*[validate_single_question(q) for q in question])
|
File without changes
|