fabricatio 0.2.5.dev4__cp312-cp312-win_amd64.whl → 0.2.5.dev5__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,92 +1,90 @@
1
1
  """This module contains the types for the keyword arguments of the methods in the models module."""
2
2
 
3
- from typing import Any, List, NotRequired, Set, TypedDict
3
+ from typing import Any, TypedDict
4
4
 
5
5
  from litellm.caching.caching import CacheMode
6
6
  from litellm.types.caching import CachingSupportedCallTypes
7
- from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
8
7
 
9
8
 
10
- class CollectionSimpleConfigKwargs(TypedDict):
9
+ class CollectionSimpleConfigKwargs(TypedDict, total=False):
11
10
  """Configuration parameters for a vector collection.
12
11
 
13
12
  These arguments are typically used when configuring connections to vector databases.
14
13
  """
15
14
 
16
- dimension: NotRequired[int]
17
- timeout: NotRequired[float]
15
+ dimension: int
16
+ timeout: float
18
17
 
19
18
 
20
- class FetchKwargs(TypedDict):
19
+ class FetchKwargs(TypedDict, total=False):
21
20
  """Arguments for fetching data from vector collections.
22
21
 
23
22
  Controls how data is retrieved from vector databases, including filtering
24
23
  and result limiting parameters.
25
24
  """
26
25
 
27
- collection_name: NotRequired[str]
28
- similarity_threshold: NotRequired[float]
29
- result_per_query: NotRequired[int]
26
+ collection_name: str
27
+ similarity_threshold: float
28
+ result_per_query: int
30
29
 
31
30
 
32
- class EmbeddingKwargs(TypedDict):
31
+ class EmbeddingKwargs(TypedDict, total=False):
33
32
  """Configuration parameters for text embedding operations.
34
33
 
35
34
  These settings control the behavior of embedding models that convert text
36
35
  to vector representations.
37
36
  """
38
37
 
39
- model: NotRequired[str]
40
- dimensions: NotRequired[int]
41
- timeout: NotRequired[PositiveInt]
42
- caching: NotRequired[bool]
38
+ model: str
39
+ dimensions: int
40
+ timeout: int
41
+ caching: bool
43
42
 
44
43
 
45
- class LLMKwargs(TypedDict):
44
+ class LLMKwargs(TypedDict, total=False):
46
45
  """Configuration parameters for language model inference.
47
46
 
48
47
  These arguments control the behavior of large language model calls,
49
48
  including generation parameters and caching options.
50
49
  """
51
50
 
52
- model: NotRequired[str]
53
- temperature: NotRequired[NonNegativeFloat]
54
- stop: NotRequired[str | List[str]]
55
- top_p: NotRequired[NonNegativeFloat]
56
- max_tokens: NotRequired[PositiveInt]
57
- stream: NotRequired[bool]
58
- timeout: NotRequired[PositiveInt]
59
- max_retries: NotRequired[PositiveInt]
60
- no_cache: NotRequired[bool] # If use cache in this call
61
- no_store: NotRequired[bool] # If store the response of this call to cache
62
- cache_ttl: NotRequired[int] # how long the stored cache is alive, in seconds
63
- s_maxage: NotRequired[int] # max accepted age of cached response, in seconds
64
-
65
-
66
- class ValidateKwargs[T](LLMKwargs):
67
- """Arguments for content validation operations.
51
+ model: str
52
+ temperature: float
53
+ stop: str | list[str]
54
+ top_p: float
55
+ max_tokens: int
56
+ stream: bool
57
+ timeout: int
58
+ max_retries: int
59
+ no_cache: bool # if the req uses cache in this call
60
+ no_store: bool # If store the response of this call to cache
61
+ cache_ttl: int # how long the stored cache is alive, in seconds
62
+ s_maxage: int # max accepted age of cached response, in seconds
63
+
64
+
65
+ class GenerateKwargs(LLMKwargs, total=False):
66
+ """Arguments for content generation operations.
68
67
 
