fabricatio 0.2.0.dev4__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +32 -0
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +1 -0
- fabricatio/actions/__init__.py +5 -0
- fabricatio/actions/communication.py +13 -0
- fabricatio/actions/transmission.py +32 -0
- fabricatio/config.py +206 -0
- fabricatio/core.py +167 -0
- fabricatio/decorators.py +56 -0
- fabricatio/fs/__init__.py +5 -0
- fabricatio/fs/readers.py +5 -0
- fabricatio/journal.py +23 -0
- fabricatio/models/action.py +128 -0
- fabricatio/models/events.py +80 -0
- fabricatio/models/generic.py +388 -0
- fabricatio/models/role.py +26 -0
- fabricatio/models/task.py +283 -0
- fabricatio/models/tool.py +100 -0
- fabricatio/models/utils.py +78 -0
- fabricatio/parser.py +69 -0
- fabricatio/py.typed +0 -0
- fabricatio/templates.py +41 -0
- fabricatio/toolboxes/__init__.py +7 -0
- fabricatio/toolboxes/task.py +4 -0
- fabricatio-0.2.0.dev4.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.0.dev4.dist-info/METADATA +224 -0
- fabricatio-0.2.0.dev4.dist-info/RECORD +29 -0
- fabricatio-0.2.0.dev4.dist-info/WHEEL +4 -0
- fabricatio-0.2.0.dev4.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,80 @@
|
|
1
|
+
"""The module containing the Event class."""
|
2
|
+
|
3
|
+
from typing import List, Self
|
4
|
+
|
5
|
+
from fabricatio.config import configs
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
7
|
+
|
8
|
+
type EventLike = str | List[str] | Self
|
9
|
+
|
10
|
+
|
11
|
+
class Event(BaseModel):
|
12
|
+
"""A class representing an event."""
|
13
|
+
|
14
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
15
|
+
|
16
|
+
segments: List[str] = Field(default_factory=list, frozen=True)
|
17
|
+
""" The segments of the namespaces."""
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def instantiate_from(cls, event: EventLike) -> Self:
|
21
|
+
"""Create an Event instance from a string or list of strings or an Event instance.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
event (EventLike): The event to instantiate from.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
Event: The Event instance.
|
28
|
+
"""
|
29
|
+
if isinstance(event, Event):
|
30
|
+
return event.clone()
|
31
|
+
if isinstance(event, str):
|
32
|
+
event = event.split(configs.pymitter.delimiter)
|
33
|
+
|
34
|
+
return cls(segments=event)
|
35
|
+
|
36
|
+
def derive(self, event: EventLike) -> Self:
|
37
|
+
"""Derive a new event from this event and another event or a string."""
|
38
|
+
return self.clone().concat(event)
|
39
|
+
|
40
|
+
def collapse(self) -> str:
|
41
|
+
"""Collapse the event into a string."""
|
42
|
+
return configs.pymitter.delimiter.join(self.segments)
|
43
|
+
|
44
|
+
def clone(self) -> Self:
|
45
|
+
"""Clone the event."""
|
46
|
+
return Event(segments=list(self.segments))
|
47
|
+
|
48
|
+
def push(self, segment: str) -> Self:
|
49
|
+
"""Push a segment to the event."""
|
50
|
+
assert segment, "The segment must not be empty."
|
51
|
+
assert configs.pymitter.delimiter not in segment, "The segment must not contain the delimiter."
|
52
|
+
|
53
|
+
self.segments.append(segment)
|
54
|
+
return self
|
55
|
+
|
56
|
+
def push_wildcard(self) -> Self:
|
57
|
+
"""Push a wildcard segment to the event."""
|
58
|
+
return self.push("*")
|
59
|
+
|
60
|
+
def pop(self) -> str:
|
61
|
+
"""Pop a segment from the event."""
|
62
|
+
return self.segments.pop()
|
63
|
+
|
64
|
+
def clear(self) -> Self:
|
65
|
+
"""Clear the event."""
|
66
|
+
self.segments.clear()
|
67
|
+
return self
|
68
|
+
|
69
|
+
def concat(self, event: EventLike) -> Self:
|
70
|
+
"""Concatenate another event to this event."""
|
71
|
+
self.segments.extend(Event.instantiate_from(event).segments)
|
72
|
+
return self
|
73
|
+
|
74
|
+
def __hash__(self) -> int:
|
75
|
+
"""Return the hash of the event, using the collapsed string."""
|
76
|
+
return hash(self.collapse())
|
77
|
+
|
78
|
+
def __eq__(self, other: str | List[str] | Self) -> bool:
|
79
|
+
"""Check if the event is equal to another event or a string."""
|
80
|
+
return self.collapse() == Event.instantiate_from(other).collapse()
|
@@ -0,0 +1,388 @@
|
|
1
|
+
"""This module defines generic classes for models in the Fabricatio library."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Callable, Dict, List, Optional, Self
|
5
|
+
|
6
|
+
import litellm
|
7
|
+
import orjson
|
8
|
+
from fabricatio.config import configs
|
9
|
+
from fabricatio.fs.readers import magika
|
10
|
+
from fabricatio.models.utils import Messages
|
11
|
+
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
12
|
+
from pydantic import (
|
13
|
+
BaseModel,
|
14
|
+
ConfigDict,
|
15
|
+
Field,
|
16
|
+
HttpUrl,
|
17
|
+
NonNegativeFloat,
|
18
|
+
PositiveInt,
|
19
|
+
SecretStr,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class Base(BaseModel):
|
24
|
+
"""Base class for all models with Pydantic configuration."""
|
25
|
+
|
26
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
27
|
+
|
28
|
+
|
29
|
+
class Named(Base):
|
30
|
+
"""Class that includes a name attribute."""
|
31
|
+
|
32
|
+
name: str = Field(frozen=True)
|
33
|
+
"""The name of the object."""
|
34
|
+
|
35
|
+
|
36
|
+
class Described(Base):
|
37
|
+
"""Class that includes a description attribute."""
|
38
|
+
|
39
|
+
description: str = Field(default="", frozen=True)
|
40
|
+
"""The description of the object."""
|
41
|
+
|
42
|
+
|
43
|
+
class WithBriefing(Named, Described):
|
44
|
+
"""Class that provides a briefing based on the name and description."""
|
45
|
+
|
46
|
+
@property
|
47
|
+
def briefing(self) -> str:
|
48
|
+
"""Get the briefing of the object.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
str: The briefing of the object.
|
52
|
+
"""
|
53
|
+
return f"{self.name}: {self.description}" if self.description else self.name
|
54
|
+
|
55
|
+
|
56
|
+
class LLMUsage(Base):
|
57
|
+
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
58
|
+
|
59
|
+
llm_api_endpoint: Optional[HttpUrl] = None
|
60
|
+
"""The OpenAI API endpoint."""
|
61
|
+
|
62
|
+
llm_api_key: Optional[SecretStr] = None
|
63
|
+
"""The OpenAI API key."""
|
64
|
+
|
65
|
+
llm_timeout: Optional[PositiveInt] = None
|
66
|
+
"""The timeout of the LLM model."""
|
67
|
+
|
68
|
+
llm_max_retries: Optional[PositiveInt] = None
|
69
|
+
"""The maximum number of retries."""
|
70
|
+
|
71
|
+
llm_model: Optional[str] = None
|
72
|
+
"""The LLM model name."""
|
73
|
+
|
74
|
+
llm_temperature: Optional[NonNegativeFloat] = None
|
75
|
+
"""The temperature of the LLM model."""
|
76
|
+
|
77
|
+
llm_stop_sign: Optional[str | List[str]] = None
|
78
|
+
"""The stop sign of the LLM model."""
|
79
|
+
|
80
|
+
llm_top_p: Optional[NonNegativeFloat] = None
|
81
|
+
"""The top p of the LLM model."""
|
82
|
+
|
83
|
+
llm_generation_count: Optional[PositiveInt] = None
|
84
|
+
"""The number of generations to generate."""
|
85
|
+
|
86
|
+
llm_stream: Optional[bool] = None
|
87
|
+
"""Whether to stream the LLM model's response."""
|
88
|
+
|
89
|
+
llm_max_tokens: Optional[PositiveInt] = None
|
90
|
+
"""The maximum number of tokens to generate."""
|
91
|
+
|
92
|
+
async def aquery(
|
93
|
+
self,
|
94
|
+
messages: List[Dict[str, str]],
|
95
|
+
model: str | None = None,
|
96
|
+
temperature: NonNegativeFloat | None = None,
|
97
|
+
stop: str | List[str] | None = None,
|
98
|
+
top_p: NonNegativeFloat | None = None,
|
99
|
+
max_tokens: PositiveInt | None = None,
|
100
|
+
n: PositiveInt | None = None,
|
101
|
+
stream: bool | None = None,
|
102
|
+
timeout: PositiveInt | None = None,
|
103
|
+
max_retries: PositiveInt | None = None,
|
104
|
+
) -> ModelResponse:
|
105
|
+
"""Asynchronously queries the language model to generate a response based on the provided messages and parameters.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
|
109
|
+
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
110
|
+
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
111
|
+
stop (str | None): A sequence at which to stop the generation of the response.
|
112
|
+
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
113
|
+
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
114
|
+
n (PositiveInt | None): The number of responses to generate.
|
115
|
+
stream (bool | None): Whether to receive the response in a streaming fashion.
|
116
|
+
timeout (PositiveInt | None): The timeout duration for the request.
|
117
|
+
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
ModelResponse: An object containing the generated response and other metadata from the model.
|
121
|
+
"""
|
122
|
+
# Call the underlying asynchronous completion function with the provided and default parameters
|
123
|
+
return await litellm.acompletion(
|
124
|
+
messages=messages,
|
125
|
+
model=model or self.llm_model or configs.llm.model,
|
126
|
+
temperature=temperature or self.llm_temperature or configs.llm.temperature,
|
127
|
+
stop=stop or self.llm_stop_sign or configs.llm.stop_sign,
|
128
|
+
top_p=top_p or self.llm_top_p or configs.llm.top_p,
|
129
|
+
max_tokens=max_tokens or self.llm_max_tokens or configs.llm.max_tokens,
|
130
|
+
n=n or self.llm_generation_count or configs.llm.generation_count,
|
131
|
+
stream=stream or self.llm_stream or configs.llm.stream,
|
132
|
+
timeout=timeout or self.llm_timeout or configs.llm.timeout,
|
133
|
+
max_retries=max_retries or self.llm_max_retries or configs.llm.max_retries,
|
134
|
+
api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
|
135
|
+
base_url=self.llm_api_endpoint.unicode_string()
|
136
|
+
if self.llm_api_endpoint
|
137
|
+
else configs.llm.api_endpoint.unicode_string(),
|
138
|
+
)
|
139
|
+
|
140
|
+
async def ainvoke(
|
141
|
+
self,
|
142
|
+
question: str,
|
143
|
+
system_message: str = "",
|
144
|
+
model: str | None = None,
|
145
|
+
temperature: NonNegativeFloat | None = None,
|
146
|
+
stop: str | List[str] | None = None,
|
147
|
+
top_p: NonNegativeFloat | None = None,
|
148
|
+
max_tokens: PositiveInt | None = None,
|
149
|
+
n: PositiveInt | None = None,
|
150
|
+
stream: bool | None = None,
|
151
|
+
timeout: PositiveInt | None = None,
|
152
|
+
max_retries: PositiveInt | None = None,
|
153
|
+
) -> List[Choices | StreamingChoices]:
|
154
|
+
"""Asynchronously invokes the language model with a question and optional system message.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
question (str): The question to ask the model.
|
158
|
+
system_message (str): The system message to provide context to the model.
|
159
|
+
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
160
|
+
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
161
|
+
stop (str | None): A sequence at which to stop the generation of the response.
|
162
|
+
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
163
|
+
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
164
|
+
n (PositiveInt | None): The number of responses to generate.
|
165
|
+
stream (bool | None): Whether to receive the response in a streaming fashion.
|
166
|
+
timeout (PositiveInt | None): The timeout duration for the request.
|
167
|
+
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
List[Choices | StreamingChoices]: A list of choices or streaming choices from the model response.
|
171
|
+
"""
|
172
|
+
return (
|
173
|
+
await self.aquery(
|
174
|
+
messages=Messages().add_system_message(system_message).add_user_message(question),
|
175
|
+
model=model,
|
176
|
+
temperature=temperature,
|
177
|
+
stop=stop,
|
178
|
+
top_p=top_p,
|
179
|
+
max_tokens=max_tokens,
|
180
|
+
n=n,
|
181
|
+
stream=stream,
|
182
|
+
timeout=timeout,
|
183
|
+
max_retries=max_retries,
|
184
|
+
)
|
185
|
+
).choices
|
186
|
+
|
187
|
+
async def aask(
|
188
|
+
self,
|
189
|
+
question: str,
|
190
|
+
system_message: str = "",
|
191
|
+
model: str | None = None,
|
192
|
+
temperature: NonNegativeFloat | None = None,
|
193
|
+
stop: str | List[str] | None = None,
|
194
|
+
top_p: NonNegativeFloat | None = None,
|
195
|
+
max_tokens: PositiveInt | None = None,
|
196
|
+
stream: bool | None = None,
|
197
|
+
timeout: PositiveInt | None = None,
|
198
|
+
max_retries: PositiveInt | None = None,
|
199
|
+
) -> str:
|
200
|
+
"""Asynchronously asks the language model a question and returns the response content.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
question (str): The question to ask the model.
|
204
|
+
system_message (str): The system message to provide context to the model.
|
205
|
+
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
206
|
+
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
207
|
+
stop (str | None): A sequence at which to stop the generation of the response.
|
208
|
+
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
209
|
+
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
210
|
+
stream (bool | None): Whether to receive the response in a streaming fashion.
|
211
|
+
timeout (PositiveInt | None): The timeout duration for the request.
|
212
|
+
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
str: The content of the model's response message.
|
216
|
+
"""
|
217
|
+
return (
|
218
|
+
(
|
219
|
+
await self.ainvoke(
|
220
|
+
n=1,
|
221
|
+
question=question,
|
222
|
+
system_message=system_message,
|
223
|
+
model=model,
|
224
|
+
temperature=temperature,
|
225
|
+
stop=stop,
|
226
|
+
top_p=top_p,
|
227
|
+
max_tokens=max_tokens,
|
228
|
+
stream=stream,
|
229
|
+
timeout=timeout,
|
230
|
+
max_retries=max_retries,
|
231
|
+
)
|
232
|
+
)
|
233
|
+
.pop()
|
234
|
+
.message.content
|
235
|
+
)
|
236
|
+
|
237
|
+
async def aask_validate[T](
|
238
|
+
self,
|
239
|
+
question: str,
|
240
|
+
validator: Callable[[str], T | None],
|
241
|
+
max_validations: PositiveInt = 2,
|
242
|
+
system_message: str = "",
|
243
|
+
model: str | None = None,
|
244
|
+
temperature: NonNegativeFloat | None = None,
|
245
|
+
stop: str | List[str] | None = None,
|
246
|
+
top_p: NonNegativeFloat | None = None,
|
247
|
+
max_tokens: PositiveInt | None = None,
|
248
|
+
stream: bool | None = None,
|
249
|
+
timeout: PositiveInt | None = None,
|
250
|
+
max_retries: PositiveInt | None = None,
|
251
|
+
) -> T:
|
252
|
+
"""Asynchronously ask a question and validate the response using a given validator.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
question (str): The question to ask.
|
256
|
+
validator (Callable[[str], T | None]): A function to validate the response.
|
257
|
+
max_validations (PositiveInt): Maximum number of validation attempts.
|
258
|
+
system_message (str): System message to include in the request.
|
259
|
+
model (str | None): The model to use for the request.
|
260
|
+
temperature (NonNegativeFloat | None): Temperature setting for the request.
|
261
|
+
stop (str | None): Stop sequence for the request.
|
262
|
+
top_p (NonNegativeFloat | None): Top-p sampling parameter.
|
263
|
+
max_tokens (PositiveInt | None): Maximum number of tokens in the response.
|
264
|
+
stream (bool | None): Whether to stream the response.
|
265
|
+
timeout (PositiveInt | None): Timeout for the request.
|
266
|
+
max_retries (PositiveInt | None): Maximum number of retries for the request.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
T: The validated response.
|
270
|
+
|
271
|
+
Raises:
|
272
|
+
ValueError: If the response fails to validate after the maximum number of attempts.
|
273
|
+
"""
|
274
|
+
for _ in range(max_validations):
|
275
|
+
if (
|
276
|
+
response := await self.aask(
|
277
|
+
question,
|
278
|
+
system_message,
|
279
|
+
model,
|
280
|
+
temperature,
|
281
|
+
stop,
|
282
|
+
top_p,
|
283
|
+
max_tokens,
|
284
|
+
stream,
|
285
|
+
timeout,
|
286
|
+
max_retries,
|
287
|
+
)
|
288
|
+
) and (validated := validator(response)):
|
289
|
+
return validated
|
290
|
+
raise ValueError("Failed to validate the response.")
|
291
|
+
|
292
|
+
def fallback_to(self, other: "LLMUsage") -> Self:
|
293
|
+
"""Fallback to another instance's attribute values if the current instance's attributes are None.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
other (LLMUsage): Another instance from which to copy attribute values.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
Self: The current instance, allowing for method chaining.
|
300
|
+
"""
|
301
|
+
# Define the list of attribute names to check and potentially copy
|
302
|
+
attr_names = [
|
303
|
+
"llm_api_endpoint",
|
304
|
+
"llm_api_key",
|
305
|
+
"llm_model",
|
306
|
+
"llm_stop_sign",
|
307
|
+
"llm_temperature",
|
308
|
+
"llm_top_p",
|
309
|
+
"llm_generation_count",
|
310
|
+
"llm_stream",
|
311
|
+
"llm_max_tokens",
|
312
|
+
"llm_timeout",
|
313
|
+
"llm_max_retries",
|
314
|
+
]
|
315
|
+
|
316
|
+
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
|
317
|
+
for attr_name in attr_names:
|
318
|
+
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
|
319
|
+
if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
|
320
|
+
setattr(self, attr_name, attr)
|
321
|
+
|
322
|
+
# Return the current instance to allow for method chaining
|
323
|
+
return self
|
324
|
+
|
325
|
+
|
326
|
+
class WithJsonExample(Base):
|
327
|
+
"""Class that provides a JSON schema for the model."""
|
328
|
+
|
329
|
+
@classmethod
|
330
|
+
def json_example(cls) -> str:
|
331
|
+
"""Return a JSON example for the model.
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
str: A JSON example for the model.
|
335
|
+
"""
|
336
|
+
return orjson.dumps(
|
337
|
+
{field_name: field_info.description for field_name, field_info in cls.model_fields.items()},
|
338
|
+
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS,
|
339
|
+
).decode()
|
340
|
+
|
341
|
+
|
342
|
+
class WithDependency(Base):
|
343
|
+
"""Class that manages file dependencies."""
|
344
|
+
|
345
|
+
dependencies: List[str] = Field(default_factory=list)
|
346
|
+
"""The file dependencies of the task, a list of file paths."""
|
347
|
+
|
348
|
+
def add_dependency[P: str | Path](self, dependency: P | List[P]) -> Self:
|
349
|
+
"""Add a file dependency to the task.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
dependency (str | Path | List[str | Path]): The file dependency to add to the task.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
Self: The current instance of the task.
|
356
|
+
"""
|
357
|
+
if not isinstance(dependency, list):
|
358
|
+
dependency = [dependency]
|
359
|
+
self.dependencies.extend(Path(d).as_posix() for d in dependency)
|
360
|
+
return self
|
361
|
+
|
362
|
+
def remove_dependency[P: str | Path](self, dependency: P | List[P]) -> Self:
|
363
|
+
"""Remove a file dependency from the task.
|
364
|
+
|
365
|
+
Args:
|
366
|
+
dependency (str | Path | List[str | Path]): The file dependency to remove from the task.
|
367
|
+
|
368
|
+
Returns:
|
369
|
+
Self: The current instance of the task.
|
370
|
+
"""
|
371
|
+
if not isinstance(dependency, list):
|
372
|
+
dependency = [dependency]
|
373
|
+
for d in dependency:
|
374
|
+
self.dependencies.remove(Path(d).as_posix())
|
375
|
+
return self
|
376
|
+
|
377
|
+
def generate_prompt(self) -> str:
|
378
|
+
"""Generate a prompt for the task based on the file dependencies.
|
379
|
+
|
380
|
+
Returns:
|
381
|
+
str: The generated prompt for the task.
|
382
|
+
"""
|
383
|
+
contents = [Path(d).read_text("utf-8") for d in self.dependencies]
|
384
|
+
recognized = [magika.identify_path(c) for c in contents]
|
385
|
+
out = ""
|
386
|
+
for r, p, c in zip(recognized, self.dependencies, contents, strict=False):
|
387
|
+
out += f"---\n\n> {p}\n```{r.dl.ct_label}\n{c}\n```\n\n"
|
388
|
+
return out
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""Module that contains the Role class."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from fabricatio.core import env
|
6
|
+
from fabricatio.journal import logger
|
7
|
+
from fabricatio.models.action import WorkFlow
|
8
|
+
from fabricatio.models.events import Event
|
9
|
+
from fabricatio.models.task import ProposeTask
|
10
|
+
from pydantic import Field
|
11
|
+
|
12
|
+
|
13
|
+
class Role(ProposeTask):
|
14
|
+
"""Class that represents a role with a registry of events and workflows."""
|
15
|
+
|
16
|
+
registry: dict[Event | str, WorkFlow] = Field(...)
|
17
|
+
""" The registry of events and workflows."""
|
18
|
+
|
19
|
+
def model_post_init(self, __context: Any) -> None:
|
20
|
+
"""Register the workflows in the role to the event bus."""
|
21
|
+
for event, workflow in self.registry.items():
|
22
|
+
workflow.fallback_to(self).fallback_to_self().inject_personality(self.briefing)
|
23
|
+
logger.debug(
|
24
|
+
f"Registering workflow: {workflow.name} for event: {event.collapse() if isinstance(event, Event) else event}"
|
25
|
+
)
|
26
|
+
env.on(event, workflow.serve)
|