fabricatio 0.2.5.dev3__cp312-cp312-win_amd64.whl → 0.2.5.dev5__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,90 +1,90 @@
1
1
  """This module contains the types for the keyword arguments of the methods in the models module."""
2
2
 
3
- from typing import Any, List, NotRequired, Set, TypedDict
3
+ from typing import Any, TypedDict
4
4
 
5
5
  from litellm.caching.caching import CacheMode
6
6
  from litellm.types.caching import CachingSupportedCallTypes
7
- from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
8
7
 
9
8
 
10
- class CollectionSimpleConfigKwargs(TypedDict):
9
+ class CollectionSimpleConfigKwargs(TypedDict, total=False):
11
10
  """Configuration parameters for a vector collection.
12
11
 
13
12
  These arguments are typically used when configuring connections to vector databases.
14
13
  """
15
14
 
16
- dimension: NotRequired[int]
17
- timeout: NotRequired[float]
15
+ dimension: int
16
+ timeout: float
18
17
 
19
18
 
20
- class FetchKwargs(TypedDict):
19
+ class FetchKwargs(TypedDict, total=False):
21
20
  """Arguments for fetching data from vector collections.
22
21
 
23
22
  Controls how data is retrieved from vector databases, including filtering
24
23
  and result limiting parameters.
25
24
  """
26
25
 
27
- collection_name: NotRequired[str]
28
- similarity_threshold: NotRequired[float]
29
- result_per_query: NotRequired[int]
26
+ collection_name: str
27
+ similarity_threshold: float
28
+ result_per_query: int
30
29
 
31
30
 
32
- class EmbeddingKwargs(TypedDict):
31
+ class EmbeddingKwargs(TypedDict, total=False):
33
32
  """Configuration parameters for text embedding operations.
34
33
 
35
34
  These settings control the behavior of embedding models that convert text
36
35
  to vector representations.
37
36
  """
38
37
 
39
- model: NotRequired[str]
40
- dimensions: NotRequired[int]
41
- timeout: NotRequired[PositiveInt]
42
- caching: NotRequired[bool]
38
+ model: str
39
+ dimensions: int
40
+ timeout: int
41
+ caching: bool
43
42
 
44
43
 
45
- class LLMKwargs(TypedDict):
44
+ class LLMKwargs(TypedDict, total=False):
46
45
  """Configuration parameters for language model inference.
47
46
 
48
47
  These arguments control the behavior of large language model calls,
49
48
  including generation parameters and caching options.
50
49
  """
51
50
 
52
- model: NotRequired[str]
53
- temperature: NotRequired[NonNegativeFloat]
54
- stop: NotRequired[str | List[str]]
55
- top_p: NotRequired[NonNegativeFloat]
56
- max_tokens: NotRequired[PositiveInt]
57
- stream: NotRequired[bool]
58
- timeout: NotRequired[PositiveInt]
59
- max_retries: NotRequired[PositiveInt]
60
- no_cache: NotRequired[bool] # If use cache in this call
61
- no_store: NotRequired[bool] # If store the response of this call to cache
62
- cache_ttl: NotRequired[int] # how long the stored cache is alive, in seconds
63
- s_maxage: NotRequired[int] # max accepted age of cached response, in seconds
64
-
65
-
66
- class ValidateKwargs[T](LLMKwargs):
67
- """Arguments for content validation operations.
51
+ model: str
52
+ temperature: float
53
+ stop: str | list[str]
54
+ top_p: float
55
+ max_tokens: int
56
+ stream: bool
57
+ timeout: int
58
+ max_retries: int
59
+ no_cache: bool # if the req uses cache in this call
60
+ no_store: bool # If store the response of this call to cache
61
+ cache_ttl: int # how long the stored cache is alive, in seconds
62
+ s_maxage: int # max accepted age of cached response, in seconds
63
+
64
+
65
+ class GenerateKwargs(LLMKwargs, total=False):
66
+ """Arguments for content generation operations.
68
67
 
69
- Extends LLMKwargs with additional parameters specific to validation tasks,
70
- such as limiting the number of validation attempts.
68
+ Extends LLMKwargs with additional parameters specific to generation tasks,
69
+ such as the number of generated items and the system message.
71
70
  """
72
71
 
73
- default: NotRequired[T]
74
- max_validations: NotRequired[PositiveInt]
72
+ system_message: str
75
73
 
76
74
 
77
- class GenerateKwargs(ValidateKwargs):
78
- """Arguments for content generation operations.
75
+ class ValidateKwargs[T](GenerateKwargs, total=False):
76
+ """Arguments for content validation operations.
79
77
 
80
- Extends ValidateKwargs with parameters specific to text generation,
81
- including system prompt configuration.
78
+ Extends LLMKwargs with additional parameters specific to validation tasks,
79
+ such as limiting the number of validation attempts.
82
80
  """