69
- Extends LLMKwargs with additional parameters specific to validation tasks,
70
- such as limiting the number of validation attempts.
68
+ Extends LLMKwargs with additional parameters specific to generation tasks,
69
+ such as the number of generated items and the system message.
71
70
  """
72
71
 
73
- default: NotRequired[T]
74
- max_validations: NotRequired[PositiveInt]
72
+ system_message: str
75
73
 
76
74
 
77
- # noinspection PyTypedDict
78
- class GenerateKwargs[T](ValidateKwargs[T]):
79
- """Arguments for content generation operations.
75
+ class ValidateKwargs[T](GenerateKwargs, total=False):
76
+ """Arguments for content validation operations.
80
77
 
81
- Extends ValidateKwargs with parameters specific to text generation,
82
- including system prompt configuration.
78
+ Extends LLMKwargs with additional parameters specific to validation tasks,
79
+ such as limiting the number of validation attempts.
83
80
  """
84
81
 
85
- system_message: NotRequired[str]
82
+ default: T
83
+ max_validations: int
86
84
 
87
85
 
88
86
  # noinspection PyTypedDict
89
- class ReviewKwargs[T](GenerateKwargs[T]):
87
+ class ReviewKwargs[T](ValidateKwargs[T], total=False):
90
88
  """Arguments for content review operations.
91
89
 
92
90
  Extends GenerateKwargs with parameters for evaluating content against
@@ -94,18 +92,18 @@ class ReviewKwargs[T](GenerateKwargs[T]):
94
92
  """
95
93
 
96
94
  topic: str
97
- criteria: NotRequired[Set[str]]
95
+ criteria: set[str]
98
96
 
99
97
 
100
98
  # noinspection PyTypedDict
101
- class ChooseKwargs[T](GenerateKwargs[T]):
99
+ class ChooseKwargs[T](ValidateKwargs[T], total=False):
102
100
  """Arguments for selection operations.
103
101
 
104
102
  Extends GenerateKwargs with parameters for selecting among options,
105
103
  such as the number of items to choose.
106
104
  """
107
105
 
108
- k: NotRequired[NonNegativeInt]
106
+ k: int
109
107
 
110
108
 
111
109
  class CacheKwargs(TypedDict, total=False):
@@ -115,35 +113,35 @@ class CacheKwargs(TypedDict, total=False):
115
113
  including in-memory, Redis, S3, and vector database caching options.
