fabricatio 0.2.6.dev5__cp39-cp39-win_amd64.whl → 0.2.6.dev7__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,29 @@
1
1
  """This module contains the types for the keyword arguments of the methods in the models module."""
2
2
 
3
+ from importlib.util import find_spec
3
4
  from typing import Any, Required, TypedDict
4
5
 
5
6
  from litellm.caching.caching import CacheMode
6
7
  from litellm.types.caching import CachingSupportedCallTypes
7
8
 
9
+ if find_spec("pymilvus"):
10
+ from pymilvus import CollectionSchema
11
+ from pymilvus.milvus_client import IndexParams
8
12
 
9
- class CollectionSimpleConfigKwargs(TypedDict, total=False):
10
- """Configuration parameters for a vector collection.
13
+ class CollectionConfigKwargs(TypedDict, total=False):
14
+ """Configuration parameters for a vector collection.
11
15
 
12
- These arguments are typically used when configuring connections to vector databases.
13
- """
16
+ These arguments are typically used when configuring connections to vector databases.
17
+ """
14
18
 
15
- dimension: int | None
16
- timeout: float
19
+ dimension: int | None
20
+ primary_field_name: str
21
+ id_type: str
22
+ vector_field_name: str
23
+ metric_type: str
24
+ timeout: float | None
25
+ schema: CollectionSchema | None
26
+ index_params: IndexParams | None
17
27
 
18
28
 
19
29
  class FetchKwargs(TypedDict, total=False):
@@ -81,6 +91,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
81
91
 
82
92
  default: T
83
93
  max_validations: int
94
+ co_extractor: GenerateKwargs
84
95
 
85
96
 
86
97
  # noinspection PyTypedDict
fabricatio/models/role.py CHANGED
@@ -3,7 +3,6 @@
3
3
  from typing import Any, Self, Set
4
4
 
5
5
  from fabricatio.capabilities.correct import Correct
6
- from fabricatio.capabilities.covalidate import CoValidate
7
6
  from fabricatio.capabilities.task import HandleTask, ProposeTask
8
7
  from fabricatio.core import env
9
8
  from fabricatio.journal import logger
@@ -13,7 +12,7 @@ from fabricatio.models.tool import ToolBox
13
12
  from pydantic import Field
14
13
 
15
14
 
16
- class Role(ProposeTask, HandleTask, Correct, CoValidate):
15
+ class Role(ProposeTask, HandleTask, Correct):
17
16
  """Class that represents a role with a registry of events and workflows.
18
17
 
19
18
  A Role serves as a container for workflows, managing their registration to events
@@ -12,9 +12,9 @@ from fabricatio.models.generic import ScopedConfig, WithBriefing
12
12
  from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
13
13
  from fabricatio.models.task import Task
14
14
  from fabricatio.models.tool import Tool, ToolBox
15
- from fabricatio.models.utils import Messages
15
+ from fabricatio.models.utils import Messages, ok
16
16
  from fabricatio.parser import GenericCapture, JsonCapture
17
- from litellm import Router, stream_chunk_builder
17
+ from litellm import Router, stream_chunk_builder # pyright: ignore [reportPrivateImportUsage]
18
18
  from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
19
19
  from litellm.types.utils import (
20
20
  Choices,
@@ -70,14 +70,22 @@ class LLMUsage(ScopedConfig):
70
70
  """
71
71
  # Call the underlying asynchronous completion function with the provided and default parameters
72
72
  # noinspection PyTypeChecker,PydanticTypeChecker
