fabricatio 0.2.5.dev5__cp312-cp312-win_amd64.whl → 0.2.6.dev1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fabricatio/config.py CHANGED
@@ -19,8 +19,6 @@ from pydantic import (
19
19
  )
20
20
  from pydantic_settings import (
21
21
  BaseSettings,
22
- DotEnvSettingsSource,
23
- EnvSettingsSource,
24
22
  PydanticBaseSettingsSource,
25
23
  PyprojectTomlConfigSettingsSource,
26
24
  SettingsConfigDict,
@@ -68,7 +66,7 @@ class LLMConfig(BaseModel):
68
66
  temperature: NonNegativeFloat = Field(default=1.0)
69
67
  """The temperature of the LLM model. Controls randomness in generation. Set to 1.0 as per request."""
70
68
 
71
- stop_sign: str | List[str] = Field(default_factory=lambda :["\n\n\n", "User:"])
69
+ stop_sign: str | List[str] = Field(default_factory=lambda: ["\n\n\n", "User:"])
72
70
  """The stop sign of the LLM model. No default stop sign specified."""
73
71
 
74
72
  top_p: NonNegativeFloat = Field(default=0.35)
@@ -83,6 +81,12 @@ class LLMConfig(BaseModel):
83
81
  max_tokens: PositiveInt = Field(default=8192)
84
82
  """The maximum number of tokens to generate. Set to 8192 as per request."""
85
83
 
84
+ rpm: Optional[PositiveInt] = Field(default=100)
85
+ """The rate limit of the LLM model in requests per minute. None means not checked."""
86
+
87
+ tpm: Optional[PositiveInt] = Field(default=1000000)
88
+ """The rate limit of the LLM model in tokens per minute. None means not checked."""
89
+
86
90
 
87
91
  class EmbeddingConfig(BaseModel):
88
92
  """Embedding configuration class."""
@@ -222,6 +226,12 @@ class TemplateConfig(BaseModel):
222
226
  review_string_template: str = Field(default="review_string")
223
227
  """The name of the review string template which will be used to review a string."""
224
228
 
229
+ generic_string_template: str = Field(default="generic_string")
230
+ """The name of the generic string template which will be used to review a string."""
231
+
232
+ correct_template: str = Field(default="correct")
233
+ """The name of the correct template which will be used to correct a string."""
234
+
225
235
 
226
236
  class MagikaConfig(BaseModel):
227
237
  """Magika configuration class."""
@@ -285,6 +295,19 @@ class CacheConfig(BaseModel):
285
295
  """Whether to enable cache."""
286
296
 
287
297
 
298
+ class RoutingConfig(BaseModel):
299
+ """Routing configuration class."""
300
+
301
+ model_config = ConfigDict(use_attribute_docstrings=True)
302
+
303
+ allowed_fails: Optional[int] = 1
304
+ """The number of allowed fails before the routing is considered failed."""
305
+ retry_after: int = 15
306
+ """The time in seconds to wait before retrying the routing after a fail."""
307
+ cooldown_time: Optional[int] = 120
308
+ """The time in seconds to wait before retrying the routing after a cooldown."""
309
+
310
+
288
311
  class Settings(BaseSettings):
289
312
  """Application settings class.
290
313
 
@@ -310,6 +333,9 @@ class Settings(BaseSettings):
310
333
  llm: LLMConfig = Field(default_factory=LLMConfig)
311
334
  """LLM Configuration"""
312
335
 
336
+ routing: RoutingConfig = Field(default_factory=RoutingConfig)
337
+ """Routing Configuration"""
338
+
313
339
  embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
314
340
  """Embedding Configuration"""
315
341
 
@@ -348,6 +374,9 @@ class Settings(BaseSettings):
348
374
  ) -> tuple[PydanticBaseSettingsSource, ...]:
349
375
  """Customize settings sources.
350
376
 
377
+ This method customizes the settings sources used by the application. It returns a tuple of settings sources, including
378
+ the dotenv settings source, environment settings source, a custom TomlConfigSettingsSource, and a custom
379
+
351
380
  Args:
352
381
  settings_cls (type[BaseSettings]): The settings class.
353
382
  init_settings (PydanticBaseSettingsSource): Initial settings source.
@@ -359,10 +388,12 @@ class Settings(BaseSettings):
359
388
  tuple[PydanticBaseSettingsSource, ...]: A tuple of settings sources.
360
389
  """
361
390
  return (
362
- DotEnvSettingsSource(settings_cls),
363
- EnvSettingsSource(settings_cls),
364
- TomlConfigSettingsSource(settings_cls),
391
+ init_settings,
392
+ dotenv_settings,
393
+ env_settings,
394
+ file_secret_settings,
365
395
  PyprojectTomlConfigSettingsSource(settings_cls),
396
+ TomlConfigSettingsSource(settings_cls),
366
397
  )
367
398
 
368
399
 
fabricatio/fs/__init__.py CHANGED
@@ -11,9 +11,10 @@ from fabricatio.fs.curd import (
11
11
  move_file,
12
12
  tree,
13
13
  )
14
- from fabricatio.fs.readers import magika, safe_json_read, safe_text_read
14
+ from fabricatio.fs.readers import MAGIKA, safe_json_read, safe_text_read
15
15
 
16
16
  __all__ = [
17
+ "MAGIKA",
17
18
  "absolute_path",
18
19
  "copy_file",
19
20
  "create_directory",
@@ -21,7 +22,6 @@ __all__ = [
21
22
  "delete_file",
22
23
  "dump_text",
23
24
  "gather_files",
24
- "magika",
25
25
  "move_file",
26
26
  "safe_json_read",
27
27
  "safe_text_read",
fabricatio/fs/readers.py CHANGED
@@ -9,7 +9,7 @@ from magika import Magika
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.journal import logger
11
11
 
12
- magika = Magika(model_dir=configs.magika.model_dir)
12
+ MAGIKA = Magika(model_dir=configs.magika.model_dir)
13
13
 
14
14
 
15
15
  def safe_text_read(path: Path | str) -> str:
fabricatio/journal.py CHANGED
@@ -19,10 +19,3 @@ logger.add(
19
19
  logger.add(sys.stderr, level=configs.debug.log_level)
20
20
 
21
21
  __all__ = ["logger"]
22
- if __name__ == "__main__":
23
- logger.debug("This is a trace message.")
24
- logger.info("This is an information message.")
25
- logger.success("This is a success message.")
26
- logger.warning("This is a warning message.")
27
- logger.error("This is an error message.")
28
- logger.critical("This is a critical message.")
@@ -3,9 +3,9 @@
3
3
  import traceback
4
4
  from abc import abstractmethod
5
5
  from asyncio import Queue, create_task
6
- from typing import Any, Dict, Self, Tuple, Type, Union, Unpack, final
6
+ from typing import Any, Dict, Self, Tuple, Type, Union, final
7
7
 
8
- from fabricatio.capabilities.review import Review
8
+ from fabricatio.capabilities.correct import Correct
9
9
  from fabricatio.capabilities.task import HandleTask, ProposeTask
10
10
  from fabricatio.journal import logger
11
11
  from fabricatio.models.generic import WithBriefing
@@ -14,7 +14,7 @@ from fabricatio.models.usages import ToolBoxUsage
14
14
  from pydantic import Field, PrivateAttr
15
15
 
16
16
 
17
- class Action(HandleTask, ProposeTask, Review):
17
+ class Action(HandleTask, ProposeTask, Correct):
18
18
  """Class that represents an action to be executed in a workflow."""
19
19
 
20
20
  name: str = Field(default="")
@@ -37,7 +37,7 @@ class Action(HandleTask, ProposeTask, Review):
37
37
  self.description = self.description or self.__class__.__doc__ or ""
38
38
 
39
39
  @abstractmethod
40
- async def _execute(self, **cxt: Unpack) -> Any:
40
+ async def _execute(self, **cxt) -> Any:
41
41
  """Execute the action with the provided arguments.
42
42
 
43
43
  Args:
@@ -2,13 +2,13 @@
2
2
 
3
3
  from abc import abstractmethod
4
4
  from pathlib import Path
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Union, final
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Union, final, overload
6
6
 
7
7
  import orjson
8
8
  from fabricatio._rust import blake3_hash
9
- from fabricatio._rust_instances import template_manager
9
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
10
10
  from fabricatio.config import configs
11
- from fabricatio.fs.readers import magika, safe_text_read
11
+ from fabricatio.fs.readers import MAGIKA, safe_text_read
12
12
  from fabricatio.journal import logger
13
13
  from fabricatio.parser import JsonCapture
14
14
  from pydantic import (
@@ -40,6 +40,14 @@ class Display(Base):
40
40
  """
41
41
  return self.model_dump_json(indent=1)
42
42
 
43
+ def compact(self) -> str:
44
+ """Display the model in a compact JSON string.
45
+
46
+ Returns:
47
+ str: The compact JSON string of the model.
48
+ """
49
+ return self.model_dump_json()
50
+
43
51
 
44
52
  class Named(Base):
45
53
  """Class that includes a name attribute."""
@@ -100,7 +108,15 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
100
108
  """Class that provides a prompt for creating a JSON object."""
101
109
 
102
110
  @classmethod
103
- def create_json_prompt(cls, requirement: str) -> str:
111
+ @overload
112
+ def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
113
+
114
+ @classmethod
115
+ @overload
116
+ def create_json_prompt(cls, requirement: str) -> str: ...
117
+
118
+ @classmethod
119
+ def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
104
120
  """Create the prompt for creating a JSON object with given requirement.
105
121
 
106
122
  Args:
@@ -109,10 +125,18 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
109
125
  Returns:
110
126
  str: The prompt for creating a JSON object with given requirement.
111
127
  """
112
- return template_manager.render_template(
113
- configs.templates.create_json_obj_template,
114
- {"requirement": requirement, "json_schema": cls.formated_json_schema()},
115
- )
128
+ if isinstance(requirement, str):
129
+ return TEMPLATE_MANAGER.render_template(
130
+ configs.templates.create_json_obj_template,
131
+ {"requirement": requirement, "json_schema": cls.formated_json_schema()},
132
+ )
133
+ return [
134
+ TEMPLATE_MANAGER.render_template(
135
+ configs.templates.create_json_obj_template,
136
+ {"requirement": r, "json_schema": cls.formated_json_schema()},
137
+ )
138
+ for r in requirement
139
+ ]
116
140
 
117
141
 
118
142
  class InstantiateFromString(Base):
@@ -231,13 +255,13 @@ class WithDependency(Base):
231
255
  Returns:
232
256
  str: The generated prompt for the task.
233
257
  """
234
- return template_manager.render_template(
258
+ return TEMPLATE_MANAGER.render_template(
235
259
  configs.templates.dependencies_template,
236
260
  {
237
261
  (pth := Path(p)).name: {
238
262
  "path": pth.as_posix(),
239
263
  "exists": pth.exists(),
240
- "description": (identity := magika.identify_path(pth)).output.description,
264
+ "description": (identity := MAGIKA.identify_path(pth)).output.description,
241
265
  "size": f"{pth.stat().st_size / (1024 * 1024) if pth.exists() and pth.is_file() else 0:.3f} MB",
242
266
  "content": (text := safe_text_read(pth)),
243
267
  "lines": len(text.splitlines()),
@@ -307,6 +331,12 @@ class ScopedConfig(Base):
307
331
  llm_max_tokens: Optional[PositiveInt] = None
308
332
  """The maximum number of tokens to generate."""
309
333
 
334
+ llm_tpm: Optional[PositiveInt] = None
335
+ """The tokens per minute of the LLM model."""
336
+
337
+ llm_rpm: Optional[PositiveInt] = None
338
+ """The requests per minute of the LLM model."""
339
+
310
340
  embedding_api_endpoint: Optional[HttpUrl] = None
311
341
  """The OpenAI API endpoint."""
312
342
 
@@ -12,7 +12,7 @@ class CollectionSimpleConfigKwargs(TypedDict, total=False):
12
12
  These arguments are typically used when configuring connections to vector databases.
13
13
  """
14
14
 
15
- dimension: int
15
+ dimension: int | None
16
16
  timeout: float
17
17
 
18
18
 
@@ -23,7 +23,7 @@ class FetchKwargs(TypedDict, total=False):
23
23
  and result limiting parameters.
24
24
  """
25
25
 
26
- collection_name: str
26
+ collection_name: str | None
27
27
  similarity_threshold: float
28
28
  result_per_query: int
29
29
 
fabricatio/models/role.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from typing import Any, Self, Set
4
4
 
5
- from fabricatio.capabilities.review import Review
5
+ from fabricatio.capabilities.correct import Correct
6
6
  from fabricatio.capabilities.task import HandleTask, ProposeTask
7
7
  from fabricatio.core import env
8
8
  from fabricatio.journal import logger
@@ -12,7 +12,7 @@ from fabricatio.models.tool import ToolBox
12
12
  from pydantic import Field
13
13
 
14
14
 
15
- class Role(ProposeTask, HandleTask, Review):
15
+ class Role(ProposeTask, HandleTask, Correct):
16
16
  """Class that represents a role with a registry of events and workflows."""
17
17
 
18
18
  registry: dict[Event | str, WorkFlow] = Field(default_factory=dict)
fabricatio/models/task.py CHANGED
@@ -6,7 +6,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
6
6
  from asyncio import Queue
7
7
  from typing import Any, List, Optional, Self
8
8
 
9
- from fabricatio._rust_instances import template_manager
9
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
10
10
  from fabricatio.config import configs
11
11
  from fabricatio.core import env
12
12
  from fabricatio.journal import logger
@@ -253,7 +253,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
253
253
  Returns:
254
254
  str: The briefing of the task.
255
255
  """
256
- return template_manager.render_template(
256
+ return TEMPLATE_MANAGER.render_template(
257
257
  configs.templates.task_briefing_template,
258
258
  self.model_dump(),
259
259
  )
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set
5
5
 
6
6
  import asyncstdlib
7
7
  import litellm
8
- from fabricatio._rust_instances import template_manager
8
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.journal import logger
11
11
  from fabricatio.models.generic import ScopedConfig, WithBriefing
@@ -13,8 +13,9 @@ from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, Genera
13
13
  from fabricatio.models.task import Task
14
14
  from fabricatio.models.tool import Tool, ToolBox
15
15
  from fabricatio.models.utils import Messages
16
- from fabricatio.parser import JsonCapture
17
- from litellm import stream_chunk_builder
16
+ from fabricatio.parser import GenericCapture, JsonCapture
17
+ from litellm import Router, stream_chunk_builder
18
+ from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
18
19
  from litellm.types.utils import (
19
20
  Choices,
20
21
  EmbeddingResponse,
@@ -22,7 +23,7 @@ from litellm.types.utils import (
22
23
  StreamingChoices,
23
24
  TextChoices,
24
25
  )
25
- from litellm.utils import CustomStreamWrapper
26
+ from litellm.utils import CustomStreamWrapper # pyright: ignore [reportPrivateImportUsage]
26
27
  from more_itertools import duplicates_everseen
27
28
  from pydantic import Field, NonNegativeInt, PositiveInt
28
29
 
@@ -30,20 +31,33 @@ if configs.cache.enabled and configs.cache.type:
30
31
  litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
31
32
  logger.success(f"{configs.cache.type.name} Cache enabled")
32
33
 
34
+ ROUTER = Router(
35
+ routing_strategy="usage-based-routing-v2",
36
+ allowed_fails=configs.routing.allowed_fails,
37
+ retry_after=configs.routing.retry_after,
38
+ cooldown_time=configs.routing.cooldown_time,
39
+ )
40
+
33
41
 
34
42
  class LLMUsage(ScopedConfig):
35
43
  """Class that manages LLM (Large Language Model) usage parameters and methods."""
36
44
 
45
+ def _deploy(self, deployment: Deployment) -> Router:
46
+ """Add a deployment to the router."""
47
+ self._added_deployment = ROUTER.upsert_deployment(deployment)
48
+ return ROUTER
49
+
37
50
  @classmethod
38
51
  def _scoped_model(cls) -> Type["LLMUsage"]:
39
52
  return LLMUsage
40
53
 
54
+ # noinspection PyTypeChecker,PydanticTypeChecker
41
55
  async def aquery(
42
56
  self,
43
57
  messages: List[Dict[str, str]],
44
58
  n: PositiveInt | None = None,
45
59
  **kwargs: Unpack[LLMKwargs],
46
- ) -> ModelResponse:
60
+ ) -> ModelResponse | CustomStreamWrapper:
47
61
  """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
48
62
 
49
63
  Args:
@@ -55,19 +69,33 @@ class LLMUsage(ScopedConfig):
55
69
  ModelResponse | CustomStreamWrapper: An object containing the generated response and other metadata from the model.
56
70
  """
57
71
  # Call the underlying asynchronous completion function with the provided and default parameters
58
- return await litellm.acompletion(
72
+ # noinspection PyTypeChecker,PydanticTypeChecker
73
+
74
+ return await self._deploy(
75
+ Deployment(
76
+ model_name=(m_name := kwargs.get("model") or self.llm_model or configs.llm.model),
77
+ litellm_params=(
78
+ p := LiteLLM_Params(
79
+ api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
80
+ api_base=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
81
+ model=m_name,
82
+ tpm=self.llm_tpm or configs.llm.tpm,
83
+ rpm=self.llm_rpm or configs.llm.rpm,
84
+ max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
85
+ timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
86
+ )
87
+ ),
88
+ model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
89
+ )
90
+ ).acompletion(
59
91
  messages=messages,
60
92
  n=n or self.llm_generation_count or configs.llm.generation_count,
61
- model=kwargs.get("model") or self.llm_model or configs.llm.model,
93
+ model=m_name,
62
94
  temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
63
95
  stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
64
96
  top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
65
97
  max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
66
98
  stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
67
- timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
68
- max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
69
- api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
70
- base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
71
99
  cache={
72
100
  "no-cache": kwargs.get("no_cache"),
73
101
  "no-store": kwargs.get("no_store"),
@@ -192,31 +220,31 @@ class LLMUsage(ScopedConfig):
192
220
  @overload
193
221
  async def aask_validate[T](
194
222
  self,
195
- question: str,
223
+ question: List[str],
196
224
  validator: Callable[[str], T | None],
197
- default: None = None,
225
+ default: T,
198
226
  max_validations: PositiveInt = 2,
199
227
  **kwargs: Unpack[GenerateKwargs],
200
- ) -> Optional[T]: ...
201
-
228
+ ) -> List[T]: ...
202
229
  @overload
203
230
  async def aask_validate[T](
204
231
  self,
205
- question: List[str],
232
+ question: str,
206
233
  validator: Callable[[str], T | None],
207
234
  default: None = None,
208
235
  max_validations: PositiveInt = 2,
209
236
  **kwargs: Unpack[GenerateKwargs],
210
- ) -> List[Optional[T]]: ...
237
+ ) -> Optional[T]: ...
238
+
211
239
  @overload
212
240
  async def aask_validate[T](
213
241
  self,
214
242
  question: List[str],
215
243
  validator: Callable[[str], T | None],
216
- default: T,
244
+ default: None = None,
217
245
  max_validations: PositiveInt = 2,
218
246
  **kwargs: Unpack[GenerateKwargs],
219
- ) -> List[T]: ...
247
+ ) -> List[Optional[T]]: ...
220
248
 
221
249
  async def aask_validate[T](
222
250
  self,
@@ -274,7 +302,7 @@ class LLMUsage(ScopedConfig):
274
302
  List[str]: The validated response as a list of strings.
275
303
  """
276
304
  return await self.aask_validate(
277
- template_manager.render_template(
305
+ TEMPLATE_MANAGER.render_template(
278
306
  configs.templates.liststr_template,
279
307
  {"requirement": requirement, "k": k},
280
308
  ),
@@ -293,7 +321,7 @@ class LLMUsage(ScopedConfig):
293
321
  List[str]: The validated response as a list of strings.
294
322
  """
295
323
  return await self.aliststr(
296
- template_manager.render_template(
324
+ TEMPLATE_MANAGER.render_template(
297
325
  configs.templates.pathstr_template,
298
326
  {"requirement": requirement},
299
327
  ),
@@ -318,6 +346,25 @@ class LLMUsage(ScopedConfig):
318
346
  )
319
347
  ).pop()
320
348
 
349
+ async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> str:
350
+ """Asynchronously generates a generic string based on a given requirement.
351
+
352
+ Args:
353
+ requirement (str): The requirement for the string.
354
+ **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
355
+
356
+ Returns:
357
+ str: The generated string.
358
+ """
359
+ return await self.aask_validate(
360
+ TEMPLATE_MANAGER.render_template(
361
+ configs.templates.generic_string_template,
362
+ {"requirement": requirement, "language": GenericCapture.capture_type},
363
+ ),
364
+ validator=lambda resp: GenericCapture.capture(resp),
365
+ **kwargs,
366
+ )
367
+
321
368
  async def achoose[T: WithBriefing](
322
369
  self,
323
370
  instruction: str,
@@ -344,7 +391,7 @@ class LLMUsage(ScopedConfig):
344
391
  if dup := duplicates_everseen(choices, key=lambda x: x.name):
345
392
  logger.error(err := f"Redundant choices: {dup}")
346
393
  raise ValueError(err)
347
- prompt = template_manager.render_template(
394
+ prompt = TEMPLATE_MANAGER.render_template(
348
395
  configs.templates.make_choice_template,
349
396
  {
350
397
  "instruction": instruction,
@@ -417,7 +464,7 @@ class LLMUsage(ScopedConfig):
417
464
  bool: The judgment result (True or False) based on the AI's response.
418
465
  """
419
466
  return await self.aask_validate(
420
- question=template_manager.render_template(
467
+ question=TEMPLATE_MANAGER.render_template(
421
468
  configs.templates.make_judgment_template,
422
469
  {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
423
470
  ),
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Any, Dict, List, Literal, Optional, Self
5
5
 
6
6
  from pydantic import BaseModel, ConfigDict, Field
7
+ from questionary import text
7
8
 
8
9
 
9
10
  class Message(BaseModel):
@@ -144,3 +145,23 @@ class TaskStatus(Enum):
144
145
  Finished = "finished"
145
146
  Failed = "failed"
146
147
  Cancelled = "cancelled"
148
+
149
+
150
+ async def ask_edit(
151
+ text_seq: List[str],
152
+ ) -> List[str]:
153
+ """Asks the user to edit a list of texts.
154
+
155
+ Args:
156
+ text_seq (List[str]): A list of texts to be edited.
157
+
158
+ Returns:
159
+ List[str]: A list of edited texts.
160
+ If the user does not edit a text, it will not be included in the returned list.
161
+ """
162
+ res = []
163
+ for i, t in enumerate(text_seq):
164
+ edited = await text(f"[{i}] ", default=t).ask_async()
165
+ if edited:
166
+ res.append(edited)
167
+ return res
fabricatio/parser.py CHANGED
@@ -35,7 +35,7 @@ class Capture(BaseModel):
35
35
  """Initialize the compiled pattern."""
36
36
  self._compiled = compile(self.pattern, self.flags)
37
37
 
38
- def fix[T](self, text: str | Iterable[str]|T) -> str | List[str]|T:
38
+ def fix[T](self, text: str | Iterable[str] | T) -> str | List[str] | T:
39
39
  """Fix the text using the pattern.
40
40
 
41
41
  Args:
@@ -47,8 +47,8 @@ class Capture(BaseModel):
47
47
  match self.capture_type:
48
48
  case "json":
49
49
  if isinstance(text, str):
50
- return repair_json(text,ensure_ascii=False)
51
- return [repair_json(item) for item in text]
50
+ return repair_json(text, ensure_ascii=False)
51
+ return [repair_json(item, ensure_ascii=False) for item in text]
52
52
  case _:
53
53
  return text
54
54
 
@@ -134,8 +134,16 @@ class Capture(BaseModel):
134
134
  """
135
135
  return cls(pattern=f"```{language}\n(.*?)\n```", capture_type=language)
136
136
 
137
+ @classmethod
138
+ def capture_generic_block(cls, language: str) -> Self:
139
+ """Capture the first occurrence of a generic code block in the given text.
140
+
141
+ Returns:
142
+ Self: The instance of the class with the captured code block.
143
+ """
144
+ return cls(pattern=f"--- Start of {language} ---\n(.*?)\n--- end of {language} ---", capture_type=language)
145
+
137
146
 
138
147
  JsonCapture = Capture.capture_code_block("json")
139
148
  PythonCapture = Capture.capture_code_block("python")
140
- MarkdownCapture = Capture.capture_code_block("markdown")
141
- CodeBlockCapture = Capture(pattern="```.*?\n(.*?)\n```")
149
+ GenericCapture = Capture.capture_generic_block("String")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fabricatio
3
- Version: 0.2.5.dev5
3
+ Version: 0.2.6.dev1
4
4
  Classifier: License :: OSI Approved :: MIT License
5
5
  Classifier: Programming Language :: Rust
6
6
  Classifier: Programming Language :: Python :: 3.12
@@ -176,7 +176,7 @@ if __name__ == "__main__":
176
176
  ### Template Management and Rendering
177
177
 
178
178
  ```python
179
- from fabricatio._rust_instances import template_manager
179
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
180
180
 
181
181
  template_name = "claude-xml.hbs"
182
182
  data = {
@@ -185,7 +185,7 @@ data = {
185
185
  "files": [{"path": "file1.py", "code": "print('Hello')"}],
186
186
  }
187
187
 
188
- rendered_template = template_manager.render_template(template_name, data)
188
+ rendered_template = TEMPLATE_MANAGER.render_template(template_name, data)
189
189
  print(rendered_template)
190
190
  ```
191
191