83
81
 
84
- system_message: NotRequired[str]
82
+ default: T
83
+ max_validations: int
85
84
 
86
85
 
87
- class ReviewKwargs(GenerateKwargs):
86
+ # noinspection PyTypedDict
87
+ class ReviewKwargs[T](ValidateKwargs[T], total=False):
88
88
  """Arguments for content review operations.
89
89
 
90
90
  Extends GenerateKwargs with parameters for evaluating content against
@@ -92,17 +92,18 @@ class ReviewKwargs(GenerateKwargs):
92
92
  """
93
93
 
94
94
  topic: str
95
- criteria: NotRequired[Set[str]]
95
+ criteria: set[str]
96
96
 
97
97
 
98
- class ChooseKwargs(GenerateKwargs):
98
+ # noinspection PyTypedDict
99
+ class ChooseKwargs[T](ValidateKwargs[T], total=False):
99
100
  """Arguments for selection operations.
100
101
 
101
102
  Extends GenerateKwargs with parameters for selecting among options,
102
103
  such as the number of items to choose.
103
104
  """
104
105
 
105
- k: NotRequired[NonNegativeInt]
106
+ k: int
106
107
 
107
108
 
108
109
  class CacheKwargs(TypedDict, total=False):
@@ -112,35 +113,35 @@ class CacheKwargs(TypedDict, total=False):
112
113
  including in-memory, Redis, S3, and vector database caching options.