73
-
74
73
  return await self._deploy(
75
74
  Deployment(
76
- model_name=(m_name := kwargs.get("model") or self.llm_model or configs.llm.model),
75
+ model_name=(
76
+ m_name := ok(
77
+ kwargs.get("model") or self.llm_model or configs.llm.model, "model name is not set at any place"
78
+ )
79
+ ), # pyright: ignore [reportCallIssue]
77
80
  litellm_params=(
78
81
  p := LiteLLM_Params(
79
- api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
80
- api_base=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
82
+ api_key=ok(
83
+ self.llm_api_key or configs.llm.api_key, "llm api key is not set at any place"
84
+ ).get_secret_value(),
85
+ api_base=ok(
86
+ self.llm_api_endpoint or configs.llm.api_endpoint,
87
+ "llm api endpoint is not set at any place",
88
+ ).unicode_string(),
81
89
  model=m_name,
82
90
  tpm=self.llm_tpm or configs.llm.tpm,
83
91
  rpm=self.llm_rpm or configs.llm.rpm,
@@ -88,14 +96,14 @@ class LLMUsage(ScopedConfig):
88
96
  model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
89
97
  )
90
98
  ).acompletion(
91
- messages=messages,
99
+ messages=messages, # pyright: ignore [reportArgumentType]
92
100
  n=n or self.llm_generation_count or configs.llm.generation_count,
93
101
  model=m_name,
94
102
  temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
95
103
  stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
96
104
  top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
97
105
  max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
98
- stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
106
+ stream=ok(kwargs.get("stream") or self.llm_stream or configs.llm.stream, "stream is not set at any place"),
99
107
  cache={
100
108
  "no-cache": kwargs.get("no_cache"),
101
109
  "no-store": kwargs.get("no_store"),
@@ -196,15 +204,15 @@ class LLMUsage(ScopedConfig):
196
204
  for q, sm in zip(q_seq, sm_seq, strict=True)
197
205
  ]
198
206
  )
199
- return [r[0].message.content for r in res]
207
+ return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
200
208
  case (list(q_seq), str(sm)):
201
209
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
202
- return [r[0].message.content for r in res]
210
+ return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
203
211
  case (str(q), list(sm_seq)):
204
212
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
205
- return [r[0].message.content for r in res]
213
+ return [r[0].message.content for r in res] # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
206
214
  case (str(q), str(sm)):
207
- return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
215
+ return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content # pyright: ignore [reportReturnType, reportAttributeAccessIssue]
208
216
  case _:
209
217
  raise RuntimeError("Should not reach here.")
210
218
 
@@ -215,6 +223,7 @@ class LLMUsage(ScopedConfig):
215
223
  validator: Callable[[str], T | None],
216
224
  default: T = ...,
217
225
  max_validations: PositiveInt = 2,
226
+ co_extractor: Optional[GenerateKwargs] = None,
218
227
  **kwargs: Unpack[GenerateKwargs],
219
228
  ) -> T: ...
220
229
  @overload
@@ -224,6 +233,7 @@ class LLMUsage(ScopedConfig):
224
233
  validator: Callable[[str], T | None],
225
234
  default: T = ...,
226
235
  max_validations: PositiveInt = 2,
236
+ co_extractor: Optional[GenerateKwargs] = None,
227
237
  **kwargs: Unpack[GenerateKwargs],
228
238
  ) -> List[T]: ...
229
239
  @overload
@@ -233,6 +243,7 @@ class LLMUsage(ScopedConfig):
233
243
  validator: Callable[[str], T | None],
234
244
  default: None = None,
235
245
  max_validations: PositiveInt = 2,
246
+ co_extractor: Optional[GenerateKwargs] = None,
236
247
  **kwargs: Unpack[GenerateKwargs],
237
248
  ) -> Optional[T]: ...
238
249
 
@@ -243,6 +254,7 @@ class LLMUsage(ScopedConfig):
243
254
  validator: Callable[[str], T | None],
244
255
  default: None = None,
245
256
  max_validations: PositiveInt = 2,
257
+ co_extractor: Optional[GenerateKwargs] = None,
246
258
  **kwargs: Unpack[GenerateKwargs],
247
259
  ) -> List[Optional[T]]: ...
248
260
 
@@ -252,6 +264,7 @@ class LLMUsage(ScopedConfig):
252
264
  validator: Callable[[str], T | None],
253
265
  default: Optional[T] = None,
254
266
  max_validations: PositiveInt = 2,