116
114
  """
117
115
 
118
- mode: NotRequired[CacheMode] # when default_on cache is always on, when default_off cache is opt in
119
- host: NotRequired[str]
120
- port: NotRequired[str]
121
- password: NotRequired[str]
122
- namespace: NotRequired[str]
123
- ttl: NotRequired[float]
124
- default_in_memory_ttl: NotRequired[float]
125
- default_in_redis_ttl: NotRequired[float]
126
- similarity_threshold: NotRequired[float]
127
- supported_call_types: NotRequired[List[CachingSupportedCallTypes]]
116
+ mode: CacheMode # when default_on cache is always on, when default_off cache is opt in
117
+ host: str
118
+ port: str
119
+ password: str
120
+ namespace: str
121
+ ttl: float
122
+ default_in_memory_ttl: float
123
+ default_in_redis_ttl: float
124
+ similarity_threshold: float
125
+ supported_call_types: list[CachingSupportedCallTypes]
128
126
  # s3 Bucket, boto3 configuration
129
- s3_bucket_name: NotRequired[str]
130
- s3_region_name: NotRequired[str]
131
- s3_api_version: NotRequired[str]
132
- s3_use_ssl: NotRequired[bool]
133
- s3_verify: NotRequired[bool | str]
134
- s3_endpoint_url: NotRequired[str]
135
- s3_aws_access_key_id: NotRequired[str]
136
- s3_aws_secret_access_key: NotRequired[str]
137
- s3_aws_session_token: NotRequired[str]
138
- s3_config: NotRequired[Any]
139
- s3_path: NotRequired[str]
127
+ s3_bucket_name: str
128
+ s3_region_name: str
129
+ s3_api_version: str
130
+ s3_use_ssl: bool
131
+ s3_verify: bool | str
132
+ s3_endpoint_url: str
133
+ s3_aws_access_key_id: str
134
+ s3_aws_secret_access_key: str
135
+ s3_aws_session_token: str
136
+ s3_config: Any
137
+ s3_path: str
140
138
  redis_semantic_cache_use_async: bool
141
139
  redis_semantic_cache_embedding_model: str
142
- redis_flush_size: NotRequired[int]
143
- redis_startup_nodes: NotRequired[List]
140
+ redis_flush_size: int
141
+ redis_startup_nodes: list
144
142
  disk_cache_dir: Any
145
- qdrant_api_base: NotRequired[str]
146
- qdrant_api_key: NotRequired[str]
147
- qdrant_collection_name: NotRequired[str]
148
- qdrant_quantization_config: NotRequired[str]
143
+ qdrant_api_base: str
144
+ qdrant_api_key: str
145
+ qdrant_collection_name: str
146
+ qdrant_quantization_config: str
149
147
  qdrant_semantic_cache_embedding_model: str
fabricatio/models/tool.py CHANGED
@@ -4,7 +4,7 @@ from importlib.machinery import ModuleSpec
4
4
  from importlib.util import module_from_spec
5
5
  from inspect import iscoroutinefunction, signature
6
6
  from types import CodeType, ModuleType
7
- from typing import Any, Callable, Dict, List, Optional, Self, overload
7
+ from typing import Any, Callable, Dict, List, Optional, Self, cast, overload
8
8
 
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.decorators import logging_execution_info, use_temp_module
@@ -136,7 +136,7 @@ class ToolExecutor(BaseModel):
136
136
 
137
137
  def inject_tools[M: ModuleType](self, module: Optional[M] = None) -> M:
138
138
  """Inject the tools into the provided module or default."""
139
- module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
139
+ module = module or cast(M, module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None)))
140
140
  for tool in self.candidates:
141
141
  logger.debug(f"Injecting tool: {tool.name}")
142
142
  setattr(module, tool.name, tool.invoke)
@@ -144,7 +144,7 @@ class ToolExecutor(BaseModel):
144
144
 
145
145
  def inject_data[M: ModuleType](self, module: Optional[M] = None) -> M:
146
146
  """Inject the data into the provided module or default."""
147
- module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None))
147
+ module = module or cast(M,module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None)))
148
148
  for key, value in self.data.items():
149
149
  logger.debug(f"Injecting data: {key}")
150
150
  setattr(module, key, value)
@@ -184,6 +184,6 @@ class ToolExecutor(BaseModel):
184
184
  tools = []
185
185
  while tool_name := recipe.pop(0):
186
186
  for toolbox in toolboxes:
187
- tools.append(toolbox[tool_name])
187
+ tools.append(toolbox.get(tool_name))
188
188
 
189
189
  return cls(candidates=tools)
@@ -1,7 +1,7 @@
1
1
  """This module contains classes that manage the usage of language models and tools in tasks."""
2
2
 
3
3
  from asyncio import gather
4
- from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
4
+ from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Type, Union, Unpack, overload
5
5
 
6
6
  import asyncstdlib
7
7
  import litellm
@@ -9,7 +9,7 @@ from fabricatio._rust_instances import template_manager
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.journal import logger
11
11
  from fabricatio.models.generic import ScopedConfig, WithBriefing
12
- from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
12
+ from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
13
13
  from fabricatio.models.task import Task
14
14
  from fabricatio.models.tool import Tool, ToolBox
15
15
  from fabricatio.models.utils import Messages
@@ -20,12 +20,13 @@ from litellm.types.utils import (
20
20
  EmbeddingResponse,
21
21
  ModelResponse,
22
22
  StreamingChoices,
23
+ TextChoices,
23
24
  )
24
25
  from litellm.utils import CustomStreamWrapper
25
26
  from more_itertools import duplicates_everseen
26
27
  from pydantic import Field, NonNegativeInt, PositiveInt
27
28
 
28
- if configs.cache.enabled:
29
+ if configs.cache.enabled and configs.cache.type:
29
30
  litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
30
31
  logger.success(f"{configs.cache.type.name} Cache enabled")
31
32
 
@@ -42,7 +43,7 @@ class LLMUsage(ScopedConfig):
42
43
  messages: List[Dict[str, str]],
43
44
  n: PositiveInt | None = None,
44
45
  **kwargs: Unpack[LLMKwargs],
45
- ) -> ModelResponse | CustomStreamWrapper:
46
+ ) -> ModelResponse:
46
47
  """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