113
114
  """
114
115
 
115
- mode: NotRequired[CacheMode] # when default_on cache is always on, when default_off cache is opt in
116
- host: NotRequired[str]
117
- port: NotRequired[str]
118
- password: NotRequired[str]
119
- namespace: NotRequired[str]
120
- ttl: NotRequired[float]
121
- default_in_memory_ttl: NotRequired[float]
122
- default_in_redis_ttl: NotRequired[float]
123
- similarity_threshold: NotRequired[float]
124
- supported_call_types: NotRequired[List[CachingSupportedCallTypes]]
116
+ mode: CacheMode # when default_on cache is always on, when default_off cache is opt in
117
+ host: str
118
+ port: str
119
+ password: str
120
+ namespace: str
121
+ ttl: float
122
+ default_in_memory_ttl: float
123
+ default_in_redis_ttl: float
124
+ similarity_threshold: float
125
+ supported_call_types: list[CachingSupportedCallTypes]
125
126
  # s3 Bucket, boto3 configuration
126
- s3_bucket_name: NotRequired[str]
127
- s3_region_name: NotRequired[str]
128
- s3_api_version: NotRequired[str]
129
- s3_use_ssl: NotRequired[bool]
130
- s3_verify: NotRequired[bool | str]
131
- s3_endpoint_url: NotRequired[str]
132
- s3_aws_access_key_id: NotRequired[str]
133
- s3_aws_secret_access_key: NotRequired[str]
134
- s3_aws_session_token: NotRequired[str]
135
- s3_config: NotRequired[Any]
136
- s3_path: NotRequired[str]
127
+ s3_bucket_name: str
128
+ s3_region_name: str
129
+ s3_api_version: str
130
+ s3_use_ssl: bool
131
+ s3_verify: bool | str
132
+ s3_endpoint_url: str
133
+ s3_aws_access_key_id: str
134
+ s3_aws_secret_access_key: str
135
+ s3_aws_session_token: str
136
+ s3_config: Any
137
+ s3_path: str
137
138
  redis_semantic_cache_use_async: bool
138
139
  redis_semantic_cache_embedding_model: str
139
- redis_flush_size: NotRequired[int]
140
- redis_startup_nodes: NotRequired[List]
140
+ redis_flush_size: int
141
+ redis_startup_nodes: list
141
142
  disk_cache_dir: Any
142
- qdrant_api_base: NotRequired[str]
143
- qdrant_api_key: NotRequired[str]
144
- qdrant_collection_name: NotRequired[str]
145
- qdrant_quantization_config: NotRequired[str]
143
+ qdrant_api_base: str
144
+ qdrant_api_key: str
145
+ qdrant_collection_name: str
146
+ qdrant_quantization_config: str
146
147
  qdrant_semantic_cache_embedding_model: str
fabricatio/models/tool.py CHANGED
@@ -4,7 +4,7 @@ from importlib.machinery import ModuleSpec
4
4
  from importlib.util import module_from_spec
5
5
  from inspect import iscoroutinefunction, signature
6
6
  from types import CodeType, ModuleType
7
- from typing import Any, Callable, Dict, List, Optional, Self, overload
7
+ from typing import Any, Callable, Dict, List, Optional, Self, cast, overload
8
8
 
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.decorators import logging_execution_info, use_temp_module
@@ -136,7 +136,7 @@ class ToolExecutor(BaseModel):
136
136
 
137
137
  def inject_tools[M: ModuleType](self, module: Optional[M] = None) -> M:
138
138
  """Inject the tools into the provided module or default."""
139
- module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
139
+ module = module or cast(M, module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None)))
140
140
  for tool in self.candidates:
141
141
  logger.debug(f"Injecting tool: {tool.name}")
142
142
  setattr(module, tool.name, tool.invoke)
@@ -144,7 +144,7 @@ class ToolExecutor(BaseModel):
144
144
 
145
145
  def inject_data[M: ModuleType](self, module: Optional[M] = None) -> M:
146
146
  """Inject the data into the provided module or default."""
147
- module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None))
147
+ module = module or cast(M,module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None)))
148
148
  for key, value in self.data.items():
149
149
  logger.debug(f"Injecting data: {key}")
150
150
  setattr(module, key, value)
@@ -184,6 +184,6 @@ class ToolExecutor(BaseModel):
184
184
  tools = []
185
185
  while tool_name := recipe.pop(0):
186
186
  for toolbox in toolboxes:
187
- tools.append(toolbox[tool_name])
187
+ tools.append(toolbox.get(tool_name))
188
188
 
189
189
  return cls(candidates=tools)
@@ -1,7 +1,7 @@
1
1
  """This module contains classes that manage the usage of language models and tools in tasks."""
2
2
 
3
3
  from asyncio import gather
4
- from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
4
+ from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Type, Union, Unpack, overload
5
5
 
6
6
  import asyncstdlib
7
7
  import litellm
@@ -9,7 +9,7 @@ from fabricatio._rust_instances import template_manager
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.journal import logger
11
11
  from fabricatio.models.generic import ScopedConfig, WithBriefing
12
- from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
12
+ from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
13
13
  from fabricatio.models.task import Task
14
14
  from fabricatio.models.tool import Tool, ToolBox
15
15
  from fabricatio.models.utils import Messages
@@ -20,12 +20,13 @@ from litellm.types.utils import (
20
20
  EmbeddingResponse,
21
21
  ModelResponse,
22
22
  StreamingChoices,
23
+ TextChoices,
23
24
  )
24
25
  from litellm.utils import CustomStreamWrapper
25
26
  from more_itertools import duplicates_everseen
26
27
  from pydantic import Field, NonNegativeInt, PositiveInt
27
28
 
28
- if configs.cache.enabled:
29
+ if configs.cache.enabled and configs.cache.type:
29
30
  litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
30
31
  logger.success(f"{configs.cache.type.name} Cache enabled")
31
32
 
@@ -42,7 +43,7 @@ class LLMUsage(ScopedConfig):
42
43
  messages: List[Dict[str, str]],
43
44
  n: PositiveInt | None = None,
44
45
  **kwargs: Unpack[LLMKwargs],
45
- ) -> ModelResponse | CustomStreamWrapper:
46
+ ) -> ModelResponse:
46
47
  """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
47
48
 
48
49
  Args:
@@ -81,7 +82,7 @@ class LLMUsage(ScopedConfig):
81
82
  system_message: str = "",
82
83
  n: PositiveInt | None = None,
83
84
  **kwargs: Unpack[LLMKwargs],