267
+ co_extractor: Optional[GenerateKwargs] = None,
255
268
  **kwargs: Unpack[GenerateKwargs],
256
269
  ) -> Optional[T] | List[Optional[T]] | List[T] | T:
257
270
  """Asynchronously asks a question and validates the response using a given validator.
@@ -261,6 +274,7 @@ class LLMUsage(ScopedConfig):
261
274
  validator (Callable[[str], T | None]): A function to validate the response.
262
275
  default (T | None): Default value to return if validation fails. Defaults to None.
263
276
  max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
277
+ co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
264
278
  **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
265
279
 
266
280
  Returns:
@@ -274,6 +288,23 @@ class LLMUsage(ScopedConfig):
274
288
  if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
275
289
  logger.debug(f"Successfully validated the response at {lap}th attempt.")
276
290
  return validated
291
+
292
+ if co_extractor and (
293
+ (
294
+ co_response := await self.aask(
295
+ question=(
296
+ TEMPLATE_MANAGER.render_template(
297
+ configs.templates.co_validation_template,
298
+ {"original_q": q, "original_a": response},
299
+ )
300
+ ),
301
+ **co_extractor,
302
+ )
303
+ )
304
+ and (validated := validator(co_response))
305
+ ):
306
+ logger.debug(f"Successfully validated the co-response at {lap}th attempt.")
307
+ return validated
277
308
  except Exception as e: # noqa: BLE001
278
309
  logger.error(f"Error during validation: \n{e}")
279
310
  break
@@ -291,7 +322,7 @@ class LLMUsage(ScopedConfig):
291
322
 
292
323
  async def aliststr(
293
324
  self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
294
- ) -> List[str]:
325
+ ) -> Optional[List[str]]:
295
326
  """Asynchronously generates a list of strings based on a given requirement.
296
327
 
297
328
  Args:
@@ -311,7 +342,7 @@ class LLMUsage(ScopedConfig):
311
342
  **kwargs,
312
343
  )
313
344
 
314
- async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> List[str]:
345
+ async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> Optional[List[str]]:
315
346
  """Asynchronously generates a list of strings based on a given requirement.
316
347
 
317
348
  Args:
@@ -339,7 +370,7 @@ class LLMUsage(ScopedConfig):
339
370
  Returns:
340
371
  str: The validated response as a single string.
341
372
  """