47
48
 
48
49
  Args:
@@ -81,7 +82,7 @@ class LLMUsage(ScopedConfig):
81
82
  system_message: str = "",
82
83
  n: PositiveInt | None = None,
83
84
  **kwargs: Unpack[LLMKwargs],
84
- ) -> List[Choices | StreamingChoices]:
85
+ ) -> Sequence[TextChoices | Choices | StreamingChoices]:
85
86
  """Asynchronously invokes the language model with a question and optional system message.
86
87
 
87
88
  Args:
@@ -101,13 +102,14 @@ class LLMUsage(ScopedConfig):
101
102
  if isinstance(resp, ModelResponse):
102
103
  return resp.choices
103
104
  if isinstance(resp, CustomStreamWrapper):
104
- if not configs.debug.streaming_visible:
105
- return stream_chunk_builder(await asyncstdlib.list()).choices
105
+ if not configs.debug.streaming_visible and (pack := stream_chunk_builder(await asyncstdlib.list())):
106
+ return pack.choices
106
107
  chunks = []
107
108
  async for chunk in resp:
108
109
  chunks.append(chunk)
109
110
  print(chunk.choices[0].delta.content or "", end="") # noqa: T201
110
- return stream_chunk_builder(chunks).choices
111
+ if pack := stream_chunk_builder(chunks):
112
+ return pack.choices
111
113
  logger.critical(err := f"Unexpected response type: {type(resp)}")
112
114
  raise ValueError(err)
113
115
 
@@ -166,15 +168,15 @@ class LLMUsage(ScopedConfig):
166
168
  for q, sm in zip(q_seq, sm_seq, strict=True)
167
169
  ]
168
170
  )
169
- return [r.pop().message.content for r in res]
171
+ return [r[0].message.content for r in res]
170
172
  case (list(q_seq), str(sm)):
171
173
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
172
- return [r.pop().message.content for r in res]
174
+ return [r[0].message.content for r in res]
173
175
  case (str(q), list(sm_seq)):
174
176
  res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
175
- return [r.pop().message.content for r in res]
177
+ return [r[0].message.content for r in res]
176
178
  case (str(q), str(sm)):
177
- return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs)).pop()).message.content
179
+ return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
178
180
  case _:
179
181
  raise RuntimeError("Should not reach here.")
180
182
 
@@ -185,8 +187,7 @@ class LLMUsage(ScopedConfig):
185
187
  validator: Callable[[str], T | None],
186
188
  default: T,
187
189
  max_validations: PositiveInt = 2,
188
- system_message: str = "",
189
- **kwargs: Unpack[LLMKwargs],
190
+ **kwargs: Unpack[GenerateKwargs],
190
191
  ) -> T: ...
191
192
  @overload
192
193
  async def aask_validate[T](
@@ -195,19 +196,36 @@ class LLMUsage(ScopedConfig):
195
196
  validator: Callable[[str], T | None],
196
197
  default: None = None,
197
198
  max_validations: PositiveInt = 2,
198
- system_message: str = "",
199
- **kwargs: Unpack[LLMKwargs],
199
+ **kwargs: Unpack[GenerateKwargs],
200
200
  ) -> Optional[T]: ...
201
201
 
202
+ @overload
202
203
  async def aask_validate[T](
203
204
  self,
204
- question: str,
205
+ question: List[str],
206
+ validator: Callable[[str], T | None],
207
+ default: None = None,
208
+ max_validations: PositiveInt = 2,
209
+ **kwargs: Unpack[GenerateKwargs],
210
+ ) -> List[Optional[T]]: ...
211
+ @overload
212
+ async def aask_validate[T](
213
+ self,
214
+ question: List[str],
215
+ validator: Callable[[str], T | None],
216
+ default: T,
217
+ max_validations: PositiveInt = 2,
218
+ **kwargs: Unpack[GenerateKwargs],
219
+ ) -> List[T]: ...
220
+
221
+ async def aask_validate[T](
222
+ self,
223
+ question: str | List[str],
205
224
  validator: Callable[[str], T | None],
206
225
  default: Optional[T] = None,
207
226
  max_validations: PositiveInt = 2,
208
- system_message: str = "",
209
- **kwargs: Unpack[LLMKwargs],
210
- ) -> Optional[T]:
227
+ **kwargs: Unpack[GenerateKwargs],
228
+ ) -> Optional[T] | List[Optional[T]] | List[T] | T:
211
229
  """Asynchronously asks a question and validates the response using a given validator.
