fabricatio 0.2.13.dev3__cp312-cp312-win_amd64.whl → 0.3.14__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fabricatio/__init__.py +7 -14
  2. fabricatio/actions/article.py +58 -23
  3. fabricatio/actions/article_rag.py +6 -15
  4. fabricatio/actions/output.py +38 -3
  5. fabricatio/actions/rag.py +4 -4
  6. fabricatio/capabilities/advanced_judge.py +4 -7
  7. fabricatio/capabilities/advanced_rag.py +2 -1
  8. fabricatio/capabilities/censor.py +5 -4
  9. fabricatio/capabilities/check.py +6 -7
  10. fabricatio/capabilities/correct.py +5 -5
  11. fabricatio/capabilities/extract.py +7 -3
  12. fabricatio/capabilities/persist.py +103 -0
  13. fabricatio/capabilities/propose.py +2 -2
  14. fabricatio/capabilities/rag.py +43 -43
  15. fabricatio/capabilities/rating.py +11 -10
  16. fabricatio/capabilities/review.py +8 -6
  17. fabricatio/capabilities/task.py +22 -22
  18. fabricatio/decorators.py +4 -2
  19. fabricatio/{core.py → emitter.py} +35 -39
  20. fabricatio/fs/__init__.py +1 -2
  21. fabricatio/journal.py +2 -11
  22. fabricatio/models/action.py +14 -30
  23. fabricatio/models/extra/aricle_rag.py +14 -8
  24. fabricatio/models/extra/article_base.py +56 -25
  25. fabricatio/models/extra/article_essence.py +2 -1
  26. fabricatio/models/extra/article_main.py +16 -13
  27. fabricatio/models/extra/article_outline.py +2 -1
  28. fabricatio/models/extra/article_proposal.py +1 -1
  29. fabricatio/models/extra/rag.py +2 -2
  30. fabricatio/models/extra/rule.py +2 -1
  31. fabricatio/models/generic.py +56 -166
  32. fabricatio/models/kwargs_types.py +1 -54
  33. fabricatio/models/role.py +49 -26
  34. fabricatio/models/task.py +8 -9
  35. fabricatio/models/tool.py +7 -7
  36. fabricatio/models/usages.py +67 -61
  37. fabricatio/parser.py +60 -100
  38. fabricatio/rust.cp312-win_amd64.pyd +0 -0
  39. fabricatio/rust.pyi +469 -74
  40. fabricatio/utils.py +63 -162
  41. fabricatio-0.3.14.data/scripts/tdown.exe +0 -0
  42. fabricatio-0.3.14.data/scripts/ttm.exe +0 -0
  43. {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/METADATA +10 -15
  44. fabricatio-0.3.14.dist-info/RECORD +64 -0
  45. {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/WHEEL +1 -1
  46. fabricatio/config.py +0 -430
  47. fabricatio/constants.py +0 -20
  48. fabricatio/models/events.py +0 -120
  49. fabricatio/rust_instances.py +0 -10
  50. fabricatio-0.2.13.dev3.data/scripts/tdown.exe +0 -0
  51. fabricatio-0.2.13.dev3.data/scripts/ttm.exe +0 -0
  52. fabricatio-0.2.13.dev3.dist-info/RECORD +0 -67
  53. {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/task.py CHANGED
@@ -4,17 +4,16 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
4
4
  """
5
5
 
6
6
  from asyncio import Queue
7
- from typing import Any, Dict, List, Optional, Self
7
+ from typing import Any, Dict, List, Optional, Self, Union
8
8
 
9
- from fabricatio.config import configs
10
- from fabricatio.constants import TaskStatus
11
- from fabricatio.core import env
9
+ from fabricatio.emitter import env
12
10
  from fabricatio.journal import logger
13
- from fabricatio.models.events import Event, EventLike
14
11
  from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
15
- from fabricatio.rust_instances import TEMPLATE_MANAGER
12
+ from fabricatio.rust import CONFIG, TEMPLATE_MANAGER, Event, TaskStatus
16
13
  from pydantic import Field, PrivateAttr
17
14
 
15
+ type EventLike = Union[str, Event, List[str]]
16
+
18
17
 
19
18
  class Task[T](WithBriefing, ProposedAble, WithDependency):
20
19
  """A class representing a task with a status and output.
@@ -33,7 +32,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
33
32
  description: str = Field(default="")
34
33
  """A detailed explanation of the task that includes all necessary information. Should be clear and answer what, why, when, where, who, and how questions."""
35
34
 
36
- goals: List[str] = Field(default=[])
35
+ goals: List[str] = Field(default_factory=list)
37
36
  """A list of objectives that the task aims to accomplish. Each goal should be clear and specific. Complex tasks should be broken into multiple smaller goals."""
38
37
 
39
38
  namespace: List[str] = Field(default_factory=list)
@@ -65,7 +64,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
65
64
 
66
65
  def model_post_init(self, __context: Any) -> None:
67
66
  """Initialize the task with a namespace event."""
68
- self._namespace.segments.extend(self.namespace)
67
+ self._namespace.concat(self.namespace)
69
68
 
70
69
  def move_to(self, new_namespace: EventLike) -> Self:
71
70
  """Move the task to a new namespace.
@@ -266,7 +265,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
266
265
  str: The briefing of the task.
267
266
  """
268
267
  return TEMPLATE_MANAGER.render_template(
269
- configs.templates.task_briefing_template,
268
+ CONFIG.templates.task_briefing_template,
270
269
  self.model_dump(),
271
270
  )
272
271
 
fabricatio/models/tool.py CHANGED
@@ -10,11 +10,11 @@ from inspect import iscoroutinefunction, signature
10
10
  from types import CodeType, ModuleType
11
11
  from typing import Any, Callable, Dict, List, Optional, Self, cast, overload
12
12
 
13
- from fabricatio.config import configs
14
13
  from fabricatio.decorators import logging_execution_info, use_temp_module
15
14
  from fabricatio.journal import logger
16
- from fabricatio.models.generic import WithBriefing
17
- from pydantic import BaseModel, ConfigDict, Field
15
+ from fabricatio.models.generic import Base, WithBriefing
16
+ from fabricatio.rust import CONFIG
17
+ from pydantic import Field
18
18
 
19
19
 
20
20
  class Tool[**P, R](WithBriefing):
@@ -181,7 +181,7 @@ class ToolBox(WithBriefing):
181
181
  return hash(self.briefing)
182
182
 
183
183
 
184
- class ToolExecutor(BaseModel):
184
+ class ToolExecutor(Base):
185
185
  """A class representing a tool executor with a sequence of tools to execute.
186
186
 
187
187
  This class manages a sequence of tools and provides methods to inject tools and data into a module, execute the tools,
@@ -191,7 +191,7 @@ class ToolExecutor(BaseModel):
191
191
  candidates (List[Tool]): The sequence of tools to execute.
192
192
  data (Dict[str, Any]): The data that could be used when invoking the tools.
193
193
  """
194
- model_config = ConfigDict(use_attribute_docstrings=True)
194
+
195
195
  candidates: List[Tool] = Field(default_factory=list, frozen=True)
196
196
  """The sequence of tools to execute."""
197
197
 
@@ -210,7 +210,7 @@ class ToolExecutor(BaseModel):
210
210
  M: The module with injected tools.
211
211
  """
212
212
  module = module or cast(
213
- "M", module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
213
+ "M", module_from_spec(spec=ModuleSpec(name=CONFIG.toolbox.tool_module_name, loader=None))
214
214
  )
215
215
  for tool in self.candidates:
216
216
  logger.debug(f"Injecting tool: {tool.name}")
@@ -229,7 +229,7 @@ class ToolExecutor(BaseModel):
229
229
  M: The module with injected data.
230
230
  """
231
231
  module = module or cast(
232
- 'M', module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None))
232
+ "M", module_from_spec(spec=ModuleSpec(name=CONFIG.toolbox.data_module_name, loader=None))
233
233
  )
234
234
  for key, value in self.data.items():
235
235
  logger.debug(f"Injecting data: {key}")
@@ -1,22 +1,25 @@
1
1
  """This module contains classes that manage the usage of language models and tools in tasks."""
2
2
 
3
3
  import traceback
4
+ from abc import ABC
4
5
  from asyncio import gather
5
6
  from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
6
7
 
7
8
  import asyncstdlib
8
- import litellm
9
- from fabricatio.config import configs
10
9
  from fabricatio.decorators import logging_exec_time
11
10
  from fabricatio.journal import logger
12
11
  from fabricatio.models.generic import ScopedConfig, WithBriefing
13
12
  from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
14
13
  from fabricatio.models.task import Task
15
14
  from fabricatio.models.tool import Tool, ToolBox
16
- from fabricatio.parser import GenericCapture, JsonCapture
17
- from fabricatio.rust_instances import TEMPLATE_MANAGER
18
- from fabricatio.utils import ok
19
- from litellm import RateLimitError, Router, stream_chunk_builder # pyright: ignore [reportPrivateImportUsage]
15
+ from fabricatio.rust import CONFIG, TEMPLATE_MANAGER
16
+ from fabricatio.utils import first_available, ok
17
+ from litellm import ( # pyright: ignore [reportPrivateImportUsage]
18
+ RateLimitError,
19
+ Router,
20
+ aembedding,
21
+ stream_chunk_builder,
22
+ )
20
23
  from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
21
24
  from litellm.types.utils import (
22
25
  Choices,
@@ -29,22 +32,16 @@ from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore
29
32
  from more_itertools import duplicates_everseen
30
33
  from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
31
34
 
32
- if configs.cache.enabled and configs.cache.type:
33
- litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
34
- logger.debug(f"{configs.cache.type.name} Cache enabled")
35
-
36
35
  ROUTER = Router(
37
36
  routing_strategy="usage-based-routing-v2",
38
- default_max_parallel_requests=configs.routing.max_parallel_requests,
39
- allowed_fails=configs.routing.allowed_fails,
40
- retry_after=configs.routing.retry_after,
41
- cooldown_time=configs.routing.cooldown_time,
42
- cache_responses=configs.cache.enabled,
43
- cache_kwargs=configs.cache.params,
37
+ default_max_parallel_requests=CONFIG.routing.max_parallel_requests,
38
+ allowed_fails=CONFIG.routing.allowed_fails,
39
+ retry_after=CONFIG.routing.retry_after,
40
+ cooldown_time=CONFIG.routing.cooldown_time,
44
41
  )
45
42
 
46
43
 
47
- class LLMUsage(ScopedConfig):
44
+ class LLMUsage(ScopedConfig, ABC):
48
45
  """Class that manages LLM (Large Language Model) usage parameters and methods.
49
46
 
50
47
  This class provides methods to deploy LLMs, query them for responses, and handle various configurations
@@ -86,48 +83,48 @@ class LLMUsage(ScopedConfig):
86
83
  Deployment(
87
84
  model_name=(
88
85
  m_name := ok(
89
- kwargs.get("model") or self.llm_model or configs.llm.model, "model name is not set at any place"
86
+ kwargs.get("model") or self.llm_model or CONFIG.llm.model, "model name is not set at any place"
90
87
  )
91
88
  ), # pyright: ignore [reportCallIssue]
92
89
  litellm_params=(
93
90
  p := LiteLLM_Params(
94
91
  api_key=ok(
95
- self.llm_api_key or configs.llm.api_key, "llm api key is not set at any place"
92
+ self.llm_api_key or CONFIG.llm.api_key, "llm api key is not set at any place"
96
93
  ).get_secret_value(),
97
94
  api_base=ok(
98
- self.llm_api_endpoint or configs.llm.api_endpoint,
95
+ self.llm_api_endpoint or CONFIG.llm.api_endpoint,
99
96
  "llm api endpoint is not set at any place",
100
- ).unicode_string(),
97
+ ),
101
98
  model=m_name,
102
- tpm=self.llm_tpm or configs.llm.tpm,
103
- rpm=self.llm_rpm or configs.llm.rpm,
104
- max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
105
- timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
99
+ tpm=self.llm_tpm or CONFIG.llm.tpm,
100
+ rpm=self.llm_rpm or CONFIG.llm.rpm,
101
+ max_retries=kwargs.get("max_retries") or self.llm_max_retries or CONFIG.llm.max_retries,
102
+ timeout=kwargs.get("timeout") or self.llm_timeout or CONFIG.llm.timeout,
106
103
  )
107
104
  ),
108
105
  model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
109
106
  )
110
107
  ).acompletion(
111
108
  messages=messages, # pyright: ignore [reportArgumentType]
112
- n=n or self.llm_generation_count or configs.llm.generation_count,
109
+ n=n or self.llm_generation_count or CONFIG.llm.generation_count,
113
110
  model=m_name,
114
- temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
115
- stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
116
- top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
117
- max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
118
- stream=ok(kwargs.get("stream") or self.llm_stream or configs.llm.stream, "stream is not set at any place"),
111
+ temperature=kwargs.get("temperature") or self.llm_temperature or CONFIG.llm.temperature,
112
+ stop=kwargs.get("stop") or self.llm_stop_sign or CONFIG.llm.stop_sign,
113
+ top_p=kwargs.get("top_p") or self.llm_top_p or CONFIG.llm.top_p,
114
+ max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or CONFIG.llm.max_tokens,
115
+ stream=first_available(
116
+ (kwargs.get("stream"), self.llm_stream, CONFIG.llm.stream), "stream is not set at any place"
117
+ ),
119
118
  cache={
120
119
  "no-cache": kwargs.get("no_cache"),
121
120
  "no-store": kwargs.get("no_store"),
122
121
  "cache-ttl": kwargs.get("cache_ttl"),
123
122
  "s-maxage": kwargs.get("s_maxage"),
124
123
  },
125
- presence_penalty=kwargs.get("presence_penalty")
126
- or self.llm_presence_penalty
127
- or configs.llm.presence_penalty,
124
+ presence_penalty=kwargs.get("presence_penalty") or self.llm_presence_penalty or CONFIG.llm.presence_penalty,
128
125
  frequency_penalty=kwargs.get("frequency_penalty")
129
126
  or self.llm_frequency_penalty
130
- or configs.llm.frequency_penalty,
127
+ or CONFIG.llm.frequency_penalty,
131
128
  )
132
129
 
133
130
  async def ainvoke(
@@ -155,11 +152,8 @@ class LLMUsage(ScopedConfig):
155
152
  )
156
153
  if isinstance(resp, ModelResponse):
157
154
  return resp.choices
158
- if isinstance(resp, CustomStreamWrapper):
159
- if not configs.debug.streaming_visible and (pack := stream_chunk_builder(await asyncstdlib.list())):
160
- return pack.choices
161
- if pack := stream_chunk_builder(await asyncstdlib.list(resp)):
162
- return pack.choices
155
+ if isinstance(resp, CustomStreamWrapper) and (pack := stream_chunk_builder(await asyncstdlib.list(resp))):
156
+ return pack.choices
163
157
  logger.critical(err := f"Unexpected response type: {type(resp)}")
164
158
  raise ValueError(err)
165
159
 
@@ -170,6 +164,7 @@ class LLMUsage(ScopedConfig):
170
164
  system_message: List[str],
171
165
  **kwargs: Unpack[LLMKwargs],
172
166
  ) -> List[str]: ...
167
+
173
168
  @overload
174
169
  async def aask(
175
170
  self,
@@ -177,6 +172,7 @@ class LLMUsage(ScopedConfig):
177
172
  system_message: List[str],
178
173
  **kwargs: Unpack[LLMKwargs],
179
174
  ) -> List[str]: ...
175
+
180
176
  @overload
181
177
  async def aask(
182
178
  self,
@@ -231,7 +227,8 @@ class LLMUsage(ScopedConfig):
231
227
  raise RuntimeError("Should not reach here.")
232
228
 
233
229
  logger.debug(
234
- f"Response Token Count: {token_counter(text=out) if isinstance(out, str) else sum(token_counter(text=o) for o in out)}" # pyright: ignore [reportOptionalIterable]
230
+ f"Response Token Count: {token_counter(text=out) if isinstance(out, str) else sum(token_counter(text=o) for o in out)}"
231
+ # pyright: ignore [reportOptionalIterable]
235
232
  )
236
233
  return out # pyright: ignore [reportReturnType]
237
234
 
@@ -244,6 +241,7 @@ class LLMUsage(ScopedConfig):
244
241
  max_validations: PositiveInt = 2,
245
242
  **kwargs: Unpack[GenerateKwargs],
246
243
  ) -> T: ...
244
+
247
245
  @overload
248
246
  async def aask_validate[T](
249
247
  self,
@@ -253,6 +251,7 @@ class LLMUsage(ScopedConfig):
253
251
  max_validations: PositiveInt = 2,
254
252
  **kwargs: Unpack[GenerateKwargs],
255
253
  ) -> List[T]: ...
254
+
256
255
  @overload
257
256
  async def aask_validate[T](
258
257
  self,
@@ -331,9 +330,11 @@ class LLMUsage(ScopedConfig):
331
330
  Returns:
332
331
  Optional[List[str]]: The validated response as a list of strings.
333
332
  """
333
+ from fabricatio.parser import JsonCapture
334
+
334
335
  return await self.aask_validate(
335
336
  TEMPLATE_MANAGER.render_template(
336
- configs.templates.liststr_template,
337
+ CONFIG.templates.liststr_template,
337
338
  {"requirement": requirement, "k": k},
338
339
  ),
339
340
  lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
@@ -352,7 +353,7 @@ class LLMUsage(ScopedConfig):
352
353
  """
353
354
  return await self.alist_str(
354
355
  TEMPLATE_MANAGER.render_template(
355
- configs.templates.pathstr_template,
356
+ CONFIG.templates.pathstr_template,
356
357
  {"requirement": requirement},
357
358
  ),
358
359
  **kwargs,
@@ -387,9 +388,11 @@ class LLMUsage(ScopedConfig):
387
388
  Returns:
388
389
  Optional[str]: The generated string.
389
390
  """
391
+ from fabricatio.parser import GenericCapture
392
+
390
393
  return await self.aask_validate( # pyright: ignore [reportReturnType]
391
394
  TEMPLATE_MANAGER.render_template(
392
- configs.templates.generic_string_template,
395
+ CONFIG.templates.generic_string_template,
393
396
  {"requirement": requirement, "language": GenericCapture.capture_type},
394
397
  ),
395
398
  validator=lambda resp: GenericCapture.capture(resp),
@@ -414,11 +417,13 @@ class LLMUsage(ScopedConfig):
414
417
  Returns:
415
418
  Optional[List[T]]: The final validated selection result list, with element types matching the input `choices`.
416
419
  """
420
+ from fabricatio.parser import JsonCapture
421
+
417
422
  if dup := duplicates_everseen(choices, key=lambda x: x.name):
418
423
  logger.error(err := f"Redundant choices: {dup}")
419
424
  raise ValueError(err)
420
425
  prompt = TEMPLATE_MANAGER.render_template(
421
- configs.templates.make_choice_template,
426
+ CONFIG.templates.make_choice_template,
422
427
  {
423
428
  "instruction": instruction,
424
429
  "options": [m.model_dump(include={"name", "briefing"}) for m in choices],
@@ -489,9 +494,11 @@ class LLMUsage(ScopedConfig):
489
494
  Returns:
490
495
  bool: The judgment result (True or False) based on the AI's response.
491
496
  """
497
+ from fabricatio.parser import JsonCapture
498
+
492
499
  return await self.aask_validate(
493
500
  question=TEMPLATE_MANAGER.render_template(
494
- configs.templates.make_judgment_template,
501
+ CONFIG.templates.make_judgment_template,
495
502
  {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
496
503
  ),
497
504
  validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
@@ -499,7 +506,7 @@ class LLMUsage(ScopedConfig):
499
506
  )
500
507
 
501
508
 
502
- class EmbeddingUsage(LLMUsage):
509
+ class EmbeddingUsage(LLMUsage, ABC):
503
510
  """A class representing the embedding model.
504
511
 
505
512
  This class extends LLMUsage and provides methods to generate embeddings for input text using various models.
@@ -526,37 +533,36 @@ class EmbeddingUsage(LLMUsage):
526
533
  EmbeddingResponse: The response containing the embeddings.
527
534
  """
528
535
  # check seq length
529
- max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
536
+ max_len = self.embedding_max_sequence_length or CONFIG.embedding.max_sequence_length
530
537
  if max_len and any(length := (token_counter(text=t)) > max_len for t in input_text):
531
538
  logger.error(err := f"Input text exceeds maximum sequence length {max_len}, got {length}.")
532
539
  raise ValueError(err)
533
540
 
534
- return await litellm.aembedding(
541
+ return await aembedding(
535
542
  input=input_text,
536
- caching=caching or self.embedding_caching or configs.embedding.caching,
537
- dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
538
- model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
543
+ caching=caching or self.embedding_caching or CONFIG.embedding.caching,
544
+ dimensions=dimensions or self.embedding_dimensions or CONFIG.embedding.dimensions,
545
+ model=model or self.embedding_model or CONFIG.embedding.model or self.llm_model or CONFIG.llm.model,
539
546
  timeout=timeout
540
547
  or self.embedding_timeout
541
- or configs.embedding.timeout
548
+ or CONFIG.embedding.timeout
542
549
  or self.llm_timeout
543
- or configs.llm.timeout,
550
+ or CONFIG.llm.timeout,
544
551
  api_key=ok(
545
- self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
552
+ self.embedding_api_key or CONFIG.embedding.api_key or self.llm_api_key or CONFIG.llm.api_key
546
553
  ).get_secret_value(),
547
554
  api_base=ok(
548
555
  self.embedding_api_endpoint
549
- or configs.embedding.api_endpoint
556
+ or CONFIG.embedding.api_endpoint
550
557
  or self.llm_api_endpoint
551
- or configs.llm.api_endpoint
552
- )
553
- .unicode_string()
554
- .rstrip("/"),
558
+ or CONFIG.llm.api_endpoint
559
+ ).rstrip("/"),
555
560
  # seems embedding function takes no base_url end with a slash
556
561
  )
557
562
 
558
563
  @overload
559
564
  async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
565
+
560
566
  @overload
561
567
  async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
562
568
 
@@ -578,7 +584,7 @@ class EmbeddingUsage(LLMUsage):
578
584
  return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
579
585
 
580
586
 
581
- class ToolBoxUsage(LLMUsage):
587
+ class ToolBoxUsage(LLMUsage, ABC):
582
588
  """A class representing the usage of tools in a task.
583
589
 
584
590
  This class extends LLMUsage and provides methods to manage and use toolboxes and tools within tasks.
fabricatio/parser.py CHANGED
@@ -1,152 +1,112 @@
1
- """A module to parse text using regular expressions."""
1
+ """A module for capturing patterns in text using regular expressions."""
2
2
 
3
3
  import re
4
+ from dataclasses import dataclass, field
4
5
  from functools import lru_cache
5
- from re import Pattern, compile
6
- from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
6
+ from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type, Union
7
7
 
8
8
  import ujson
9
9
  from json_repair import repair_json
10
- from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
11
10
 
12
- from fabricatio.config import configs
13
11
  from fabricatio.journal import logger
12
+ from fabricatio.rust import CONFIG
14
13
 
15
14
 
16
- class Capture(BaseModel):
15
+ @dataclass(frozen=True)
16
+ class Capture:
17
17
  """A class to capture patterns in text using regular expressions.
18
18
 
19
19
  Attributes:
20
- pattern (str): The regular expression pattern to search for.
21
- _compiled (Pattern): The compiled regular expression pattern.
20
+ target_groups (Tuple[int, ...]): The target groups to extract from the match.
21
+ pattern (str): The regex pattern to search for.
22
+ flags (int): Flags to apply when compiling the regex.
23
+ capture_type (Optional[str]): Optional hint for post-processing (e.g., 'json').
22
24
  """
23
25
 
24
- model_config = ConfigDict(use_attribute_docstrings=True)
25
- target_groups: Tuple[int, ...] = Field(default_factory=tuple)
26
- """The target groups to capture from the pattern."""
27
- pattern: str = Field(frozen=True)
26
+ pattern: str = field()
28
27
  """The regular expression pattern to search for."""
29
- flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
30
- """The flags to use when compiling the regular expression pattern."""
28
+ flags: int = re.DOTALL | re.MULTILINE | re.IGNORECASE
29
+ """Flags to control regex behavior (DOTALL, MULTILINE, IGNORECASE by default)."""
31
30
  capture_type: Optional[str] = None
32
- """The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
33
- _compiled: Pattern = PrivateAttr()
31
+ """Optional type identifier for post-processing (e.g., 'json' for JSON repair)."""
32
+ target_groups: Tuple[int, ...] = field(default_factory=tuple)
33
+ """Tuple of group indices to extract from the match (1-based indexing)."""
34
34
 
35
- def model_post_init(self, __context: Any) -> None:
36
- """Initialize the compiled pattern."""
37
- self._compiled = compile(self.pattern, self.flags)
38
-
39
- def fix[T](self, text: str | Iterable[str] | T) -> str | List[str] | T:
40
- """Fix the text using the pattern.
41
-
42
- Args:
43
- text (str | List[str]): The text to fix.
44
-
45
- Returns:
46
- str | List[str]: The fixed text with the same type as input.
47
- """
35
+ def fix(self, text: Union[str, Iterable[str], Any]) -> Union[str, List[str], Any]:
36
+ """Fix the text based on capture_type (e.g., JSON repair)."""
48
37
  match self.capture_type:
49
- case "json" if configs.general.use_json_repair:
50
- logger.debug("Applying json repair to text.")
38
+ case "json" if CONFIG.general.use_json_repair:
39
+ logger.debug("Applying JSON repair to text.")
51
40
  if isinstance(text, str):
52
- return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
53
- return [repair_json(item, ensure_ascii=False) for item in
54
- text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
41
+ return repair_json(text, ensure_ascii=False)
42
+ return [repair_json(item, ensure_ascii=False) for item in text]
55
43
  case _:
56
- return text # pyright: ignore [reportReturnType]
57
-
58
- def capture(self, text: str) -> Tuple[str, ...] | str | None:
59
- """Capture the first occurrence of the pattern in the given text.
60
-
61
- Args:
62
- text (str): The text to search the pattern in.
63
-
64
- Returns:
65
- str | None: The captured text if the pattern is found, otherwise None.
66
-
67
- """
68
- if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
69
- logger.debug(f"Capture Failed {type(text)}: \n{text}")
44
+ return text
45
+
46
+ def capture(self, text: str) -> Optional[Union[str, Tuple[str, ...]]]:
47
+ """Capture the first match of the pattern in the text."""
48
+ compiled = re.compile(self.pattern, self.flags)
49
+ match = compiled.match(text) or compiled.search(text)
50
+ if match is None:
51
+ logger.debug(f"Capture Failed: {text}")
70
52
  return None
53
+
71
54
  groups = self.fix(match.groups())
72
55
  if self.target_groups:
73
56
  cap = tuple(groups[g - 1] for g in self.target_groups)
74
- logger.debug(f"Captured text: {'\n\n'.join(cap)}")
57
+ logger.debug(f"Captured texts: {'\n==\n'.join(cap)}")
75
58
  return cap
76
59
  cap = groups[0]
77
60
  logger.debug(f"Captured text: \n{cap}")
78
61
  return cap
79
62
 
80
- def convert_with[T](self, text: str, convertor: Callable[[Tuple[str, ...]], T] | Callable[[str], T]) -> T | None:
81
- """Convert the given text using the pattern.
82
-
83
- Args:
84
- text (str): The text to search the pattern in.
85
- convertor (Callable[[Tuple[str, ...]], T] | Callable[[str], T]): The function to convert the captured text.
86
-
87
- Returns:
88
- str | None: The converted text if the pattern is found, otherwise None.
89
- """
63
+ def convert_with(
64
+ self,
65
+ text: str,
66
+ convertor: Callable[[Union[str, Tuple[str, ...]]], Any],
67
+ ) -> Optional[Any]:
68
+ """Convert captured text using a provided function."""
90
69
  if (cap := self.capture(text)) is None:
91
70
  return None
92
71
  try:
93
- return convertor(cap) # pyright: ignore [reportArgumentType]
94
- except (ValueError, SyntaxError, ValidationError) as e:
95
- logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
72
+ return convertor(cap)
73
+ except Exception as e: # noqa: BLE001
74
+ logger.error(f"Failed to convert text using {convertor.__name__}: {e}\n{cap}")
96
75
  return None
97
76
 
98
- def validate_with[K, T, E](
99
- self,
100
- text: str,
101
- target_type: Type[T],
102
- elements_type: Optional[Type[E]] = None,
103
- length: Optional[int] = None,
104
- deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
105
- ) -> T | None:
106
- """Validate the given text using the pattern.
107
-
108
- Args:
109
- text (str): The text to search the pattern in.
110
- target_type (Type[T]): The expected type of the output, dict or list.
111
- elements_type (Optional[Type[E]]): The expected type of the elements in the output dict keys or list elements.
112
- length (Optional[int]): The expected length of the output, bool(length)==False means no length validation.
113
- deserializer (Callable[[Tuple[str, ...]], K] | Callable[[str], K]): The function to deserialize the captured text.
114
-
115
- Returns:
116
- T | None: The validated text if the pattern is found and the output is of the expected type, otherwise None.
117
- """
118
- judges = [lambda output_obj: isinstance(output_obj, target_type)]
77
+ def validate_with[T, K, E](
78
+ self,
79
+ text: str,
80
+ target_type: Type[T],
81
+ elements_type: Optional[Type[E]] = None,
82
+ length: Optional[int] = None,
83
+ deserializer: Callable[[Union[str, Tuple[str, ...]]], K] = lambda x: ujson.loads(x) if isinstance(x, str) else ujson.loads(x[0]),
84
+ ) -> Optional[T]:
85
+ """Deserialize and validate the captured text against expected types."""
86
+ judges = [lambda obj: isinstance(obj, target_type)]
119
87
  if elements_type:
120
- judges.append(lambda output_obj: all(isinstance(e, elements_type) for e in output_obj))
88
+ judges.append(lambda obj: all(isinstance(e, elements_type) for e in obj))
121
89
  if length:
122
- judges.append(lambda output_obj: len(output_obj) == length)
90
+ judges.append(lambda obj: len(obj) == length)
123
91
 
124
92
  if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
125
- return out # pyright: ignore [reportReturnType]
93
+ return out # type: ignore
126
94
  return None
127
95
 
128
96
  @classmethod
129
97
  @lru_cache(32)
130
98
  def capture_code_block(cls, language: str) -> Self:
131
- """Capture the first occurrence of a code block in the given text.
132
-
133
- Args:
134
- language (str): The text containing the code block.
135
-
136
- Returns:
137
- Self: The instance of the class with the captured code block.
138
- """
99
+ """Capture a code block of the given language."""
139
100
  return cls(pattern=f"```{language}(.*?)```", capture_type=language)
140
101
 
141
102
  @classmethod
142
103
  @lru_cache(32)
143
104
  def capture_generic_block(cls, language: str) -> Self:
144
- """Capture the first occurrence of a generic code block in the given text.
145
-
146
- Returns:
147
- Self: The instance of the class with the captured code block.
148
- """
149
- return cls(pattern=f"--- Start of {language} ---(.*?)--- end of {language} ---", capture_type=language)
105
+ """Capture a generic block of the given language."""
106
+ return cls(
107
+ pattern=f"--- Start of {language} ---(.*?)--- End of {language} ---",
108
+ capture_type=language,
109
+ )
150
110
 
151
111
 
152
112
  JsonCapture = Capture.capture_code_block("json")
Binary file