342
- return (
373
+ return ok(
343
374
  await self.apathstr(
344
375
  requirement,
345
376
  k=1,
@@ -347,7 +378,7 @@ class LLMUsage(ScopedConfig):
347
378
  )
348
379
  ).pop()
349
380
 
350
- async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> str:
381
+ async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> Optional[str]:
351
382
  """Asynchronously generates a generic string based on a given requirement.
352
383
 
353
384
  Args:
@@ -357,7 +388,7 @@ class LLMUsage(ScopedConfig):
357
388
  Returns:
358
389
  str: The generated string.
359
390
  """
360
- return await self.aask_validate(
391
+ return await self.aask_validate( # pyright: ignore [reportReturnType]
361
392
  TEMPLATE_MANAGER.render_template(
362
393
  configs.templates.generic_string_template,
363
394
  {"requirement": requirement, "language": GenericCapture.capture_type},
@@ -372,7 +403,7 @@ class LLMUsage(ScopedConfig):
372
403
  choices: List[T],
373
404
  k: NonNegativeInt = 0,
374
405
  **kwargs: Unpack[ValidateKwargs[List[T]]],
375
- ) -> List[T]:
406
+ ) -> Optional[List[T]]:
376
407
  """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
377
408
 
378
409
  Args:
@@ -437,13 +468,13 @@ class LLMUsage(ScopedConfig):
437
468
  Raises:
438
469
  ValueError: If validation fails after maximum attempts or if no valid selection is made.
439
470
  """
440
- return (
471
+ return ok(
441
472
  await self.achoose(
442
473
  instruction=instruction,
443
474
  choices=choices,
444
475
  k=1,
445
476
  **kwargs,
446
- )
477
+ ),
447
478
  )[0]
448
479
 
449
480
  async def ajudge(
@@ -500,7 +531,7 @@ class EmbeddingUsage(LLMUsage):
500
531
  """
501
532
  # check seq length
502
533
  max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
503
- if any(len(t) > max_len for t in input_text):
534
+ if max_len and any(len(t) > max_len for t in input_text):
504
535
  logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
505
536
  raise ValueError(err)
506
537
 
@@ -514,10 +545,10 @@ class EmbeddingUsage(LLMUsage):
514
545
  or configs.embedding.timeout
515
546
  or self.llm_timeout
516
547
  or configs.llm.timeout,
517
- api_key=(
548
+ api_key=ok(
518
549
  self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
519
550
  ).get_secret_value(),
520
- api_base=(
551
+ api_base=ok(
521
552
  self.embedding_api_endpoint
522
553
  or configs.embedding.api_endpoint
523
554
  or self.llm_api_endpoint
@@ -566,7 +597,7 @@ class ToolBoxUsage(LLMUsage):
566
597
  self,
567
598
  task: Task,
568
599
  **kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
569
- ) -> List[ToolBox]:
600
+ ) -> Optional[List[ToolBox]]:
570
601
  """Asynchronously executes a multi-choice decision-making process to choose toolboxes.
571
602
 
572
603
  Args:
@@ -591,7 +622,7 @@ class ToolBoxUsage(LLMUsage):
591
622
  task: Task,
592
623
  toolbox: ToolBox,
593
624
  **kwargs: Unpack[ChooseKwargs[List[Tool]]],
594
- ) -> List[Tool]:
625
+ ) -> Optional[List[Tool]]:
595
626
  """Asynchronously executes a multi-choice decision-making process to choose tools.
596
627
 
597
628
  Args:
@@ -631,11 +662,11 @@ class ToolBoxUsage(LLMUsage):
631
662
  tool_choose_kwargs = tool_choose_kwargs or {}
632
663
 
633
664
  # Choose the toolboxes
634
- chosen_toolboxes = await self.choose_toolboxes(task, **box_choose_kwargs)
665
+ chosen_toolboxes = ok(await self.choose_toolboxes(task, **box_choose_kwargs))
635
666
  # Choose the tools
636
667
  chosen_tools = []
637
668
  for toolbox in chosen_toolboxes:
638
- chosen_tools.extend(await self.choose_tools(task, toolbox, **tool_choose_kwargs))
669
+ chosen_tools.extend(ok(await self.choose_tools(task, toolbox, **tool_choose_kwargs)))
639
670
  return chosen_tools
640
671
 
641
672
  async def gather_tools(self, task: Task, **kwargs: Unpack[ChooseKwargs]) -> List[Tool]:
@@ -165,3 +165,24 @@ async def ask_edit(
165
165
  if edited:
166
166
  res.append(edited)
167
167
  return res
168
+
169
+
170
+ def override_kwargs[T](kwargs: Dict[str, T], **overrides) -> Dict[str, T]:
171
+ """Override the values in kwargs with the provided overrides."""
172
+ kwargs.update({k: v for k, v in overrides.items() if v is not None})
173
+ return kwargs
174
+
175
+
176
+ def ok[T](val: Optional[T], msg:str="Value is None") -> T:
177
+ """Check if a value is None and raise a ValueError with the provided message if it is.
178
+
179
+ Args:
180
+ val: The value to check.
181
+ msg: The message to include in the ValueError if val is None.
182
+
183
+ Returns:
184
+ T: The value if it is not None.
185
+ """
186
+ if val is None:
187
+ raise ValueError(msg)
188
+ return val
fabricatio/parser.py CHANGED
@@ -52,6 +52,7 @@ class Capture(BaseModel):
52
52
  case _:
53
53
  return text
54
54
 
55
+
55
56
  def capture(self, text: str) -> Tuple[str, ...] | str | None:
56
57
  """Capture the first occurrence of the pattern in the given text.
57
58
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fabricatio
3
- Version: 0.2.6.dev5
3
+ Version: 0.2.6.dev7
4
4
  Classifier: License :: OSI Approved :: MIT License
5
5
  Classifier: Programming Language :: Rust
6
6
  Classifier: Programming Language :: Python :: 3.12
@@ -1,34 +1,33 @@
1
- fabricatio-0.2.6.dev5.dist-info/METADATA,sha256=W68vsoplegz0qZ6jTc47JXuhgTVJBXGOKJtxdMh7p_Y,14085
2
- fabricatio-0.2.6.dev5.dist-info/WHEEL,sha256=SmPT9fUKPAPiE6hwAZ9_NHUVRjWSQ_RENTrzrvPx4p0,94
3
- fabricatio-0.2.6.dev5.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
4
- fabricatio/actions/article.py,sha256=d3i4i7h88yG8gvEW2SLDKM-FSsUv7qqiLgWpe6AJfSk,4463
1
+ fabricatio-0.2.6.dev7.dist-info/METADATA,sha256=KJKAdxeQyQdtZjg8fyqf5BCnAZA4EJkAAMaNsDgXYCQ,14085
2
+ fabricatio-0.2.6.dev7.dist-info/WHEEL,sha256=mDFV3bKFgwlxLHvOsPqpR9up9dUKYzsUQNKBdkW5c08,94
3
+ fabricatio-0.2.6.dev7.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
4
+ fabricatio/actions/article.py,sha256=LfIWnbFYB9e3Bq2YDPk1geWDbJTq7zCitLtpFhAhYHM,4563
5
5
  fabricatio/actions/output.py,sha256=KSSLvEvXsA10ACN2mbqGo98QwKLVUAoMUJNKYk6HhGc,645
6
6
  fabricatio/actions/rag.py,sha256=GpT7YlqOYznZyaT-6Y84_33HtZGT-5s71ZK8iroQA9g,813
7
7
  fabricatio/capabilities/correct.py,sha256=0BYhjo9WrLwKsXQR8bTPvdQITbrMs7RX1xpzhuQt_yY,5222
8
- fabricatio/capabilities/covalidate.py,sha256=zl0b0Z8ZC3XkpzISIZJY4CZZAdVsx4qd1rdTLrFHFz8,6621
9
8
  fabricatio/capabilities/propose.py,sha256=y3kge5g6bb8HYuV8e9h4MdqOMTlsfAIZpqE_cagWPTY,1593
10
- fabricatio/capabilities/rag.py,sha256=OebdGps8LGniN_HkRAOuwZd1ZQsyQe3WrduNAmBSxLM,15773
9
+ fabricatio/capabilities/rag.py,sha256=R1yUD675CDEmGakXb2nzEzZe0vjN7edMS7VHtPOAriU,15771
11
10
  fabricatio/capabilities/rating.py,sha256=R9otyZVE2E3kKxrOCTZMeesBCPbC-fSb7bXgZPMQzfU,14406
12
11
  fabricatio/capabilities/review.py,sha256=XYzpSnFCT9HS2XytQT8HDgV4SjXehexoJgucZFMx6P8,11102
13
12
  fabricatio/capabilities/task.py,sha256=MBiDyC3oHwTbTiLiGyqUEVfVGSN42lU03ndeapTpyjQ,4609
14
- fabricatio/config.py,sha256=orbMYzGZqdQy5cRlw0M6kOYmrRDq7uo8yEEZIjBfR18,16594
13
+ fabricatio/config.py,sha256=f3B_Mwhc4mGEdECG4EqcxGww0Eu7KhCAwPXXJlHf1a8,16635
15
14
  fabricatio/core.py,sha256=VQ_JKgUGIy2gZ8xsTBZCdr_IP7wC5aPg0_bsOmjQ588,6458
16
15
  fabricatio/decorators.py,sha256=uzsP4tFKQNjDHBkofsjjoJA0IUAaYOtt6YVedoyOqlo,6551
17
16
  fabricatio/fs/curd.py,sha256=N6l2MncjrFfnXBRtteRouXp5Rjy8EAKC_i29_G-zz98,4618
18
17
  fabricatio/fs/readers.py,sha256=EZKN_AZdrp8DggJECP53QHw3uHeSDf-AwCAA_V7fNKU,1202
19
18
  fabricatio/fs/__init__.py,sha256=PCf0s_9KDjVfNw7AfPoJzGt3jMq4gJOfbcT4pb0D0ZY,588
20
19
  fabricatio/journal.py,sha256=stnEP88aUBA_GmU9gfTF2EZI8FS2OyMLGaMSTgK4QgA,476
21
- fabricatio/models/action.py,sha256=-WutQKRKBSLU6TNscT-puS6TnQdLxSmatof4JxltXWU,8281
20
+ fabricatio/models/action.py,sha256=dSmwIrW68JhCrkhWENRgTLIQ-0grVA4408QAUy23HZo,8210
22
21
  fabricatio/models/events.py,sha256=QvlnS8FEELg6KNabcJMeh2GV_y0ZBzKOPphcteKYWYU,4183
23
- fabricatio/models/extra.py,sha256=s2y_zFH9kdCFHjrJadrHTODh_wFNcr6d_RDCLixgeyw,7390
22
+ fabricatio/models/extra.py,sha256=oPCrh80u-O5XoFMVvZ6D6SVpSSW0zkxw4zfaTeK_wLU,26263
24
23
  fabricatio/models/generic.py,sha256=IdPJMf3qxZFq8yqd6OuAYKfCM0wBlJkozgxvxQZVEEc,14025
25
- fabricatio/models/kwargs_types.py,sha256=Dfmd18SABDeV9JsI1JfPNpoB8FtB6qVYgJshZBsN1P0,4593
26
- fabricatio/models/role.py,sha256=GVe8Rzjxzn9Fiava82XWJJQKjbxl6Fxi3U1pebiWeYU,2853
24
+ fabricatio/models/kwargs_types.py,sha256=H6DI3Jdben-FER_kx7owiRzmbSFKuu0sFjCADA1LJB0,5008
25
+ fabricatio/models/role.py,sha256=mmQbJ6GKr2Gx3wtjEz8d-vYoXs09ffcEkT_eCXaDd3E,2782
27
26
  fabricatio/models/task.py,sha256=8NaR7ojQWyM740EDTqt9stwHKdrD6axCRpLKo0QzS-I,10492
28
27
  fabricatio/models/tool.py,sha256=4b-v4WIC_LuLOKzzXL9bvKXr8vmGZ8O2uAFv5-1KRA0,7052
29
- fabricatio/models/usages.py,sha256=Rp4AGYFgpm5L2McNSaztWMZplmE_PnMmRc9NdcD5Jmw,28702
30
- fabricatio/models/utils.py,sha256=1bCqeB6za7ecCAM3cU1raNWuN56732m45rXtlIlc3I4,5017
31
- fabricatio/parser.py,sha256=b1Em7zoEepQIrxgM51Damnbsx_AhPab-BefIWHGo1Ss,6249
28
+ fabricatio/models/usages.py,sha256=-689ssQ5F1SmxDToDHbv0EH8YaPTjhkn14l_M6Aer-M,30859
29
+ fabricatio/models/utils.py,sha256=3HW0tM6WwOK8g14tnIzVWTXzIRLHjMKPjjSl9pMRWkw,5668
30
+ fabricatio/parser.py,sha256=9Jzw-yV6uKbFvf6sPna-XHdziVGVBZWvPctgX_6ODL8,6251
32
31
  fabricatio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
32
  fabricatio/toolboxes/arithmetic.py,sha256=WLqhY-Pikv11Y_0SGajwZx3WhsLNpHKf9drzAqOf_nY,1369
34
33
  fabricatio/toolboxes/fs.py,sha256=l4L1CVxJmjw9Ld2XUpIlWfV0_Fu_2Og6d3E13I-S4aE,736
@@ -38,6 +37,6 @@ fabricatio/workflows/rag.py,sha256=-YYp2tlE9Vtfgpg6ROpu6QVO8j8yVSPa6yDzlN3qVxs,5
38
37
  fabricatio/_rust.pyi,sha256=eawBfpyGrB-JtOh4I6RSbjFSq83SSl-0syBeZ-g8270,3491
39
38
  fabricatio/_rust_instances.py,sha256=2GwF8aVfYNemRI2feBzH1CZfBGno-XJJE5imJokGEYw,314
40
39
  fabricatio/__init__.py,sha256=SzBYsRhZeL77jLtfJEjmoHOSwHwUGyvMATX6xfndLDM,1135
41
- fabricatio/_rust.cp39-win_amd64.pyd,sha256=rDkMzkreCAsU9bmxconUNqFNOaSLgSr7m--MqnSbXUU,1826304
42
- fabricatio-0.2.6.dev5.data/scripts/tdown.exe,sha256=ZrN_zXI0oDCkHx_r7-QrdqncIH0mFqq19mN9sE81kxs,3397632
43
- fabricatio-0.2.6.dev5.dist-info/RECORD,,
40
+ fabricatio/_rust.cp39-win_amd64.pyd,sha256=GvYOGn9Xya6YMX-nhmqv-w908ndgc2HSinAYMkhypKo,1826304
41
+ fabricatio-0.2.6.dev7.data/scripts/tdown.exe,sha256=5mZx7mp19U-nnWHwoZTyRmJun2iR77nar1wab1j_Jj8,3397632
42
+ fabricatio-0.2.6.dev7.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.8.2)
2
+ Generator: maturin (1.8.3)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp39-cp39-win_amd64
@@ -1,160 +0,0 @@
1
- """Co-validation capability for LLMs."""
2
-
3
- from asyncio import gather
4
- from typing import Callable, List, Optional, Union, Unpack, overload
5
-
6
- from fabricatio import TEMPLATE_MANAGER
7
- from fabricatio.config import configs
8
- from fabricatio.journal import logger
9
- from fabricatio.models.kwargs_types import GenerateKwargs
10
- from fabricatio.models.usages import LLMUsage
11
-
12
-
13
- class CoValidate(LLMUsage):
14
- """Class that represents a co-validation capability using multiple LLMs.
15
-
16
- This class provides methods to validate responses by attempting multiple approaches:
17
- 1. Using the primary LLM to generate a response
18
- 2. Using a secondary (co-) model to refine responses that fail validation
19
- 3. Trying multiple times if needed
20
- """
21
-
22
- @overload
23
- async def aask_covalidate[T](
24
- self,
25
- question: str,
26
- validator: Callable[[str], T | None],
27
- co_model: Optional[str] = None,
28
- co_temperature: Optional[float] = None,
29
- co_top_p: Optional[float] = None,
30
- co_max_tokens: Optional[int] = None,
31
- max_validations: int = 2,
32
- default: None = None,
33
- **kwargs: Unpack[GenerateKwargs],
34
- ) -> T | None: ...
35
-
36
- @overload
37
- async def aask_covalidate[T](
38
- self,
39
- question: str,
40
- validator: Callable[[str], T | None],
41
- co_model: Optional[str] = None,
42
- co_temperature: Optional[float] = None,
43
- co_top_p: Optional[float] = None,
44
- co_max_tokens: Optional[int] = None,
45
- max_validations: int = 2,
46
- default: T = ...,
47
- **kwargs: Unpack[GenerateKwargs],
48
- ) -> T: ...
49
-
50
- @overload
51
- async def aask_covalidate[T](
52
- self,
53
- question: List[str],
54
- validator: Callable[[str], T | None],
55
- co_model: Optional[str] = None,
56
- co_temperature: Optional[float] = None,
57
- co_top_p: Optional[float] = None,
58
- co_max_tokens: Optional[int] = None,
59
- max_validations: int = 2,
60
- default: None = None,
61
- **kwargs: Unpack[GenerateKwargs],
62
- ) -> List[T | None]: ...
63
-
64
- @overload
65
- async def aask_covalidate[T](
66
- self,
67
- question: List[str],
68
- validator: Callable[[str], T | None],
69
- co_model: Optional[str] = None,
70
- co_temperature: Optional[float] = None,
71
- co_top_p: Optional[float] = None,
72
- co_max_tokens: Optional[int] = None,
73
- max_validations: int = 2,
74
- default: T = ...,
75
- **kwargs: Unpack[GenerateKwargs],
76
- ) -> List[T]: ...
77
-
78
- async def aask_covalidate[T](
79
- self,
80
- question: Union[str, List[str]],
81
- validator: Callable[[str], T | None],
82
- co_model: Optional[str] = None,
83
- co_temperature: Optional[float] = None,
84
- co_top_p: Optional[float] = None,
85
- co_max_tokens: Optional[int] = None,
86
- max_validations: int = 2,
87
- default: Optional[T] = None,
88
- **kwargs: Unpack[GenerateKwargs],
89
- ) -> Union[T | None, List[T | None]]:
90
- """Ask the LLM with co-validation to obtain a validated response.
91
-
92
- This method attempts to generate a response that passes validation using two approaches:
93
- 1. First, it asks the primary LLM using the original question
94
- 2. If validation fails, it uses a secondary (co-) model with a template to improve the response
95
- 3. The process repeats up to max_validations times
96
-
97
- Args:
98
- question: String question or list of questions to ask
99
- validator: Function that validates responses, returns result or None if invalid
100
- co_model: Optional model name for the co-validator
101
- co_temperature: Optional temperature setting for the co-validator
102
- co_top_p: Optional top_p setting for the co-validator
103
- co_max_tokens: Optional maximum tokens for the co-validator response
104
- max_validations: Maximum number of validation attempts
105
- default: Default value to return if validation fails
106
- **kwargs: Additional keyword arguments passed to aask method
107
-
108
- Returns:
109
- The validated result (T) or default if validation fails.
110
- If input is a list of questions, returns a list of results.
111
- """
112
-
113
- async def validate_single_question(q: str) -> Optional[T]:
114
- """Process a single question with validation attempts."""
115
- validation_kwargs = kwargs.copy()
116
-
117
- for lap in range(max_validations):
118
- try:
119
- # First attempt: direct question to primary model
120
- response = await self.aask(question=q, **validation_kwargs)
121
- if response and (validated := validator(response)):
122
- logger.debug(f"Successfully validated the primary response at {lap}th attempt.")
123
- return validated
124
-
125
- # Second attempt: use co-model with validation template
126
- co_prompt = TEMPLATE_MANAGER.render_template(
127
- configs.templates.co_validation_template,
128
- {"original_q": q, "original_a": response},
129
- )
130
- co_response = await self.aask(
131
- question=co_prompt,
132
- model=co_model,
133
- temperature=co_temperature,
134
- top_p=co_top_p,
135
- max_tokens=co_max_tokens,
136
- )
137
-
138
- if co_response and (validated := validator(co_response)):
139
- logger.debug(f"Successfully validated the co-response at {lap}th attempt.")
140
- return validated
141
-
142
- except Exception as e: # noqa: BLE001
143
- logger.error(f"Error during validation: \n{e}")
144
- break
145
-
146
- # Disable caching for subsequent attempts
147
- if not validation_kwargs.get("no_cache"):
148
- validation_kwargs["no_cache"] = True
149
- logger.debug("Disabled cache for the next attempt")
150
-
151
- if default is None:
152
- logger.error(f"Failed to validate the response after {max_validations} attempts.")
153
- return default
154
-
155
- # Handle single question or list of questions
156
- if isinstance(question, str):
157
- return await validate_single_question(question)
158
-
159
- # Process multiple questions in parallel
160
- return await gather(*[validate_single_question(q) for q in question])