fabricatio 0.2.3.dev3__cp312-cp312-win_amd64.whl → 0.2.4.dev1__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +6 -2
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/actions/__init__.py +2 -2
- fabricatio/actions/article.py +44 -0
- fabricatio/capabilities/propose.py +55 -0
- fabricatio/capabilities/rag.py +129 -44
- fabricatio/capabilities/rating.py +12 -36
- fabricatio/capabilities/task.py +6 -23
- fabricatio/config.py +37 -2
- fabricatio/models/action.py +3 -3
- fabricatio/models/events.py +36 -0
- fabricatio/models/extra.py +96 -0
- fabricatio/models/generic.py +194 -7
- fabricatio/models/kwargs_types.py +14 -0
- fabricatio/models/task.py +5 -23
- fabricatio/models/usages.py +117 -184
- fabricatio/models/utils.py +19 -0
- fabricatio/parser.py +35 -8
- fabricatio-0.2.4.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/METADATA +66 -178
- fabricatio-0.2.4.dev1.dist-info/RECORD +38 -0
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio-0.2.3.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.3.dev3.dist-info/RECORD +0 -37
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev1.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/usages.py
CHANGED
@@ -1,19 +1,18 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack, overload
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
8
|
-
import orjson
|
9
8
|
from fabricatio._rust_instances import template_manager
|
10
9
|
from fabricatio.config import configs
|
11
10
|
from fabricatio.journal import logger
|
12
|
-
from fabricatio.models.generic import
|
11
|
+
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
13
12
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
|
14
13
|
from fabricatio.models.task import Task
|
15
14
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.models.utils import Messages
|
15
|
+
from fabricatio.models.utils import Messages
|
17
16
|
from fabricatio.parser import JsonCapture
|
18
17
|
from litellm import stream_chunk_builder
|
19
18
|
from litellm.types.utils import (
|
@@ -23,135 +22,16 @@ from litellm.types.utils import (
|
|
23
22
|
StreamingChoices,
|
24
23
|
)
|
25
24
|
from litellm.utils import CustomStreamWrapper
|
26
|
-
from
|
25
|
+
from more_itertools import duplicates_everseen
|
26
|
+
from pydantic import Field, NonNegativeInt, PositiveInt
|
27
27
|
|
28
28
|
|
29
|
-
class LLMUsage(
|
29
|
+
class LLMUsage(ScopedConfig):
|
30
30
|
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
llm_api_key: Optional[SecretStr] = None
|
36
|
-
"""The OpenAI API key."""
|
37
|
-
|
38
|
-
llm_timeout: Optional[PositiveInt] = None
|
39
|
-
"""The timeout of the LLM model."""
|
40
|
-
|
41
|
-
llm_max_retries: Optional[PositiveInt] = None
|
42
|
-
"""The maximum number of retries."""
|
43
|
-
|
44
|
-
llm_model: Optional[str] = None
|
45
|
-
"""The LLM model name."""
|
46
|
-
|
47
|
-
llm_temperature: Optional[NonNegativeFloat] = None
|
48
|
-
"""The temperature of the LLM model."""
|
49
|
-
|
50
|
-
llm_stop_sign: Optional[str | List[str]] = None
|
51
|
-
"""The stop sign of the LLM model."""
|
52
|
-
|
53
|
-
llm_top_p: Optional[NonNegativeFloat] = None
|
54
|
-
"""The top p of the LLM model."""
|
55
|
-
|
56
|
-
llm_generation_count: Optional[PositiveInt] = None
|
57
|
-
"""The number of generations to generate."""
|
58
|
-
|
59
|
-
llm_stream: Optional[bool] = None
|
60
|
-
"""Whether to stream the LLM model's response."""
|
61
|
-
|
62
|
-
llm_max_tokens: Optional[PositiveInt] = None
|
63
|
-
"""The maximum number of tokens to generate."""
|
64
|
-
|
65
|
-
async def aembedding(
|
66
|
-
self,
|
67
|
-
input_text: List[str],
|
68
|
-
model: Optional[str] = None,
|
69
|
-
dimensions: Optional[int] = None,
|
70
|
-
timeout: Optional[PositiveInt] = None,
|
71
|
-
caching: Optional[bool] = False,
|
72
|
-
) -> EmbeddingResponse:
|
73
|
-
"""Asynchronously generates embeddings for the given input text.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
input_text (List[str]): A list of strings to generate embeddings for.
|
77
|
-
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
78
|
-
dimensions (Optional[int]): The dimensions of the embedding. Defaults to None.
|
79
|
-
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
80
|
-
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
81
|
-
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
EmbeddingResponse: The response containing the embeddings.
|
85
|
-
"""
|
86
|
-
return await litellm.aembedding(
|
87
|
-
input=input_text,
|
88
|
-
caching=caching,
|
89
|
-
dimensions=dimensions,
|
90
|
-
model=model or self.llm_model or configs.llm.model,
|
91
|
-
timeout=timeout or self.llm_timeout or configs.llm.timeout,
|
92
|
-
api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
|
93
|
-
api_base=self.llm_api_endpoint.unicode_string().rstrip(
|
94
|
-
"/"
|
95
|
-
) # seems embedding function takes no base_url end with a slash
|
96
|
-
if self.llm_api_endpoint
|
97
|
-
else configs.llm.api_endpoint.unicode_string().rstrip("/"),
|
98
|
-
)
|
99
|
-
|
100
|
-
@overload
|
101
|
-
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
102
|
-
@overload
|
103
|
-
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
104
|
-
|
105
|
-
async def vectorize(
|
106
|
-
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
107
|
-
) -> List[List[float]] | List[float]:
|
108
|
-
"""Asynchronously generates vector embeddings for the given input text.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
112
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
113
|
-
|
114
|
-
Returns:
|
115
|
-
List[List[float]] | List[float]: The generated embeddings.
|
116
|
-
"""
|
117
|
-
if isinstance(input_text, str):
|
118
|
-
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
119
|
-
|
120
|
-
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
121
|
-
|
122
|
-
@overload
|
123
|
-
async def pack(
|
124
|
-
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
125
|
-
) -> List[MilvusData]: ...
|
126
|
-
@overload
|
127
|
-
async def pack(
|
128
|
-
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
129
|
-
) -> MilvusData: ...
|
130
|
-
|
131
|
-
async def pack(
|
132
|
-
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
133
|
-
) -> List[MilvusData] | MilvusData:
|
134
|
-
"""Asynchronously generates MilvusData objects for the given input text.
|
135
|
-
|
136
|
-
Args:
|
137
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
138
|
-
subject (Optional[str]): The subject of the input text. Defaults to None.
|
139
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
143
|
-
"""
|
144
|
-
if isinstance(input_text, str):
|
145
|
-
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
146
|
-
vecs = await self.vectorize(input_text, **kwargs)
|
147
|
-
return [
|
148
|
-
MilvusData(
|
149
|
-
vector=vec,
|
150
|
-
text=text,
|
151
|
-
subject=subject,
|
152
|
-
)
|
153
|
-
for text, vec in zip(input_text, vecs, strict=True)
|
154
|
-
]
|
32
|
+
@classmethod
|
33
|
+
def _scoped_model(cls) -> Type["LLMUsage"]:
|
34
|
+
return LLMUsage
|
155
35
|
|
156
36
|
async def aquery(
|
157
37
|
self,
|
@@ -181,10 +61,8 @@ class LLMUsage(Base):
|
|
181
61
|
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
182
62
|
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
183
63
|
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
184
|
-
api_key=
|
185
|
-
base_url=self.llm_api_endpoint.unicode_string()
|
186
|
-
if self.llm_api_endpoint
|
187
|
-
else configs.llm.api_endpoint.unicode_string(),
|
64
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
65
|
+
base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
188
66
|
)
|
189
67
|
|
190
68
|
async def ainvoke(
|
@@ -213,13 +91,13 @@ class LLMUsage(Base):
|
|
213
91
|
if isinstance(resp, ModelResponse):
|
214
92
|
return resp.choices
|
215
93
|
if isinstance(resp, CustomStreamWrapper):
|
216
|
-
if configs.debug.streaming_visible:
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
return stream_chunk_builder(
|
94
|
+
if not configs.debug.streaming_visible:
|
95
|
+
return stream_chunk_builder(await asyncstdlib.list()).choices
|
96
|
+
chunks = []
|
97
|
+
async for chunk in resp:
|
98
|
+
chunks.append(chunk)
|
99
|
+
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
100
|
+
return stream_chunk_builder(chunks).choices
|
223
101
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
224
102
|
raise ValueError(err)
|
225
103
|
|
@@ -334,11 +212,10 @@ class LLMUsage(Base):
|
|
334
212
|
**kwargs,
|
335
213
|
)
|
336
214
|
) and (validated := validator(response)):
|
337
|
-
logger.debug(f"Successfully validated the response at {i}th attempt.
|
215
|
+
logger.debug(f"Successfully validated the response at {i}th attempt.")
|
338
216
|
return validated
|
339
|
-
|
340
|
-
|
341
|
-
raise ValueError("Failed to validate the response.")
|
217
|
+
logger.error(err := f"Failed to validate the response after {max_validations} attempts.")
|
218
|
+
raise ValueError(err)
|
342
219
|
|
343
220
|
async def aask_validate_batch[T](
|
344
221
|
self,
|
@@ -361,6 +238,26 @@ class LLMUsage(Base):
|
|
361
238
|
"""
|
362
239
|
return await gather(*[self.aask_validate(question, validator, **kwargs) for question in questions])
|
363
240
|
|
241
|
+
async def aliststr(self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[GenerateKwargs]) -> List[str]:
|
242
|
+
"""Asynchronously generates a list of strings based on a given requirement.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
requirement (str): The requirement for the list of strings.
|
246
|
+
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
247
|
+
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
List[str]: The validated response as a list of strings.
|
251
|
+
"""
|
252
|
+
return await self.aask_validate(
|
253
|
+
template_manager.render_template(
|
254
|
+
configs.templates.liststr_template,
|
255
|
+
{"requirement": requirement, "k": k},
|
256
|
+
),
|
257
|
+
lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
|
258
|
+
**kwargs,
|
259
|
+
)
|
260
|
+
|
364
261
|
async def achoose[T: WithBriefing](
|
365
262
|
self,
|
366
263
|
instruction: str,
|
@@ -384,28 +281,28 @@ class LLMUsage(Base):
|
|
384
281
|
- Ensures response compliance through JSON parsing and format validation.
|
385
282
|
- Relies on `aask_validate` to implement retry mechanisms with validation.
|
386
283
|
"""
|
284
|
+
if dup := duplicates_everseen(choices, key=lambda x: x.name):
|
285
|
+
logger.error(err := f"Redundant choices: {dup}")
|
286
|
+
raise ValueError(err)
|
387
287
|
prompt = template_manager.render_template(
|
388
288
|
configs.templates.make_choice_template,
|
389
289
|
{
|
390
290
|
"instruction": instruction,
|
391
|
-
"options": [{"name"
|
291
|
+
"options": [m.model_dump(include={"name", "briefing"}) for m in choices],
|
392
292
|
"k": k,
|
393
293
|
},
|
394
294
|
)
|
395
295
|
names = {c.name for c in choices}
|
296
|
+
|
396
297
|
logger.debug(f"Start choosing between {names} with prompt: \n{prompt}")
|
397
298
|
|
398
299
|
def _validate(response: str) -> List[T] | None:
|
399
|
-
ret = JsonCapture.
|
400
|
-
|
401
|
-
if not isinstance(ret, List) or (0 < k != len(ret)):
|
402
|
-
logger.error(f"Incorrect Type or length of response: \n{ret}")
|
300
|
+
ret = JsonCapture.validate_with(response, target_type=List, elements_type=str, length=k)
|
301
|
+
if ret is None or set(ret) - names:
|
403
302
|
return None
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
return [next(toolbox for toolbox in choices if toolbox.name == toolbox_str) for toolbox_str in ret]
|
303
|
+
return [
|
304
|
+
next(candidate for candidate in choices if candidate.name == candidate_name) for candidate_name in ret
|
305
|
+
]
|
409
306
|
|
410
307
|
return await self.aask_validate(
|
411
308
|
question=prompt,
|
@@ -459,55 +356,91 @@ class LLMUsage(Base):
|
|
459
356
|
Returns:
|
460
357
|
bool: The judgment result (True or False) based on the AI's response.
|
461
358
|
"""
|
462
|
-
|
463
|
-
def _validate(response: str) -> bool | None:
|
464
|
-
ret = JsonCapture.convert_with(response, orjson.loads)
|
465
|
-
if not isinstance(ret, bool):
|
466
|
-
return None
|
467
|
-
return ret
|
468
|
-
|
469
359
|
return await self.aask_validate(
|
470
360
|
question=template_manager.render_template(
|
471
361
|
configs.templates.make_judgment_template,
|
472
362
|
{"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
|
473
363
|
),
|
474
|
-
validator=
|
364
|
+
validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
|
475
365
|
**kwargs,
|
476
366
|
)
|
477
367
|
|
478
|
-
|
479
|
-
|
368
|
+
|
369
|
+
class EmbeddingUsage(LLMUsage):
|
370
|
+
"""A class representing the embedding model."""
|
371
|
+
|
372
|
+
async def aembedding(
|
373
|
+
self,
|
374
|
+
input_text: List[str],
|
375
|
+
model: Optional[str] = None,
|
376
|
+
dimensions: Optional[int] = None,
|
377
|
+
timeout: Optional[PositiveInt] = None,
|
378
|
+
caching: Optional[bool] = False,
|
379
|
+
) -> EmbeddingResponse:
|
380
|
+
"""Asynchronously generates embeddings for the given input text.
|
480
381
|
|
481
382
|
Args:
|
482
|
-
|
383
|
+
input_text (List[str]): A list of strings to generate embeddings for.
|
384
|
+
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
385
|
+
dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
|
386
|
+
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
387
|
+
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
388
|
+
|
483
389
|
|
484
390
|
Returns:
|
485
|
-
|
391
|
+
EmbeddingResponse: The response containing the embeddings.
|
486
392
|
"""
|
487
|
-
#
|
488
|
-
|
489
|
-
for
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
393
|
+
# check seq length
|
394
|
+
max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
|
395
|
+
if any(len(t) > max_len for t in input_text):
|
396
|
+
logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
|
397
|
+
raise ValueError(err)
|
398
|
+
|
399
|
+
return await litellm.aembedding(
|
400
|
+
input=input_text,
|
401
|
+
caching=caching or self.embedding_caching or configs.embedding.caching,
|
402
|
+
dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
|
403
|
+
model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
|
404
|
+
timeout=timeout
|
405
|
+
or self.embedding_timeout
|
406
|
+
or configs.embedding.timeout
|
407
|
+
or self.llm_timeout
|
408
|
+
or configs.llm.timeout,
|
409
|
+
api_key=(
|
410
|
+
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
411
|
+
).get_secret_value(),
|
412
|
+
api_base=(
|
413
|
+
self.embedding_api_endpoint
|
414
|
+
or configs.embedding.api_endpoint
|
415
|
+
or self.llm_api_endpoint
|
416
|
+
or configs.llm.api_endpoint
|
417
|
+
)
|
418
|
+
.unicode_string()
|
419
|
+
.rstrip("/"),
|
420
|
+
# seems embedding function takes no base_url end with a slash
|
421
|
+
)
|
422
|
+
|
423
|
+
@overload
|
424
|
+
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
425
|
+
@overload
|
426
|
+
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
496
427
|
|
497
|
-
def
|
498
|
-
|
428
|
+
async def vectorize(
|
429
|
+
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
430
|
+
) -> List[List[float]] | List[float]:
|
431
|
+
"""Asynchronously generates vector embeddings for the given input text.
|
499
432
|
|
500
433
|
Args:
|
501
|
-
|
434
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
435
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
502
436
|
|
503
437
|
Returns:
|
504
|
-
|
438
|
+
List[List[float]] | List[float]: The generated embeddings.
|
505
439
|
"""
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
setattr(other, attr_name, attr)
|
440
|
+
if isinstance(input_text, str):
|
441
|
+
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
442
|
+
|
443
|
+
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
511
444
|
|
512
445
|
|
513
446
|
class ToolBoxUsage(LLMUsage):
|
fabricatio/models/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""A module containing utility classes for the models."""
|
2
2
|
|
3
|
+
from enum import Enum
|
3
4
|
from typing import Any, Dict, List, Literal, Optional, Self
|
4
5
|
|
5
6
|
from pydantic import BaseModel, ConfigDict, Field
|
@@ -125,3 +126,21 @@ class MilvusData(BaseModel):
|
|
125
126
|
"""
|
126
127
|
self.id = new_id
|
127
128
|
return self
|
129
|
+
|
130
|
+
|
131
|
+
class TaskStatus(Enum):
|
132
|
+
"""An enumeration representing the status of a task.
|
133
|
+
|
134
|
+
Attributes:
|
135
|
+
Pending: The task is pending.
|
136
|
+
Running: The task is currently running.
|
137
|
+
Finished: The task has been successfully completed.
|
138
|
+
Failed: The task has failed.
|
139
|
+
Cancelled: The task has been cancelled.
|
140
|
+
"""
|
141
|
+
|
142
|
+
Pending = "pending"
|
143
|
+
Running = "running"
|
144
|
+
Finished = "finished"
|
145
|
+
Failed = "failed"
|
146
|
+
Cancelled = "cancelled"
|
fabricatio/parser.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
"""A module to parse text using regular expressions."""
|
2
2
|
|
3
|
-
from typing import Any, Callable, Self, Tuple
|
3
|
+
from typing import Any, Callable, Optional, Self, Tuple, Type
|
4
4
|
|
5
|
+
import orjson
|
5
6
|
import regex
|
6
|
-
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
7
8
|
from regex import Pattern, compile
|
8
9
|
|
9
10
|
from fabricatio.journal import logger
|
@@ -27,11 +28,7 @@ class Capture(BaseModel):
|
|
27
28
|
_compiled: Pattern = PrivateAttr()
|
28
29
|
|
29
30
|
def model_post_init(self, __context: Any) -> None:
|
30
|
-
"""Initialize the compiled
|
31
|
-
|
32
|
-
Args:
|
33
|
-
__context (Any): The context in which the model is initialized.
|
34
|
-
"""
|
31
|
+
"""Initialize the compiled pattern."""
|
35
32
|
self._compiled = compile(self.pattern, self.flags)
|
36
33
|
|
37
34
|
def capture(self, text: str) -> Tuple[str, ...] | str | None:
|
@@ -70,10 +67,40 @@ class Capture(BaseModel):
|
|
70
67
|
return None
|
71
68
|
try:
|
72
69
|
return convertor(cap)
|
73
|
-
except (ValueError, SyntaxError) as e:
|
70
|
+
except (ValueError, SyntaxError, ValidationError) as e:
|
74
71
|
logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
|
75
72
|
return None
|
76
73
|
|
74
|
+
def validate_with[K, T, E](
|
75
|
+
self,
|
76
|
+
text: str,
|
77
|
+
target_type: Type[T],
|
78
|
+
elements_type: Optional[Type[E]] = None,
|
79
|
+
length: Optional[int] = None,
|
80
|
+
deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = orjson.loads,
|
81
|
+
) -> T | None:
|
82
|
+
"""Validate the given text using the pattern.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
text (str): The text to search the pattern in.
|
86
|
+
target_type (Type[T]): The expected type of the output, dict or list.
|
87
|
+
elements_type (Optional[Type[E]]): The expected type of the elements in the output dict keys or list elements.
|
88
|
+
length (Optional[int]): The expected length of the output, bool(length)==False means no length validation.
|
89
|
+
deserializer (Callable[[Tuple[str, ...]], K] | Callable[[str], K]): The function to deserialize the captured text.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
T | None: The validated text if the pattern is found and the output is of the expected type, otherwise None.
|
93
|
+
"""
|
94
|
+
judges = [lambda output_obj: isinstance(output_obj, target_type)]
|
95
|
+
if elements_type:
|
96
|
+
judges.append(lambda output_obj: all(isinstance(e, elements_type) for e in output_obj))
|
97
|
+
if length:
|
98
|
+
judges.append(lambda output_obj: len(output_obj) == length)
|
99
|
+
|
100
|
+
if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
|
101
|
+
return out
|
102
|
+
return None
|
103
|
+
|
77
104
|
@classmethod
|
78
105
|
def capture_code_block(cls, language: str) -> Self:
|
79
106
|
"""Capture the first occurrence of a code block in the given text.
|
Binary file
|