fabricatio 0.2.1.dev0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,515 @@
1
+ """This module contains classes that manage the usage of language models and tools in tasks."""
2
+
3
+ from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack
4
+
5
+ import asyncstdlib
6
+ import litellm
7
+ import orjson
8
+ from fabricatio._rust_instances import template_manager
9
+ from fabricatio.config import configs
10
+ from fabricatio.journal import logger
11
+ from fabricatio.models.generic import Base, WithBriefing
12
+ from fabricatio.models.kwargs_types import ChooseKwargs, LLMKwargs
13
+ from fabricatio.models.task import Task
14
+ from fabricatio.models.tool import Tool, ToolBox
15
+ from fabricatio.models.utils import Messages
16
+ from fabricatio.parser import JsonCapture
17
+ from litellm import stream_chunk_builder
18
+ from litellm.types.utils import (
19
+ Choices,
20
+ ModelResponse,
21
+ StreamingChoices,
22
+ )
23
+ from litellm.utils import CustomStreamWrapper
24
+ from pydantic import Field, HttpUrl, NonNegativeFloat, NonNegativeInt, PositiveInt, SecretStr
25
+
26
+
27
+ class LLMUsage(Base):
28
+ """Class that manages LLM (Large Language Model) usage parameters and methods."""
29
+
30
+ llm_api_endpoint: Optional[HttpUrl] = None
31
+ """The OpenAI API endpoint."""
32
+
33
+ llm_api_key: Optional[SecretStr] = None
34
+ """The OpenAI API key."""
35
+
36
+ llm_timeout: Optional[PositiveInt] = None
37
+ """The timeout of the LLM model."""
38
+
39
+ llm_max_retries: Optional[PositiveInt] = None
40
+ """The maximum number of retries."""
41
+
42
+ llm_model: Optional[str] = None
43
+ """The LLM model name."""
44
+
45
+ llm_temperature: Optional[NonNegativeFloat] = None
46
+ """The temperature of the LLM model."""
47
+
48
+ llm_stop_sign: Optional[str | List[str]] = None
49
+ """The stop sign of the LLM model."""
50
+
51
+ llm_top_p: Optional[NonNegativeFloat] = None
52
+ """The top p of the LLM model."""
53
+
54
+ llm_generation_count: Optional[PositiveInt] = None
55
+ """The number of generations to generate."""
56
+
57
+ llm_stream: Optional[bool] = None
58
+ """Whether to stream the LLM model's response."""
59
+
60
+ llm_max_tokens: Optional[PositiveInt] = None
61
+ """The maximum number of tokens to generate."""
62
+
63
+ async def aquery(
64
+ self,
65
+ messages: List[Dict[str, str]],
66
+ n: PositiveInt | None = None,
67
+ **kwargs: Unpack[LLMKwargs],
68
+ ) -> ModelResponse | CustomStreamWrapper:
69
+ """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
70
+
71
+ Args:
72
+ messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
73
+ n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
74
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
75
+
76
+ Returns:
77
+ ModelResponse: An object containing the generated response and other metadata from the model.
78
+ """
79
+ # Call the underlying asynchronous completion function with the provided and default parameters
80
+ return await litellm.acompletion(
81
+ messages=messages,
82
+ n=n or self.llm_generation_count or configs.llm.generation_count,
83
+ model=kwargs.get("model") or self.llm_model or configs.llm.model,
84
+ temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
85
+ stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
86
+ top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
87
+ max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
88
+ stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
89
+ timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
90
+ max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
91
+ api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
92
+ base_url=self.llm_api_endpoint.unicode_string()
93
+ if self.llm_api_endpoint
94
+ else configs.llm.api_endpoint.unicode_string(),
95
+ )
96
+
97
+ async def ainvoke(
98
+ self,
99
+ question: str,
100
+ system_message: str = "",
101
+ n: PositiveInt | None = None,
102
+ **kwargs: Unpack[LLMKwargs],
103
+ ) -> List[Choices | StreamingChoices]:
104
+ """Asynchronously invokes the language model with a question and optional system message.
105
+
106
+ Args:
107
+ question (str): The question to ask the model.
108
+ system_message (str): The system message to provide context to the model. Defaults to an empty string.
109
+ n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
110
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
111
+
112
+ Returns:
113
+ List[Choices | StreamingChoices]: A list of choices or streaming choices from the model response.
114
+ """
115
+ resp = await self.aquery(
116
+ messages=Messages().add_system_message(system_message).add_user_message(question),
117
+ n=n,
118
+ **kwargs,
119
+ )
120
+ if isinstance(resp, ModelResponse):
121
+ return resp.choices
122
+ if isinstance(resp, CustomStreamWrapper):
123
+ if configs.debug.streaming_visible:
124
+ chunks = []
125
+ async for chunk in resp:
126
+ chunks.append(chunk)
127
+ print(chunk.choices[0].delta.content or "", end="") # noqa: T201
128
+ return stream_chunk_builder(chunks).choices
129
+ return stream_chunk_builder(await asyncstdlib.list()).choices
130
+ logger.critical(err := f"Unexpected response type: {type(resp)}")
131
+ raise ValueError(err)
132
+
133
+ async def aask(
134
+ self,
135
+ question: str,
136
+ system_message: str = "",
137
+ **kwargs: Unpack[LLMKwargs],
138
+ ) -> str:
139
+ """Asynchronously asks the language model a question and returns the response content.
140
+
141
+ Args:
142
+ question (str): The question to ask the model.
143
+ system_message (str): The system message to provide context to the model. Defaults to an empty string.
144
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
145
+
146
+ Returns:
147
+ str: The content of the model's response message.
148
+ """
149
+ return (
150
+ (
151
+ await self.ainvoke(
152
+ n=1,
153
+ question=question,
154
+ system_message=system_message,
155
+ **kwargs,
156
+ )
157
+ ).pop()
158
+ ).message.content
159
+
160
+ async def aask_validate[T](
161
+ self,
162
+ question: str,
163
+ validator: Callable[[str], T | None],
164
+ max_validations: PositiveInt = 2,
165
+ system_message: str = "",
166
+ **kwargs: Unpack[LLMKwargs],
167
+ ) -> T:
168
+ """Asynchronously asks a question and validates the response using a given validator.
169
+
170
+ Args:
171
+ question (str): The question to ask.
172
+ validator (Callable[[str], T | None]): A function to validate the response.
173
+ max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
174
+ system_message (str): System message to include in the request. Defaults to an empty string.
175
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
176
+
177
+ Returns:
178
+ T: The validated response.
179
+
180
+ Raises:
181
+ ValueError: If the response fails to validate after the maximum number of attempts.
182
+ """
183
+ for i in range(max_validations):
184
+ if (
185
+ response := await self.aask(
186
+ question=question,
187
+ system_message=system_message,
188
+ **kwargs,
189
+ )
190
+ ) and (validated := validator(response)):
191
+ logger.debug(f"Successfully validated the response at {i}th attempt. response: \n{response}")
192
+ return validated
193
+ logger.debug(f"Failed to validate the response at {i}th attempt. response: \n{response}")
194
+ logger.error(f"Failed to validate the response after {max_validations} attempts.")
195
+ raise ValueError("Failed to validate the response.")
196
+
197
+ async def achoose[T: WithBriefing](
198
+ self,
199
+ instruction: str,
200
+ choices: List[T],
201
+ k: NonNegativeInt = 0,
202
+ max_validations: PositiveInt = 2,
203
+ system_message: str = "",
204
+ **kwargs: Unpack[LLMKwargs],
205
+ ) -> List[T]:
206
+ """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
207
+
208
+ Args:
209
+ instruction (str): The user-provided instruction/question description.
210
+ choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
211
+ k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
212
+ max_validations (PositiveInt): Maximum number of validation failures, default is 2.
213
+ system_message (str): Custom system-level prompt, defaults to an empty string.
214
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
215
+
216
+ Returns:
217
+ List[T]: The final validated selection result list, with element types matching the input `choices`.
218
+
219
+ Important:
220
+ - Uses a template engine to generate structured prompts.
221
+ - Ensures response compliance through JSON parsing and format validation.
222
+ - Relies on `aask_validate` to implement retry mechanisms with validation.
223
+ """
224
+ prompt = template_manager.render_template(
225
+ configs.templates.make_choice_template,
226
+ {
227
+ "instruction": instruction,
228
+ "options": [{"name": m.name, "briefing": m.briefing} for m in choices],
229
+ "k": k,
230
+ },
231
+ )
232
+ names = {c.name for c in choices}
233
+ logger.debug(f"Start choosing between {names} with prompt: \n{prompt}")
234
+
235
+ def _validate(response: str) -> List[T] | None:
236
+ ret = JsonCapture.convert_with(response, orjson.loads)
237
+
238
+ if not isinstance(ret, List) or (0 < k != len(ret)):
239
+ logger.error(f"Incorrect Type or length of response: \n{ret}")
240
+ return None
241
+ if any(n not in names for n in ret):
242
+ logger.error(f"Invalid choice in response: \n{ret}")
243
+ return None
244
+
245
+ return [next(toolbox for toolbox in choices if toolbox.name == toolbox_str) for toolbox_str in ret]
246
+
247
+ return await self.aask_validate(
248
+ question=prompt,
249
+ validator=_validate,
250
+ max_validations=max_validations,
251
+ system_message=system_message,
252
+ **kwargs,
253
+ )
254
+
255
+ async def apick[T: WithBriefing](
256
+ self,
257
+ instruction: str,
258
+ choices: List[T],
259
+ max_validations: PositiveInt = 2,
260
+ system_message: str = "",
261
+ **kwargs: Unpack[LLMKwargs],
262
+ ) -> T:
263
+ """Asynchronously picks a single choice from a list of options using AI validation.
264
+
265
+ This method is a convenience wrapper around `achoose` that always selects exactly one item.
266
+
267
+ Args:
268
+ instruction (str): The user-provided instruction/question description.
269
+ choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
270
+ max_validations (PositiveInt): Maximum number of validation failures, default is 2.
271
+ system_message (str): Custom system-level prompt, defaults to an empty string.
272
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`,
273
+ `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
274
+
275
+ Returns:
276
+ T: The single selected item from the choices list.
277
+
278
+ Raises:
279
+ ValueError: If validation fails after maximum attempts or if no valid selection is made.
280
+ """
281
+ return await self.achoose(
282
+ instruction=instruction,
283
+ choices=choices,
284
+ k=1,
285
+ max_validations=max_validations,
286
+ system_message=system_message,
287
+ **kwargs,
288
+ )[0]
289
+
290
+ async def ajudge(
291
+ self,
292
+ prompt: str,
293
+ affirm_case: str = "",
294
+ deny_case: str = "",
295
+ max_validations: PositiveInt = 2,
296
+ system_message: str = "",
297
+ **kwargs: Unpack[LLMKwargs],
298
+ ) -> bool:
299
+ """Asynchronously judges a prompt using AI validation.
300
+
301
+ Args:
302
+ prompt (str): The input prompt to be judged.
303
+ affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
304
+ deny_case (str): The negative case for the AI model. Defaults to an empty string.
305
+ max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
306
+ system_message (str): System message for the AI model. Defaults to an empty string.
307
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
308
+
309
+ Returns:
310
+ bool: The judgment result (True or False) based on the AI's response.
311
+
312
+ Notes:
313
+ The method uses an internal validator to ensure the response is a boolean value.
314
+ If the response cannot be converted to a boolean, it will return None.
315
+ """
316
+
317
+ def _validate(response: str) -> bool | None:
318
+ ret = JsonCapture.convert_with(response, orjson.loads)
319
+ if not isinstance(ret, bool):
320
+ return None
321
+ return ret
322
+
323
+ return await self.aask_validate(
324
+ question=template_manager.render_template(
325
+ configs.templates.make_judgment_template,
326
+ {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
327
+ ),
328
+ validator=_validate,
329
+ max_validations=max_validations,
330
+ system_message=system_message,
331
+ **kwargs,
332
+ )
333
+
334
+ def fallback_to(self, other: "LLMUsage") -> Self:
335
+ """Fallback to another instance's attribute values if the current instance's attributes are None.
336
+
337
+ Args:
338
+ other (LLMUsage): Another instance from which to copy attribute values.
339
+
340
+ Returns:
341
+ Self: The current instance, allowing for method chaining.
342
+ """
343
+ # Iterate over the attribute names and copy values from 'other' to 'self' where applicable
344
+ # noinspection PydanticTypeChecker,PyTypeChecker
345
+ for attr_name in LLMUsage.model_fields:
346
+ # Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
347
+ if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
348
+ setattr(self, attr_name, attr)
349
+
350
+ # Return the current instance to allow for method chaining
351
+ return self
352
+
353
+ def hold_to(self, others: Union["LLMUsage", Iterable["LLMUsage"]]) -> Self:
354
+ """Hold to another instance's attribute values if the current instance's attributes are None.
355
+
356
+ Args:
357
+ others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
358
+
359
+ Returns:
360
+ Self: The current instance, allowing for method chaining.
361
+ """
362
+ for other in others:
363
+ # noinspection PyTypeChecker,PydanticTypeChecker
364
+ for attr_name in LLMUsage.model_fields:
365
+ if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
366
+ setattr(other, attr_name, attr)
367
+
368
+
369
+ class ToolBoxUsage(LLMUsage):
370
+ """A class representing the usage of tools in a task."""
371
+
372
+ toolboxes: Set[ToolBox] = Field(default_factory=set)
373
+ """A set of toolboxes used by the instance."""
374
+
375
+ @property
376
+ def available_toolbox_names(self) -> List[str]:
377
+ """Return a list of available toolbox names."""
378
+ return [toolbox.name for toolbox in self.toolboxes]
379
+
380
+ async def choose_toolboxes(
381
+ self,
382
+ task: Task,
383
+ system_message: str = "",
384
+ k: NonNegativeInt = 0,
385
+ max_validations: PositiveInt = 2,
386
+ **kwargs: Unpack[LLMKwargs],
387
+ ) -> List[ToolBox]:
388
+ """Asynchronously executes a multi-choice decision-making process to choose toolboxes.
389
+
390
+ Args:
391
+ task (Task): The task for which to choose toolboxes.
392
+ system_message (str): Custom system-level prompt, defaults to an empty string.
393
+ k (NonNegativeInt): The number of toolboxes to select, 0 means infinite. Defaults to 0.
394
+ max_validations (PositiveInt): Maximum number of validation failures, default is 2.
395
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
396
+
397
+ Returns:
398
+ List[ToolBox]: The selected toolboxes.
399
+ """
400
+ if not self.toolboxes:
401
+ logger.warning("No toolboxes available.")
402
+ return []
403
+ return await self.achoose(
404
+ instruction=task.briefing, # TODO write a template to build a more robust instruction
405
+ choices=list(self.toolboxes),
406
+ k=k,
407
+ max_validations=max_validations,
408
+ system_message=system_message,
409
+ **kwargs,
410
+ )
411
+
412
+ async def choose_tools(
413
+ self,
414
+ task: Task,
415
+ toolbox: ToolBox,
416
+ system_message: str = "",
417
+ k: NonNegativeInt = 0,
418
+ max_validations: PositiveInt = 2,
419
+ **kwargs: Unpack[LLMKwargs],
420
+ ) -> List[Tool]:
421
+ """Asynchronously executes a multi-choice decision-making process to choose tools.
422
+
423
+ Args:
424
+ task (Task): The task for which to choose tools.
425
+ toolbox (ToolBox): The toolbox from which to choose tools.
426
+ system_message (str): Custom system-level prompt, defaults to an empty string.
427
+ k (NonNegativeInt): The number of tools to select, 0 means infinite. Defaults to 0.
428
+ max_validations (PositiveInt): Maximum number of validation failures, default is 2.
429
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage, such as `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
430
+
431
+ Returns:
432
+ List[Tool]: The selected tools.
433
+ """
434
+ if not toolbox.tools:
435
+ logger.warning(f"No tools available in toolbox {toolbox.name}.")
436
+ return []
437
+ return await self.achoose(
438
+ instruction=task.briefing, # TODO write a template to build a more robust instruction
439
+ choices=toolbox.tools,
440
+ k=k,
441
+ max_validations=max_validations,
442
+ system_message=system_message,
443
+ **kwargs,
444
+ )
445
+
446
+ async def gather_tools_fine_grind(
447
+ self,
448
+ task: Task,
449
+ box_choose_kwargs: Optional[ChooseKwargs] = None,
450
+ tool_choose_kwargs: Optional[ChooseKwargs] = None,
451
+ ) -> List[Tool]:
452
+ """Asynchronously gathers tools based on the provided task and toolbox and tool selection criteria.
453
+
454
+ Args:
455
+ task (Task): The task for which to gather tools.
456
+ box_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing toolboxes, such as `system_message`, `k`, `max_validations`, `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
457
+ tool_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing tools, such as `system_message`, `k`, `max_validations`, `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
458
+
459
+ Returns:
460
+ List[Tool]: A list of tools gathered based on the provided task and toolbox and tool selection criteria.
461
+ """
462
+ box_choose_kwargs = box_choose_kwargs or {}
463
+ tool_choose_kwargs = tool_choose_kwargs or {}
464
+
465
+ # Choose the toolboxes
466
+ chosen_toolboxes = await self.choose_toolboxes(task, **box_choose_kwargs)
467
+ # Choose the tools
468
+ chosen_tools = []
469
+ for toolbox in chosen_toolboxes:
470
+ chosen_tools.extend(await self.choose_tools(task, toolbox, **tool_choose_kwargs))
471
+ return chosen_tools
472
+
473
+ async def gather_tools(self, task: Task, **kwargs: Unpack[ChooseKwargs]) -> List[Tool]:
474
+ """Asynchronously gathers tools based on the provided task.
475
+
476
+ Args:
477
+ task (Task): The task for which to gather tools.
478
+ **kwargs (Unpack[ChooseKwargs]): Keyword arguments for choosing tools, such as `system_message`, `k`, `max_validations`, `model`, `temperature`, `stop`, `top_p`, `max_tokens`, `stream`, `timeout`, and `max_retries`.
479
+
480
+ Returns:
481
+ List[Tool]: A list of tools gathered based on the provided task.
482
+ """
483
+ return await self.gather_tools_fine_grind(task, kwargs, kwargs)
484
+
485
+ def supply_tools_from[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
486
+ """Supplies tools from other ToolUsage instances to this instance.
487
+
488
+ Args:
489
+ others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
490
+ from which to take tools.
491
+
492
+ Returns:
493
+ Self: The current ToolUsage instance with updated tools.
494
+ """
495
+ if isinstance(others, ToolBoxUsage):
496
+ others = [others]
497
+ for other in others:
498
+ self.toolboxes.update(other.toolboxes)
499
+ return self
500
+
501
+ def provide_tools_to[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
502
+ """Provides tools from this instance to other ToolUsage instances.
503
+
504
+ Args:
505
+ others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
506
+ to which to provide tools.
507
+
508
+ Returns:
509
+ Self: The current ToolUsage instance.
510
+ """
511
+ if isinstance(others, ToolBoxUsage):
512
+ others = [others]
513
+ for other in others:
514
+ other.toolboxes.update(self.toolboxes)
515
+ return self
@@ -0,0 +1,78 @@
1
+ """A module containing utility classes for the models."""
2
+
3
+ from typing import Dict, List, Literal, Self
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+
8
+ class Message(BaseModel):
9
+ """A class representing a message."""
10
+
11
+ model_config = ConfigDict(use_attribute_docstrings=True)
12
+ role: Literal["user", "system", "assistant"] = Field(default="user")
13
+ """
14
+ Who is sending the message.
15
+ """
16
+ content: str = Field(default="")
17
+ """
18
+ The content of the message.
19
+ """
20
+
21
+
22
+ class Messages(list):
23
+ """A list of messages."""
24
+
25
+ def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
26
+ """Adds a message to the list with the specified role and content.
27
+
28
+ Args:
29
+ role (Literal["user", "system", "assistant"]): The role of the message sender.
30
+ content (str): The content of the message.
31
+
32
+ Returns:
33
+ Self: The current instance of Messages to allow method chaining.
34
+ """
35
+ if content:
36
+ self.append(Message(role=role, content=content))
37
+ return self
38
+
39
+ def add_user_message(self, content: str) -> Self:
40
+ """Adds a user message to the list with the specified content.
41
+
42
+ Args:
43
+ content (str): The content of the user message.
44
+
45
+ Returns:
46
+ Self: The current instance of Messages to allow method chaining.
47
+ """
48
+ return self.add_message("user", content)
49
+
50
+ def add_system_message(self, content: str) -> Self:
51
+ """Adds a system message to the list with the specified content.
52
+
53
+ Args:
54
+ content (str): The content of the system message.
55
+
56
+ Returns:
57
+ Self: The current instance of Messages to allow method chaining.
58
+ """
59
+ return self.add_message("system", content)
60
+
61
+ def add_assistant_message(self, content: str) -> Self:
62
+ """Adds an assistant message to the list with the specified content.
63
+
64
+ Args:
65
+ content (str): The content of the assistant message.
66
+
67
+ Returns:
68
+ Self: The current instance of Messages to allow method chaining.
69
+ """
70
+ return self.add_message("assistant", content)
71
+
72
+ def as_list(self) -> List[Dict[str, str]]:
73
+ """Converts the messages to a list of dictionaries.
74
+
75
+ Returns:
76
+ list[dict]: A list of dictionaries representing the messages.
77
+ """
78
+ return [message.model_dump() for message in self]
fabricatio/parser.py ADDED
@@ -0,0 +1,93 @@
1
+ """A module to parse text using regular expressions."""
2
+
3
+ from typing import Any, Callable, Self, Tuple
4
+
5
+ import regex
6
+ from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
7
+ from regex import Pattern, compile
8
+
9
+ from fabricatio.journal import logger
10
+
11
+
12
+ class Capture(BaseModel):
13
+ """A class to capture patterns in text using regular expressions.
14
+
15
+ Attributes:
16
+ pattern (str): The regular expression pattern to search for.
17
+ _compiled (Pattern): The compiled regular expression pattern.
18
+ """
19
+
20
+ model_config = ConfigDict(use_attribute_docstrings=True)
21
+ target_groups: Tuple[int, ...] = Field(default_factory=tuple)
22
+ """The target groups to capture from the pattern."""
23
+ pattern: str = Field(frozen=True)
24
+ """The regular expression pattern to search for."""
25
+ flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
26
+ """The flags to use when compiling the regular expression pattern."""
27
+ _compiled: Pattern = PrivateAttr()
28
+
29
+ def model_post_init(self, __context: Any) -> None:
30
+ """Initialize the compiled regular expression pattern after the model is initialized.
31
+
32
+ Args:
33
+ __context (Any): The context in which the model is initialized.
34
+ """
35
+ self._compiled = compile(self.pattern, self.flags)
36
+
37
+ def capture(self, text: str) -> Tuple[str, ...] | str | None:
38
+ """Capture the first occurrence of the pattern in the given text.
39
+
40
+ Args:
41
+ text (str): The text to search the pattern in.
42
+
43
+ Returns:
44
+ str | None: The captured text if the pattern is found, otherwise None.
45
+
46
+ """
47
+ match = self._compiled.search(text)
48
+ if match is None:
49
+ return None
50
+
51
+ if self.target_groups:
52
+ cap = tuple(match.group(g) for g in self.target_groups)
53
+ logger.debug(f"Captured text: {'\n\n'.join(cap)}")
54
+ return cap
55
+ cap = match.group(1)
56
+ logger.debug(f"Captured text: \n{cap}")
57
+ return cap
58
+
59
+ def convert_with[T](self, text: str, convertor: Callable[[Tuple[str, ...]], T] | Callable[[str], T]) -> T | None:
60
+ """Convert the given text using the pattern.
61
+
62
+ Args:
63
+ text (str): The text to search the pattern in.
64
+ convertor (Callable[[Tuple[str, ...]], T] | Callable[[str], T]): The function to convert the captured text.
65
+
66
+ Returns:
67
+ str | None: The converted text if the pattern is found, otherwise None.
68
+ """
69
+ if (cap := self.capture(text)) is None:
70
+ return None
71
+ try:
72
+ return convertor(cap)
73
+ except (ValueError, SyntaxError) as e:
74
+ logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
75
+ return None
76
+
77
+ @classmethod
78
+ def capture_code_block(cls, language: str) -> Self:
79
+ """Capture the first occurrence of a code block in the given text.
80
+
81
+ Args:
82
+ language (str): The text containing the code block.
83
+
84
+ Returns:
85
+ Self: The instance of the class with the captured code block.
86
+ """
87
+ return cls(pattern=f"```{language}\n(.*?)\n```")
88
+
89
+
90
+ JsonCapture = Capture.capture_code_block("json")
91
+ PythonCapture = Capture.capture_code_block("python")
92
+ MarkdownCapture = Capture.capture_code_block("markdown")
93
+ CodeBlockCapture = Capture(pattern="```.*?\n(.*?)\n```")
fabricatio/py.typed ADDED
File without changes