84
- ) -> List[Choices | StreamingChoices]:
85
+ ) -> Sequence[TextChoices | Choices | StreamingChoices]:
85
86
  """Asynchronously invokes the language model with a question and optional system message.
86
87
 
87
88
  Args:
@@ -101,13 +102,14 @@ class LLMUsage(ScopedConfig):
101
102
  if isinstance(resp, ModelResponse):
102
103
  return resp.choices
103
104
  if isinstance(resp, CustomStreamWrapper):
104
- if not configs.debug.streaming_visible:
105
- return stream_chunk_builder(await asyncstdlib.list()).choices
105
+ if not configs.debug.streaming_visible and (pack := stream_chunk_builder(await asyncstdlib.list())):
106
+ return pack.choices
106
107
  chunks = []
107
108
  async for chunk in resp:
108
109
  chunks.append(chunk)
109
110
  print(chunk.choices[0].delta.content or "", end="") # noqa: T201
110
- return stream_chunk_builder(chunks).choices
111
+ if pack := stream_chunk_builder(chunks):
112
+ return pack.choices
111
113
  logger.critical(err := f"Unexpected response type: {type(resp)}")
112
114
  raise ValueError(err)
113
115
 
@@ -166,15 +168,15 @@ class LLMUsage(ScopedConfig):
166
168
  for q, sm in zip(q_seq, sm_seq, strict=True)
167
169
  ]
168
170
  )
169
- return [r.pop().message.content for r in res]
171
+ return [r[0].message.content for r in res]
170
172
  case (list(q_seq), str(sm)):
171
173
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
172
- return [r.pop().message.content for r in res]
174
+ return [r[0].message.content for r in res]
173
175
  case (str(q), list(sm_seq)):
174
176
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
175
- return [r.pop().message.content for r in res]
177
+ return [r[0].message.content for r in res]
176
178
  case (str(q), str(sm)):
177
- return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs)).pop()).message.content
179
+ return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
178
180
  case _:
179
181
  raise RuntimeError("Should not reach here.")
180
182
 
@@ -185,19 +187,45 @@ class LLMUsage(ScopedConfig):
185
187
  validator: Callable[[str], T | None],
186
188
  default: T,
187
189
  max_validations: PositiveInt = 2,
188
- system_message: str = "",
189
- **kwargs: Unpack[LLMKwargs],
190
+ **kwargs: Unpack[GenerateKwargs],
190
191
  ) -> T: ...
191
-
192
+ @overload
192
193
  async def aask_validate[T](
193
194
  self,
194
195
  question: str,
195
196
  validator: Callable[[str], T | None],
197
+ default: None = None,
198
+ max_validations: PositiveInt = 2,
199
+ **kwargs: Unpack[GenerateKwargs],
200
+ ) -> Optional[T]: ...
201
+
202
+ @overload
203
+ async def aask_validate[T](
204
+ self,
205
+ question: List[str],
206
+ validator: Callable[[str], T | None],
207
+ default: None = None,
208
+ max_validations: PositiveInt = 2,
209
+ **kwargs: Unpack[GenerateKwargs],
210
+ ) -> List[Optional[T]]: ...
211
+ @overload
212
+ async def aask_validate[T](
213
+ self,
214
+ question: List[str],
215
+ validator: Callable[[str], T | None],
216
+ default: T,
217
+ max_validations: PositiveInt = 2,
218
+ **kwargs: Unpack[GenerateKwargs],
219
+ ) -> List[T]: ...
220
+
221
+ async def aask_validate[T](
222
+ self,
223
+ question: str | List[str],
224
+ validator: Callable[[str], T | None],
196
225
  default: Optional[T] = None,
197
226
  max_validations: PositiveInt = 2,
198
- system_message: str = "",
199
- **kwargs: Unpack[LLMKwargs],
200
- ) -> Optional[T]:
227
+ **kwargs: Unpack[GenerateKwargs],
228
+ ) -> Optional[T] | List[Optional[T]] | List[T] | T:
201
229
  """Asynchronously asks a question and validates the response using a given validator.
202
230
 
203
231
  Args:
@@ -205,59 +233,42 @@ class LLMUsage(ScopedConfig):
205
233
  validator (Callable[[str], T | None]): A function to validate the response.
206
234
  default (T | None): Default value to return if validation fails. Defaults to None.
207
235
  max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
208
- system_message (str): System message to include in the request. Defaults to an empty string.
209
236
  **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
210
237
 
211
238
  Returns:
212
239
  T: The validated response.
213
240
 
214
241
  """
215
- for i in range(max_validations):
216
- if (
217
- response := await self.aask(
218
- question=question,
219
- system_message=system_message,
220
- **kwargs,
221
- )
222
- ) and (validated := validator(response)):
223
- logger.debug(f"Successfully validated the response at {i}th attempt.")
224
- return validated
225
- kwargs["no_cache"] = True
226
- logger.debug("Closed the cache for the next attempt")
227
- if default is None:
228
- logger.error(f"Failed to validate the response after {max_validations} attempts.")
229
- return default
230
-
231
- async def aask_validate_batch[T](
232
- self,
233
- questions: List[str],
234
- validator: Callable[[str], T | None],
235
- **kwargs: Unpack[GenerateKwargs[T]],
236
- ) -> List[T]:
237
- """Asynchronously asks a batch of questions and validates the responses using a given validator.
238
242
 
