fabricatio 0.2.0.dev4__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +32 -0
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +1 -0
- fabricatio/actions/__init__.py +5 -0
- fabricatio/actions/communication.py +13 -0
- fabricatio/actions/transmission.py +32 -0
- fabricatio/config.py +206 -0
- fabricatio/core.py +167 -0
- fabricatio/decorators.py +56 -0
- fabricatio/fs/__init__.py +5 -0
- fabricatio/fs/readers.py +5 -0
- fabricatio/journal.py +23 -0
- fabricatio/models/action.py +128 -0
- fabricatio/models/events.py +80 -0
- fabricatio/models/generic.py +388 -0
- fabricatio/models/role.py +26 -0
- fabricatio/models/task.py +283 -0
- fabricatio/models/tool.py +100 -0
- fabricatio/models/utils.py +78 -0
- fabricatio/parser.py +69 -0
- fabricatio/py.typed +0 -0
- fabricatio/templates.py +41 -0
- fabricatio/toolboxes/__init__.py +7 -0
- fabricatio/toolboxes/task.py +4 -0
- fabricatio-0.2.0.dev4.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.0.dev4.dist-info/METADATA +224 -0
- fabricatio-0.2.0.dev4.dist-info/RECORD +29 -0
- fabricatio-0.2.0.dev4.dist-info/WHEEL +4 -0
- fabricatio-0.2.0.dev4.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
"""This module defines the `Task` class, which represents a task with a status and output.
|
2
|
+
|
3
|
+
It includes methods to manage the task's lifecycle, such as starting, finishing, cancelling, and failing the task.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from asyncio import Queue
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Any, List, Optional, Self
|
9
|
+
|
10
|
+
from fabricatio.core import env
|
11
|
+
from fabricatio.journal import logger
|
12
|
+
from fabricatio.models.events import Event, EventLike
|
13
|
+
from fabricatio.models.generic import LLMUsage, WithBriefing, WithDependency, WithJsonExample
|
14
|
+
from fabricatio.parser import JsonCapture
|
15
|
+
from pydantic import Field, PrivateAttr, ValidationError
|
16
|
+
|
17
|
+
|
18
|
+
class TaskStatus(Enum):
|
19
|
+
"""An enumeration representing the status of a task.
|
20
|
+
|
21
|
+
Attributes:
|
22
|
+
Pending: The task is pending.
|
23
|
+
Running: The task is currently running.
|
24
|
+
Finished: The task has been successfully completed.
|
25
|
+
Failed: The task has failed.
|
26
|
+
Cancelled: The task has been cancelled.
|
27
|
+
"""
|
28
|
+
|
29
|
+
Pending = "pending"
|
30
|
+
Running = "running"
|
31
|
+
Finished = "finished"
|
32
|
+
Failed = "failed"
|
33
|
+
Cancelled = "cancelled"
|
34
|
+
|
35
|
+
|
36
|
+
class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
37
|
+
"""A class representing a task with a status and output.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
name (str): The name of the task.
|
41
|
+
description (str): The description of the task.
|
42
|
+
goal (str): The goal of the task.
|
43
|
+
dependencies (List[str]): The file dependencies of the task, a list of file paths.
|
44
|
+
namespace (List[str]): The namespace of the task, a list of namespace segment, as string.
|
45
|
+
"""
|
46
|
+
|
47
|
+
name: str = Field(...)
|
48
|
+
"""The name of the task."""
|
49
|
+
|
50
|
+
description: str = Field(default="")
|
51
|
+
"""The description of the task."""
|
52
|
+
|
53
|
+
goal: str = Field(default="")
|
54
|
+
"""The goal of the task."""
|
55
|
+
|
56
|
+
namespace: List[str] = Field(default_factory=list)
|
57
|
+
"""The namespace of the task, a list of namespace segment, as string."""
|
58
|
+
|
59
|
+
_output: Queue = PrivateAttr(default_factory=lambda: Queue(maxsize=1))
|
60
|
+
"""The output queue of the task."""
|
61
|
+
|
62
|
+
_status: TaskStatus = PrivateAttr(default=TaskStatus.Pending)
|
63
|
+
"""The status of the task."""
|
64
|
+
|
65
|
+
_namespace: Event = PrivateAttr(default_factory=Event)
|
66
|
+
"""The namespace of the task as an event, which is generated from the namespace list."""
|
67
|
+
|
68
|
+
def model_post_init(self, __context: Any) -> None:
|
69
|
+
"""Initialize the task with a namespace event."""
|
70
|
+
self._namespace.segments.extend(self.namespace)
|
71
|
+
|
72
|
+
def move_to(self, new_namespace: EventLike) -> Self:
|
73
|
+
"""Move the task to a new namespace.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
new_namespace (List[str]): The new namespace to move the task to.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
Task: The moved instance of the `Task` class.
|
80
|
+
"""
|
81
|
+
self.namespace = new_namespace
|
82
|
+
self._namespace.clear().concat(new_namespace)
|
83
|
+
return self
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def simple_task(cls, name: str, goal: str, description: str) -> Self:
|
87
|
+
"""Create a simple task with a name, goal, and description.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
name (str): The name of the task.
|
91
|
+
goal (str): The goal of the task.
|
92
|
+
description (str): The description of the task.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Task: A new instance of the `Task` class.
|
96
|
+
"""
|
97
|
+
return cls(name=name, goal=goal, description=description)
|
98
|
+
|
99
|
+
def update_task(self, goal: Optional[str] = None, description: Optional[str] = None) -> Self:
|
100
|
+
"""Update the goal and description of the task.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
goal (str, optional): The new goal of the task.
|
104
|
+
description (str, optional): The new description of the task.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Task: The updated instance of the `Task` class.
|
108
|
+
"""
|
109
|
+
if goal:
|
110
|
+
self.goal = goal
|
111
|
+
if description:
|
112
|
+
self.description = description
|
113
|
+
return self
|
114
|
+
|
115
|
+
async def get_output(self) -> T:
|
116
|
+
"""Get the output of the task.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
T: The output of the task.
|
120
|
+
"""
|
121
|
+
logger.debug(f"Getting output for task {self.name}")
|
122
|
+
return await self._output.get()
|
123
|
+
|
124
|
+
def status_label(self, status: TaskStatus) -> str:
|
125
|
+
"""Return a formatted status label for the task.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
status (TaskStatus): The status of the task.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
str: The formatted status label.
|
132
|
+
"""
|
133
|
+
return self._namespace.derive(self.name).push(status.value).collapse()
|
134
|
+
|
135
|
+
@property
|
136
|
+
def pending_label(self) -> str:
|
137
|
+
"""Return the pending status label for the task.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
str: The pending status label.
|
141
|
+
"""
|
142
|
+
return self.status_label(TaskStatus.Pending)
|
143
|
+
|
144
|
+
@property
|
145
|
+
def running_label(self) -> str:
|
146
|
+
"""Return the running status label for the task.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
str: The running status label.
|
150
|
+
"""
|
151
|
+
return self.status_label(TaskStatus.Running)
|
152
|
+
|
153
|
+
@property
|
154
|
+
def finished_label(self) -> str:
|
155
|
+
"""Return the finished status label for the task.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
str: The finished status label.
|
159
|
+
"""
|
160
|
+
return self.status_label(TaskStatus.Finished)
|
161
|
+
|
162
|
+
@property
|
163
|
+
def failed_label(self) -> str:
|
164
|
+
"""Return the failed status label for the task.
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
str: The failed status label.
|
168
|
+
"""
|
169
|
+
return self.status_label(TaskStatus.Failed)
|
170
|
+
|
171
|
+
@property
|
172
|
+
def cancelled_label(self) -> str:
|
173
|
+
"""Return the cancelled status label for the task.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
str: The cancelled status label.
|
177
|
+
"""
|
178
|
+
return self.status_label(TaskStatus.Cancelled)
|
179
|
+
|
180
|
+
async def finish(self, output: T) -> Self:
|
181
|
+
"""Mark the task as finished and set the output.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
output (T): The output of the task.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
Task: The finished instance of the `Task` class.
|
188
|
+
"""
|
189
|
+
logger.info(f"Finishing task {self.name}")
|
190
|
+
self._status = TaskStatus.Finished
|
191
|
+
await self._output.put(output)
|
192
|
+
logger.debug(f"Output set for task {self.name}")
|
193
|
+
await env.emit_async(self.finished_label, self)
|
194
|
+
logger.debug(f"Emitted finished event for task {self.name}")
|
195
|
+
return self
|
196
|
+
|
197
|
+
async def start(self) -> Self:
|
198
|
+
"""Mark the task as running.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
Task: The running instance of the `Task` class.
|
202
|
+
"""
|
203
|
+
logger.info(f"Starting task {self.name}")
|
204
|
+
self._status = TaskStatus.Running
|
205
|
+
await env.emit_async(self.running_label, self)
|
206
|
+
return self
|
207
|
+
|
208
|
+
async def cancel(self) -> Self:
|
209
|
+
"""Mark the task as cancelled.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Task: The cancelled instance of the `Task` class.
|
213
|
+
"""
|
214
|
+
self._status = TaskStatus.Cancelled
|
215
|
+
await env.emit_async(self.cancelled_label, self)
|
216
|
+
return self
|
217
|
+
|
218
|
+
async def fail(self) -> Self:
|
219
|
+
"""Mark the task as failed.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
Task: The failed instance of the `Task` class.
|
223
|
+
"""
|
224
|
+
logger.error(f"Task {self.name} failed")
|
225
|
+
self._status = TaskStatus.Failed
|
226
|
+
await env.emit_async(self.failed_label, self)
|
227
|
+
return self
|
228
|
+
|
229
|
+
async def publish(self) -> Self:
|
230
|
+
"""Publish the task to the event bus.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
Task: The published instance of the `Task` class
|
234
|
+
"""
|
235
|
+
logger.info(f"Publishing task {self.name}")
|
236
|
+
await env.emit_async(self.pending_label, self)
|
237
|
+
return self
|
238
|
+
|
239
|
+
async def delegate(self) -> T:
|
240
|
+
"""Delegate the task to the event bus and wait for the output.
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
T: The output of the task
|
244
|
+
"""
|
245
|
+
logger.info(f"Delegating task {self.name}")
|
246
|
+
await env.emit_async(self.pending_label, self)
|
247
|
+
return await self.get_output()
|
248
|
+
|
249
|
+
@property
|
250
|
+
def briefing(self) -> str:
|
251
|
+
"""Return a briefing of the task including its goal.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
str: The briefing of the task.
|
255
|
+
"""
|
256
|
+
return f"{super().briefing}\n{self.goal}"
|
257
|
+
|
258
|
+
|
259
|
+
class ProposeTask(LLMUsage, WithBriefing):
|
260
|
+
"""A class that proposes a task based on a prompt."""
|
261
|
+
|
262
|
+
async def propose(self, prompt: str) -> Task:
|
263
|
+
"""Propose a task based on the provided prompt."""
|
264
|
+
assert prompt, "Prompt must be provided."
|
265
|
+
|
266
|
+
def _validate_json(response: str) -> None | Task:
|
267
|
+
try:
|
268
|
+
cap = JsonCapture.capture(response)
|
269
|
+
logger.debug(f"Response: \n{response}")
|
270
|
+
logger.info(f"Captured JSON: \n{cap[0]}")
|
271
|
+
return Task.model_validate_json(cap[0] if cap else response)
|
272
|
+
except ValidationError as e:
|
273
|
+
logger.error(f"Failed to parse task from JSON: {e}")
|
274
|
+
return None
|
275
|
+
|
276
|
+
return await self.aask_validate(
|
277
|
+
f"{prompt} \n\nBased on requirement above, "
|
278
|
+
f"you need to construct a task to satisfy that requirement in JSON format "
|
279
|
+
f"written like this: \n\n```json\n{Task.json_example()}\n```\n\n"
|
280
|
+
f"No extra explanation needed. ",
|
281
|
+
_validate_json,
|
282
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
283
|
+
)
|
@@ -0,0 +1,100 @@
|
|
1
|
+
"""A module for defining tools and toolboxes."""
|
2
|
+
|
3
|
+
from inspect import getfullargspec, signature
|
4
|
+
from typing import Any, Callable, List, Self
|
5
|
+
|
6
|
+
from fabricatio.models.generic import WithBriefing
|
7
|
+
from pydantic import Field
|
8
|
+
|
9
|
+
|
10
|
+
class Tool[**P, R](WithBriefing):
|
11
|
+
"""A class representing a tool with a callable source function."""
|
12
|
+
|
13
|
+
name: str = Field(default="")
|
14
|
+
"""The name of the tool."""
|
15
|
+
|
16
|
+
description: str = Field(default="")
|
17
|
+
"""The description of the tool."""
|
18
|
+
|
19
|
+
source: Callable[P, R]
|
20
|
+
"""The source function of the tool."""
|
21
|
+
|
22
|
+
def model_post_init(self, __context: Any) -> None:
|
23
|
+
"""Initialize the tool with a name and a source function."""
|
24
|
+
self.name = self.name or self.source.__name__
|
25
|
+
assert self.name, "The tool must have a name."
|
26
|
+
self.description = self.description or self.source.__doc__ or ""
|
27
|
+
|
28
|
+
def invoke(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
29
|
+
"""Invoke the tool's source function with the provided arguments."""
|
30
|
+
return self.source(*args, **kwargs)
|
31
|
+
|
32
|
+
@property
|
33
|
+
def briefing(self) -> str:
|
34
|
+
"""Return a brief description of the tool.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
str: A brief description of the tool.
|
38
|
+
"""
|
39
|
+
source_signature = str(signature(self.source))
|
40
|
+
# 获取源函数的返回类型
|
41
|
+
return_annotation = getfullargspec(self.source).annotations.get("return", "None")
|
42
|
+
return f"{self.name}{source_signature} -> {return_annotation}\n{self.description}"
|
43
|
+
|
44
|
+
|
45
|
+
class ToolBox(WithBriefing):
|
46
|
+
"""A class representing a collection of tools."""
|
47
|
+
|
48
|
+
tools: List[Tool] = Field(default_factory=list)
|
49
|
+
"""A list of tools in the toolbox."""
|
50
|
+
|
51
|
+
def collect_tool[**P, R](self, func: Callable[P, R]) -> Callable[P, R]:
|
52
|
+
"""Add a callable function to the toolbox as a tool.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
func (Callable[P, R]): The function to be added as a tool.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Callable[P, R]: The added function.
|
59
|
+
"""
|
60
|
+
self.tools.append(Tool(source=func))
|
61
|
+
return func
|
62
|
+
|
63
|
+
def add_tool[**P, R](self, func: Callable[P, R]) -> Self:
|
64
|
+
"""Add a callable function to the toolbox as a tool.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
func (Callable): The function to be added as a tool.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
Self: The current instance of the toolbox.
|
71
|
+
"""
|
72
|
+
self.tools.append(Tool(source=func))
|
73
|
+
return self
|
74
|
+
|
75
|
+
@property
|
76
|
+
def briefing(self) -> str:
|
77
|
+
"""Return a brief description of the toolbox.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
str: A brief description of the toolbox.
|
81
|
+
"""
|
82
|
+
list_out = "\n\n".join([f"- {tool.briefing}" for tool in self.tools])
|
83
|
+
toc = f"## {self.name}: {self.description}\n## {len(self.tools)} tools available:"
|
84
|
+
return f"{toc}\n\n{list_out}"
|
85
|
+
|
86
|
+
def get[**P, R](self, name: str) -> Tool[P, R]:
|
87
|
+
"""Invoke a tool by name with the provided arguments.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
name (str): The name of the tool to invoke.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Tool: The tool instance with the specified name.
|
94
|
+
|
95
|
+
Raises:
|
96
|
+
AssertionError: If no tool with the specified name is found.
|
97
|
+
"""
|
98
|
+
tool = next((tool for tool in self.tools if tool.name == name), None)
|
99
|
+
assert tool, f"No tool named {name} found."
|
100
|
+
return tool
|
@@ -0,0 +1,78 @@
|
|
1
|
+
"""A module containing utility classes for the models."""
|
2
|
+
|
3
|
+
from typing import Dict, List, Literal, Self
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
|
+
|
7
|
+
|
8
|
+
class Message(BaseModel):
|
9
|
+
"""A class representing a message."""
|
10
|
+
|
11
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
12
|
+
role: Literal["user", "system", "assistant"] = Field(default="user")
|
13
|
+
"""
|
14
|
+
Who is sending the message.
|
15
|
+
"""
|
16
|
+
content: str = Field(default="")
|
17
|
+
"""
|
18
|
+
The content of the message.
|
19
|
+
"""
|
20
|
+
|
21
|
+
|
22
|
+
class Messages(list):
|
23
|
+
"""A list of messages."""
|
24
|
+
|
25
|
+
def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
|
26
|
+
"""Adds a message to the list with the specified role and content.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
role (Literal["user", "system", "assistant"]): The role of the message sender.
|
30
|
+
content (str): The content of the message.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
Self: The current instance of Messages to allow method chaining.
|
34
|
+
"""
|
35
|
+
if content:
|
36
|
+
self.append(Message(role=role, content=content))
|
37
|
+
return self
|
38
|
+
|
39
|
+
def add_user_message(self, content: str) -> Self:
|
40
|
+
"""Adds a user message to the list with the specified content.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
content (str): The content of the user message.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Self: The current instance of Messages to allow method chaining.
|
47
|
+
"""
|
48
|
+
return self.add_message("user", content)
|
49
|
+
|
50
|
+
def add_system_message(self, content: str) -> Self:
|
51
|
+
"""Adds a system message to the list with the specified content.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
content (str): The content of the system message.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Self: The current instance of Messages to allow method chaining.
|
58
|
+
"""
|
59
|
+
return self.add_message("system", content)
|
60
|
+
|
61
|
+
def add_assistant_message(self, content: str) -> Self:
|
62
|
+
"""Adds an assistant message to the list with the specified content.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
content (str): The content of the assistant message.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Self: The current instance of Messages to allow method chaining.
|
69
|
+
"""
|
70
|
+
return self.add_message("assistant", content)
|
71
|
+
|
72
|
+
def as_list(self) -> List[Dict[str, str]]:
|
73
|
+
"""Converts the messages to a list of dictionaries.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
list[dict]: A list of dictionaries representing the messages.
|
77
|
+
"""
|
78
|
+
return [message.model_dump() for message in self]
|
fabricatio/parser.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
"""A module to parse text using regular expressions."""
|
2
|
+
|
3
|
+
from typing import Any, Self, Tuple
|
4
|
+
|
5
|
+
import regex
|
6
|
+
from pydantic import Field, PositiveInt, PrivateAttr
|
7
|
+
from regex import Pattern, compile
|
8
|
+
|
9
|
+
from fabricatio.models.generic import Base
|
10
|
+
|
11
|
+
|
12
|
+
class Capture(Base):
|
13
|
+
"""A class to capture patterns in text using regular expressions.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
pattern (str): The regular expression pattern to search for.
|
17
|
+
_compiled (Pattern): The compiled regular expression pattern.
|
18
|
+
"""
|
19
|
+
|
20
|
+
target_groups: Tuple[int, ...] = Field(default_factory=tuple)
|
21
|
+
"""The target groups to capture from the pattern."""
|
22
|
+
pattern: str = Field(frozen=True)
|
23
|
+
"""The regular expression pattern to search for."""
|
24
|
+
flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
|
25
|
+
"""The flags to use when compiling the regular expression pattern."""
|
26
|
+
_compiled: Pattern = PrivateAttr()
|
27
|
+
|
28
|
+
def model_post_init(self, __context: Any) -> None:
|
29
|
+
"""Initialize the compiled regular expression pattern after the model is initialized.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
__context (Any): The context in which the model is initialized.
|
33
|
+
"""
|
34
|
+
self._compiled = compile(self.pattern, self.flags)
|
35
|
+
|
36
|
+
def capture(self, text: str) -> Tuple[str, ...] | None:
|
37
|
+
"""Capture the first occurrence of the pattern in the given text.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
text (str): The text to search the pattern in.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
str | None: The captured text if the pattern is found, otherwise None.
|
44
|
+
|
45
|
+
"""
|
46
|
+
match = self._compiled.search(text)
|
47
|
+
if match is None:
|
48
|
+
return None
|
49
|
+
|
50
|
+
if self.target_groups:
|
51
|
+
return tuple(match.group(g) for g in self.target_groups)
|
52
|
+
return (match.group(),)
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def capture_code_block(cls, language: str) -> Self:
|
56
|
+
"""Capture the first occurrence of a code block in the given text.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
language (str): The text containing the code block.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Self: The instance of the class with the captured code block.
|
63
|
+
"""
|
64
|
+
return cls(pattern=f"```{language}\n(.*?)\n```", target_groups=(1,))
|
65
|
+
|
66
|
+
|
67
|
+
JsonCapture = Capture.capture_code_block("json")
|
68
|
+
PythonCapture = Capture.capture_code_block("python")
|
69
|
+
CodeBlockCapture = Capture.capture_code_block("")
|
fabricatio/py.typed
ADDED
File without changes
|
fabricatio/templates.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
"""A module that manages templates for code generation."""
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Self
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, DirectoryPath, Field, FilePath, PrivateAttr
|
6
|
+
|
7
|
+
from fabricatio.config import configs
|
8
|
+
from fabricatio.journal import logger
|
9
|
+
|
10
|
+
|
11
|
+
class TemplateManager(BaseModel):
|
12
|
+
"""A class that manages templates for code generation."""
|
13
|
+
|
14
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
15
|
+
templates_dir: List[DirectoryPath] = Field(default_factory=lambda: list(configs.code2prompt.template_dir))
|
16
|
+
"""The directories containing the templates. first element has the highest override priority."""
|
17
|
+
_discovered_templates: Dict[str, FilePath] = PrivateAttr(default_factory=dict)
|
18
|
+
|
19
|
+
def model_post_init(self, __context: Any) -> None:
|
20
|
+
"""Post-initialization method for the model."""
|
21
|
+
self.discover_templates()
|
22
|
+
|
23
|
+
def discover_templates(self) -> Self:
|
24
|
+
"""Discover the templates in the template directories."""
|
25
|
+
discovered = [
|
26
|
+
f
|
27
|
+
for d in self.templates_dir[::-1]
|
28
|
+
for f in d.rglob(f"*{configs.code2prompt.template_suffix}", case_sensitive=False)
|
29
|
+
if f.is_file()
|
30
|
+
]
|
31
|
+
|
32
|
+
self._discovered_templates = {f.stem: f for f in discovered}
|
33
|
+
logger.info(f"Discovered {len(self._discovered_templates)} templates.")
|
34
|
+
return self
|
35
|
+
|
36
|
+
def get_template(self, name: str) -> FilePath | None:
|
37
|
+
"""Get the template with the specified name."""
|
38
|
+
return self._discovered_templates.get(name, None)
|
39
|
+
|
40
|
+
|
41
|
+
templates_manager = TemplateManager()
|
Binary file
|