212
230
 
213
231
  Args:
@@ -215,59 +233,42 @@ class LLMUsage(ScopedConfig):
215
233
  validator (Callable[[str], T | None]): A function to validate the response.
216
234
  default (T | None): Default value to return if validation fails. Defaults to None.
217
235
  max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
218
- system_message (str): System message to include in the request. Defaults to an empty string.
219
236
  **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
220
237
 
221
238
  Returns:
222
239
  T: The validated response.
223
240
 
224
241
  """
225
- for i in range(max_validations):
226
- if (
227
- response := await self.aask(
228
- question=question,
229
- system_message=system_message,
230
- **kwargs,
231
- )
232
- ) and (validated := validator(response)):
233
- logger.debug(f"Successfully validated the response at {i}th attempt.")
234
- return validated
235
- kwargs["no_cache"] = True
236
- logger.debug("Closed the cache for the next attempt")
237
- if default is None:
238
- logger.error(f"Failed to validate the response after {max_validations} attempts.")
239
- return default
240
-
241
- async def aask_validate_batch[T](
242
- self,
243
- questions: List[str],
244
- validator: Callable[[str], T | None],
245
- **kwargs: Unpack[GenerateKwargs[T]],
246
- ) -> List[T]:
247
- """Asynchronously asks a batch of questions and validates the responses using a given validator.
248
242
 
249
- Args:
250
- questions (List[str]): The list of questions to ask.
251
- validator (Callable[[str], T | None]): A function to validate the response.
252
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
253
-
254
- Returns:
255
- T: The validated response.
256
-
257
- Raises:
258
- ValueError: If the response fails to validate after the maximum number of attempts.
259
- """
260
- return await gather(*[self.aask_validate(question, validator, **kwargs) for question in questions])
243
+ async def _inner(q: str) -> Optional[T]:
244
+ for lap in range(max_validations):
245
+ try:
246
+ if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
247
+ logger.debug(f"Successfully validated the response at {lap}th attempt.")
248
+ return validated
249
+ except Exception as e: # noqa: BLE001
250
+ logger.error(f"Error during validation: \n{e}")
251
+ break
252
+ kwargs["no_cache"] = True
253
+ logger.debug("Closed the cache for the next attempt")
254
+ if default is None:
255
+ logger.error(f"Failed to validate the response after {max_validations} attempts.")
256
+ return default
257
+
258
+ if isinstance(question, str):
259
+ return await _inner(question)
260
+
261
+ return await gather(*[_inner(q) for q in question])
261
262
 
262
263
  async def aliststr(
263
- self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[GenerateKwargs[List[str]]]
264
+ self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
264
265
  ) -> List[str]:
265
266
  """Asynchronously generates a list of strings based on a given requirement.