239
- Args:
240
- questions (List[str]): The list of questions to ask.
241
- validator (Callable[[str], T | None]): A function to validate the response.
242
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
243
-
244
- Returns:
245
- T: The validated response.
246
-
247
- Raises:
248
- ValueError: If the response fails to validate after the maximum number of attempts.
249
- """
250
- return await gather(*[self.aask_validate(question, validator, **kwargs) for question in questions])
243
+ async def _inner(q: str) -> Optional[T]:
244
+ for lap in range(max_validations):
245
+ try:
246
+ if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
247
+ logger.debug(f"Successfully validated the response at {lap}th attempt.")
248
+ return validated
249
+ except Exception as e: # noqa: BLE001
250
+ logger.error(f"Error during validation: \n{e}")
251
+ break
252
+ kwargs["no_cache"] = True
253
+ logger.debug("Closed the cache for the next attempt")
254
+ if default is None:
255
+ logger.error(f"Failed to validate the response after {max_validations} attempts.")
256
+ return default
257
+
258
+ if isinstance(question, str):
259
+ return await _inner(question)
260
+
261
+ return await gather(*[_inner(q) for q in question])
251
262
 
252
263
  async def aliststr(
253
- self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[GenerateKwargs[List[str]]]
264
+ self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
254
265
  ) -> List[str]:
255
266
  """Asynchronously generates a list of strings based on a given requirement.
256
267
 
257
268
  Args:
258
269
  requirement (str): The requirement for the list of strings.
259
270
  k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
260
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
271
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
261
272
 
262
273
  Returns:
263
274
  List[str]: The validated response as a list of strings.
@@ -289,12 +300,12 @@ class LLMUsage(ScopedConfig):
289
300
  **kwargs,
290
301
  )
291
302
 
292
- async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[GenerateKwargs[List[str]]]) -> str:
303
+ async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[ValidateKwargs[List[str]]]) -> str:
293
304
  """Asynchronously generates a single path string based on a given requirement.
294
305
 
295
306
  Args:
296
307
  requirement (str): The requirement for the list of strings.
297
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
308
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
298
309
 
299
310
  Returns:
300
311
  str: The validated response as a single string.
@@ -312,7 +323,7 @@ class LLMUsage(ScopedConfig):
312
323
  instruction: str,
313
324
  choices: List[T],
314
325
  k: NonNegativeInt = 0,
315
- **kwargs: Unpack[GenerateKwargs[List[T]]],
326
+ **kwargs: Unpack[ValidateKwargs[List[T]]],
316
327
  ) -> List[T]:
317
328
  """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
318
329
 
@@ -320,7 +331,7 @@ class LLMUsage(ScopedConfig):
320
331
  instruction (str): The user-provided instruction/question description.
321
332
  choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
322
333
  k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
323
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
334
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
324
335
 
325
336
  Returns:
326
337
  List[T]: The final validated selection result list, with element types matching the input `choices`.
@@ -363,14 +374,14 @@ class LLMUsage(ScopedConfig):
363
374
  self,
364
375
  instruction: str,
365
376
  choices: List[T],
366
- **kwargs: Unpack[GenerateKwargs[List[T]]],
377
+ **kwargs: Unpack[ValidateKwargs[List[T]]],
367
378
  ) -> T:
368
379
  """Asynchronously picks a single choice from a list of options using AI validation.
369
380
 
370
381
  Args:
371
382
  instruction (str): The user-provided instruction/question description.
372
383
  choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
373
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
384
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
374
385
 
375
386
  Returns:
376
387
  T: The single selected item from the choices list.
@@ -392,7 +403,7 @@ class LLMUsage(ScopedConfig):
392
403
  prompt: str,
393
404
  affirm_case: str = "",
394
405
  deny_case: str = "",
395
- **kwargs: Unpack[GenerateKwargs[bool]],
406
+ **kwargs: Unpack[ValidateKwargs[bool]],
396
407
  ) -> bool:
397
408
  """Asynchronously judges a prompt using AI validation.
398
409
 
@@ -400,7 +411,7 @@ class LLMUsage(ScopedConfig):
400
411
  prompt (str): The input prompt to be judged.
401
412
  affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
402
413
  deny_case (str): The negative case for the AI model. Defaults to an empty string.
403
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
414
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
404
415
 
405
416
  Returns:
406
417
  bool: The judgment result (True or False) based on the AI's response.
@@ -506,7 +517,6 @@ class ToolBoxUsage(LLMUsage):
506
517
  async def choose_toolboxes(
507
518
  self,
508
519
  task: Task,
509
- system_message: str = "",
510
520
  **kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
511
521
  ) -> List[ToolBox]:
512
522
  """Asynchronously executes a multi-choice decision-making process to choose toolboxes.
@@ -525,7 +535,6 @@ class ToolBoxUsage(LLMUsage):
525
535
  return await self.achoose(
526
536
  instruction=task.briefing,
527
537
  choices=list(self.toolboxes),
528
- system_message=system_message,
529
538
  **kwargs,
530
539
  )
531
540
 
fabricatio/parser.py CHANGED
@@ -1,12 +1,14 @@
1
1
  """A module to parse text using regular expressions."""
2
2
 
3
- from typing import Any, Callable, Optional, Self, Tuple, Type
3
+ from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
4
4
 
5
5
  import orjson
6
6
  import regex
7
+ from json_repair import repair_json
7
8
  from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
8
9
  from regex import Pattern, compile
9
10
 
11
+ from fabricatio.config import configs
10
12
  from fabricatio.journal import logger
11
13
 
12
14
 
@@ -25,12 +27,31 @@ class Capture(BaseModel):
25
27
  """The regular expression pattern to search for."""
26
28
  flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
27
29
  """The flags to use when compiling the regular expression pattern."""
30
+ capture_type: Optional[str] = None
31
+ """The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
28
32
  _compiled: Pattern = PrivateAttr()
29
33
 
30
34
  def model_post_init(self, __context: Any) -> None:
31
35
  """Initialize the compiled pattern."""
32
36
  self._compiled = compile(self.pattern, self.flags)
33
37
 
38
+ def fix[T](self, text: str | Iterable[str]|T) -> str | List[str]|T:
39
+ """Fix the text using the pattern.
40
+
41
+ Args:
42
+ text (str | List[str]): The text to fix.
43
+
44
+ Returns:
45
+ str | List[str]: The fixed text with the same type as input.
46
+ """
47
+ match self.capture_type:
48
+ case "json":
49
+ if isinstance(text, str):
50
+ return repair_json(text,ensure_ascii=False)
51
+ return [repair_json(item) for item in text]
52
+ case _:
53
+ return text
54
+
34
55
  def capture(self, text: str) -> Tuple[str, ...] | str | None:
35
56
  """Capture the first occurrence of the pattern in the given text.
36
57
 
@@ -44,12 +65,12 @@ class Capture(BaseModel):
44
65
  match = self._compiled.search(text)
45
66
  if match is None:
46
67
  return None
47
-
68
+ groups = self.fix(match.groups()) if configs.general.use_json_repair else match.groups()
48
69
  if self.target_groups:
49
- cap = tuple(match.group(g) for g in self.target_groups)
70
+ cap = tuple(groups[g - 1] for g in self.target_groups)
50
71
  logger.debug(f"Captured text: {'\n\n'.join(cap)}")
51
72
  return cap
52
- cap = match.group(1)
73
+ cap = groups[0]
53
74
  logger.debug(f"Captured text: \n{cap}")
54
75
  return cap
55
76
 
@@ -111,7 +132,7 @@ class Capture(BaseModel):
111
132
  Returns:
112
133
  Self: The instance of the class with the captured code block.
113
134
  """
114
- return cls(pattern=f"```{language}\n(.*?)\n```")
135
+ return cls(pattern=f"```{language}\n(.*?)\n```", capture_type=language)
115
136
 
116
137
 
117
138
  JsonCapture = Capture.capture_code_block("json")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fabricatio
3
- Version: 0.2.5.dev3
3
+ Version: 0.2.5.dev5
4
4
  Classifier: License :: OSI Approved :: MIT License
5
5
  Classifier: Programming Language :: Rust
6
6
  Classifier: Programming Language :: Python :: 3.12
@@ -11,6 +11,7 @@ Classifier: Typing :: Typed
11
11
  Requires-Dist: appdirs>=1.4.4
12
12
  Requires-Dist: asyncio>=3.4.3
13
13
  Requires-Dist: asyncstdlib>=3.13.0
14
+ Requires-Dist: json-repair>=0.39.1
14
15
  Requires-Dist: litellm>=1.60.0
15
16
  Requires-Dist: loguru>=0.7.3
16
17
  Requires-Dist: magika>=0.5.1