266
267
 
267
268
  Args:
268
269
  requirement (str): The requirement for the list of strings.
269
270
  k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
270
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
271
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
271
272
 
272
273
  Returns:
273
274
  List[str]: The validated response as a list of strings.
@@ -299,12 +300,12 @@ class LLMUsage(ScopedConfig):
299
300
  **kwargs,
300
301
  )
301
302
 
302
- async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[GenerateKwargs[List[str]]]) -> str:
303
+ async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[ValidateKwargs[List[str]]]) -> str:
303
304
  """Asynchronously generates a single path string based on a given requirement.
304
305
 
305
306
  Args:
306
307
  requirement (str): The requirement for the list of strings.
307
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
308
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
308
309
 
309
310
  Returns:
310
311
  str: The validated response as a single string.
@@ -322,7 +323,7 @@ class LLMUsage(ScopedConfig):
322
323
  instruction: str,
323
324
  choices: List[T],
324
325
  k: NonNegativeInt = 0,
325
- **kwargs: Unpack[GenerateKwargs[List[T]]],
326
+ **kwargs: Unpack[ValidateKwargs[List[T]]],
326
327
  ) -> List[T]:
327
328
  """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
328
329
 
@@ -330,7 +331,7 @@ class LLMUsage(ScopedConfig):
330
331
  instruction (str): The user-provided instruction/question description.
331
332
  choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
332
333
  k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
333
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
334
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
334
335
 
335
336
  Returns:
336
337
  List[T]: The final validated selection result list, with element types matching the input `choices`.
@@ -373,14 +374,14 @@ class LLMUsage(ScopedConfig):
373
374
  self,
374
375
  instruction: str,
375
376
  choices: List[T],
376
- **kwargs: Unpack[GenerateKwargs[List[T]]],
377
+ **kwargs: Unpack[ValidateKwargs[List[T]]],
377
378
  ) -> T:
378
379
  """Asynchronously picks a single choice from a list of options using AI validation.
379
380
 
380
381
  Args:
381
382
  instruction (str): The user-provided instruction/question description.
382
383
  choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
383
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
384
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
384
385
 
385
386
  Returns:
386
387
  T: The single selected item from the choices list.
@@ -402,7 +403,7 @@ class LLMUsage(ScopedConfig):
402
403
  prompt: str,
403
404
  affirm_case: str = "",
404
405
  deny_case: str = "",
405
- **kwargs: Unpack[GenerateKwargs[bool]],
406
+ **kwargs: Unpack[ValidateKwargs[bool]],
406
407
  ) -> bool:
407
408
  """Asynchronously judges a prompt using AI validation.
408
409
 
@@ -410,7 +411,7 @@ class LLMUsage(ScopedConfig):
410
411
  prompt (str): The input prompt to be judged.
411
412
  affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
412
413
  deny_case (str): The negative case for the AI model. Defaults to an empty string.
413
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
414
+ **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
414
415
 
415
416
  Returns:
416
417
  bool: The judgment result (True or False) based on the AI's response.
@@ -516,7 +517,6 @@ class ToolBoxUsage(LLMUsage):
516
517
  async def choose_toolboxes(
517
518
  self,
518
519
  task: Task,
519
- system_message: str = "",
520
520
  **kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
521
521
  ) -> List[ToolBox]:
522
522
  """Asynchronously executes a multi-choice decision-making process to choose toolboxes.
@@ -535,7 +535,6 @@ class ToolBoxUsage(LLMUsage):
535
535
  return await self.achoose(
536
536
  instruction=task.briefing,
537
537
  choices=list(self.toolboxes),
538
- system_message=system_message,
539
538
  **kwargs,
540
539
  )
541
540
 
fabricatio/parser.py CHANGED
@@ -1,12 +1,14 @@
1
1
  """A module to parse text using regular expressions."""
2
2
 
3
- from typing import Any, Callable, Optional, Self, Tuple, Type
3
+ from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
4
4
 
5
5
  import orjson
6
6
  import regex
7
+ from json_repair import repair_json
7
8
  from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
8
9
  from regex import Pattern, compile
9
10
 
11
+ from fabricatio.config import configs
10
12
  from fabricatio.journal import logger
11
13
 
12
14
 
@@ -25,12 +27,31 @@ class Capture(BaseModel):
25
27
  """The regular expression pattern to search for."""
26
28
  flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
27
29
  """The flags to use when compiling the regular expression pattern."""
30
+ capture_type: Optional[str] = None
31
+ """The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
28
32
  _compiled: Pattern = PrivateAttr()
29
33
 
30
34
  def model_post_init(self, __context: Any) -> None:
31
35
  """Initialize the compiled pattern."""
32
36
  self._compiled = compile(self.pattern, self.flags)
33
37
 
38
+ def fix[T](self, text: str | Iterable[str]|T) -> str | List[str]|T:
39
+ """Fix the text using the pattern.
40
+
41
+ Args:
42
+ text (str | List[str]): The text to fix.
43
+
44
+ Returns:
45
+ str | List[str]: The fixed text with the same type as input.
46
+ """
47
+ match self.capture_type:
48
+ case "json":
49
+ if isinstance(text, str):
50
+ return repair_json(text,ensure_ascii=False)
51
+ return [repair_json(item) for item in text]
52
+ case _:
53
+ return text
54
+
34
55
  def capture(self, text: str) -> Tuple[str, ...] | str | None:
35
56
  """Capture the first occurrence of the pattern in the given text.
36
57
 
@@ -44,12 +65,12 @@ class Capture(BaseModel):
44
65
  match = self._compiled.search(text)
45
66
  if match is None:
46
67
  return None
47
-
68
+ groups = self.fix(match.groups()) if configs.general.use_json_repair else match.groups()
48
69
  if self.target_groups:
49
- cap = tuple(match.group(g) for g in self.target_groups)
70
+ cap = tuple(groups[g - 1] for g in self.target_groups)
50
71
  logger.debug(f"Captured text: {'\n\n'.join(cap)}")
51
72
  return cap
52
- cap = match.group(1)
73
+ cap = groups[0]
53
74
  logger.debug(f"Captured text: \n{cap}")
54
75
  return cap
55
76
 
@@ -111,7 +132,7 @@ class Capture(BaseModel):
111
132
  Returns:
112
133
  Self: The instance of the class with the captured code block.
113
134
  """
114
- return cls(pattern=f"```{language}\n(.*?)\n```")
135
+ return cls(pattern=f"```{language}\n(.*?)\n```", capture_type=language)
115
136
 
116
137
 
117
138
  JsonCapture = Capture.capture_code_block("json")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fabricatio
3
- Version: 0.2.5.dev4
3
+ Version: 0.2.5.dev5
4
4
  Classifier: License :: OSI Approved :: MIT License
5
5
  Classifier: Programming Language :: Rust
6
6
  Classifier: Programming Language :: Python :: 3.12
@@ -11,6 +11,7 @@ Classifier: Typing :: Typed
11
11
  Requires-Dist: appdirs>=1.4.4
12
12
  Requires-Dist: asyncio>=3.4.3
13
13
  Requires-Dist: asyncstdlib>=3.13.0
14
+ Requires-Dist: json-repair>=0.39.1
14
15
  Requires-Dist: litellm>=1.60.0
15
16
  Requires-Dist: loguru>=0.7.3
16
17
  Requires-Dist: magika>=0.5.1