fabricatio 0.2.13.dev3__cp312-cp312-win_amd64.whl → 0.3.14__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +7 -14
- fabricatio/actions/article.py +58 -23
- fabricatio/actions/article_rag.py +6 -15
- fabricatio/actions/output.py +38 -3
- fabricatio/actions/rag.py +4 -4
- fabricatio/capabilities/advanced_judge.py +4 -7
- fabricatio/capabilities/advanced_rag.py +2 -1
- fabricatio/capabilities/censor.py +5 -4
- fabricatio/capabilities/check.py +6 -7
- fabricatio/capabilities/correct.py +5 -5
- fabricatio/capabilities/extract.py +7 -3
- fabricatio/capabilities/persist.py +103 -0
- fabricatio/capabilities/propose.py +2 -2
- fabricatio/capabilities/rag.py +43 -43
- fabricatio/capabilities/rating.py +11 -10
- fabricatio/capabilities/review.py +8 -6
- fabricatio/capabilities/task.py +22 -22
- fabricatio/decorators.py +4 -2
- fabricatio/{core.py → emitter.py} +35 -39
- fabricatio/fs/__init__.py +1 -2
- fabricatio/journal.py +2 -11
- fabricatio/models/action.py +14 -30
- fabricatio/models/extra/aricle_rag.py +14 -8
- fabricatio/models/extra/article_base.py +56 -25
- fabricatio/models/extra/article_essence.py +2 -1
- fabricatio/models/extra/article_main.py +16 -13
- fabricatio/models/extra/article_outline.py +2 -1
- fabricatio/models/extra/article_proposal.py +1 -1
- fabricatio/models/extra/rag.py +2 -2
- fabricatio/models/extra/rule.py +2 -1
- fabricatio/models/generic.py +56 -166
- fabricatio/models/kwargs_types.py +1 -54
- fabricatio/models/role.py +49 -26
- fabricatio/models/task.py +8 -9
- fabricatio/models/tool.py +7 -7
- fabricatio/models/usages.py +67 -61
- fabricatio/parser.py +60 -100
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +469 -74
- fabricatio/utils.py +63 -162
- fabricatio-0.3.14.data/scripts/tdown.exe +0 -0
- fabricatio-0.3.14.data/scripts/ttm.exe +0 -0
- {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/METADATA +10 -15
- fabricatio-0.3.14.dist-info/RECORD +64 -0
- {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/WHEEL +1 -1
- fabricatio/config.py +0 -430
- fabricatio/constants.py +0 -20
- fabricatio/models/events.py +0 -120
- fabricatio/rust_instances.py +0 -10
- fabricatio-0.2.13.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.13.dev3.data/scripts/ttm.exe +0 -0
- fabricatio-0.2.13.dev3.dist-info/RECORD +0 -67
- {fabricatio-0.2.13.dev3.dist-info → fabricatio-0.3.14.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/task.py
CHANGED
@@ -4,17 +4,16 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from asyncio import Queue
|
7
|
-
from typing import Any, Dict, List, Optional, Self
|
7
|
+
from typing import Any, Dict, List, Optional, Self, Union
|
8
8
|
|
9
|
-
from fabricatio.
|
10
|
-
from fabricatio.constants import TaskStatus
|
11
|
-
from fabricatio.core import env
|
9
|
+
from fabricatio.emitter import env
|
12
10
|
from fabricatio.journal import logger
|
13
|
-
from fabricatio.models.events import Event, EventLike
|
14
11
|
from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
|
15
|
-
from fabricatio.
|
12
|
+
from fabricatio.rust import CONFIG, TEMPLATE_MANAGER, Event, TaskStatus
|
16
13
|
from pydantic import Field, PrivateAttr
|
17
14
|
|
15
|
+
type EventLike = Union[str, Event, List[str]]
|
16
|
+
|
18
17
|
|
19
18
|
class Task[T](WithBriefing, ProposedAble, WithDependency):
|
20
19
|
"""A class representing a task with a status and output.
|
@@ -33,7 +32,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
33
32
|
description: str = Field(default="")
|
34
33
|
"""A detailed explanation of the task that includes all necessary information. Should be clear and answer what, why, when, where, who, and how questions."""
|
35
34
|
|
36
|
-
goals: List[str] = Field(
|
35
|
+
goals: List[str] = Field(default_factory=list)
|
37
36
|
"""A list of objectives that the task aims to accomplish. Each goal should be clear and specific. Complex tasks should be broken into multiple smaller goals."""
|
38
37
|
|
39
38
|
namespace: List[str] = Field(default_factory=list)
|
@@ -65,7 +64,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
65
64
|
|
66
65
|
def model_post_init(self, __context: Any) -> None:
|
67
66
|
"""Initialize the task with a namespace event."""
|
68
|
-
self._namespace.
|
67
|
+
self._namespace.concat(self.namespace)
|
69
68
|
|
70
69
|
def move_to(self, new_namespace: EventLike) -> Self:
|
71
70
|
"""Move the task to a new namespace.
|
@@ -266,7 +265,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
266
265
|
str: The briefing of the task.
|
267
266
|
"""
|
268
267
|
return TEMPLATE_MANAGER.render_template(
|
269
|
-
|
268
|
+
CONFIG.templates.task_briefing_template,
|
270
269
|
self.model_dump(),
|
271
270
|
)
|
272
271
|
|
fabricatio/models/tool.py
CHANGED
@@ -10,11 +10,11 @@ from inspect import iscoroutinefunction, signature
|
|
10
10
|
from types import CodeType, ModuleType
|
11
11
|
from typing import Any, Callable, Dict, List, Optional, Self, cast, overload
|
12
12
|
|
13
|
-
from fabricatio.config import configs
|
14
13
|
from fabricatio.decorators import logging_execution_info, use_temp_module
|
15
14
|
from fabricatio.journal import logger
|
16
|
-
from fabricatio.models.generic import WithBriefing
|
17
|
-
from
|
15
|
+
from fabricatio.models.generic import Base, WithBriefing
|
16
|
+
from fabricatio.rust import CONFIG
|
17
|
+
from pydantic import Field
|
18
18
|
|
19
19
|
|
20
20
|
class Tool[**P, R](WithBriefing):
|
@@ -181,7 +181,7 @@ class ToolBox(WithBriefing):
|
|
181
181
|
return hash(self.briefing)
|
182
182
|
|
183
183
|
|
184
|
-
class ToolExecutor(
|
184
|
+
class ToolExecutor(Base):
|
185
185
|
"""A class representing a tool executor with a sequence of tools to execute.
|
186
186
|
|
187
187
|
This class manages a sequence of tools and provides methods to inject tools and data into a module, execute the tools,
|
@@ -191,7 +191,7 @@ class ToolExecutor(BaseModel):
|
|
191
191
|
candidates (List[Tool]): The sequence of tools to execute.
|
192
192
|
data (Dict[str, Any]): The data that could be used when invoking the tools.
|
193
193
|
"""
|
194
|
-
|
194
|
+
|
195
195
|
candidates: List[Tool] = Field(default_factory=list, frozen=True)
|
196
196
|
"""The sequence of tools to execute."""
|
197
197
|
|
@@ -210,7 +210,7 @@ class ToolExecutor(BaseModel):
|
|
210
210
|
M: The module with injected tools.
|
211
211
|
"""
|
212
212
|
module = module or cast(
|
213
|
-
"M", module_from_spec(spec=ModuleSpec(name=
|
213
|
+
"M", module_from_spec(spec=ModuleSpec(name=CONFIG.toolbox.tool_module_name, loader=None))
|
214
214
|
)
|
215
215
|
for tool in self.candidates:
|
216
216
|
logger.debug(f"Injecting tool: {tool.name}")
|
@@ -229,7 +229,7 @@ class ToolExecutor(BaseModel):
|
|
229
229
|
M: The module with injected data.
|
230
230
|
"""
|
231
231
|
module = module or cast(
|
232
|
-
|
232
|
+
"M", module_from_spec(spec=ModuleSpec(name=CONFIG.toolbox.data_module_name, loader=None))
|
233
233
|
)
|
234
234
|
for key, value in self.data.items():
|
235
235
|
logger.debug(f"Injecting data: {key}")
|
fabricatio/models/usages.py
CHANGED
@@ -1,22 +1,25 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
import traceback
|
4
|
+
from abc import ABC
|
4
5
|
from asyncio import gather
|
5
6
|
from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
|
6
7
|
|
7
8
|
import asyncstdlib
|
8
|
-
import litellm
|
9
|
-
from fabricatio.config import configs
|
10
9
|
from fabricatio.decorators import logging_exec_time
|
11
10
|
from fabricatio.journal import logger
|
12
11
|
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
13
12
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
14
13
|
from fabricatio.models.task import Task
|
15
14
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.
|
17
|
-
from fabricatio.
|
18
|
-
from
|
19
|
-
|
15
|
+
from fabricatio.rust import CONFIG, TEMPLATE_MANAGER
|
16
|
+
from fabricatio.utils import first_available, ok
|
17
|
+
from litellm import ( # pyright: ignore [reportPrivateImportUsage]
|
18
|
+
RateLimitError,
|
19
|
+
Router,
|
20
|
+
aembedding,
|
21
|
+
stream_chunk_builder,
|
22
|
+
)
|
20
23
|
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
21
24
|
from litellm.types.utils import (
|
22
25
|
Choices,
|
@@ -29,22 +32,16 @@ from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore
|
|
29
32
|
from more_itertools import duplicates_everseen
|
30
33
|
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
31
34
|
|
32
|
-
if configs.cache.enabled and configs.cache.type:
|
33
|
-
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
34
|
-
logger.debug(f"{configs.cache.type.name} Cache enabled")
|
35
|
-
|
36
35
|
ROUTER = Router(
|
37
36
|
routing_strategy="usage-based-routing-v2",
|
38
|
-
default_max_parallel_requests=
|
39
|
-
allowed_fails=
|
40
|
-
retry_after=
|
41
|
-
cooldown_time=
|
42
|
-
cache_responses=configs.cache.enabled,
|
43
|
-
cache_kwargs=configs.cache.params,
|
37
|
+
default_max_parallel_requests=CONFIG.routing.max_parallel_requests,
|
38
|
+
allowed_fails=CONFIG.routing.allowed_fails,
|
39
|
+
retry_after=CONFIG.routing.retry_after,
|
40
|
+
cooldown_time=CONFIG.routing.cooldown_time,
|
44
41
|
)
|
45
42
|
|
46
43
|
|
47
|
-
class LLMUsage(ScopedConfig):
|
44
|
+
class LLMUsage(ScopedConfig, ABC):
|
48
45
|
"""Class that manages LLM (Large Language Model) usage parameters and methods.
|
49
46
|
|
50
47
|
This class provides methods to deploy LLMs, query them for responses, and handle various configurations
|
@@ -86,48 +83,48 @@ class LLMUsage(ScopedConfig):
|
|
86
83
|
Deployment(
|
87
84
|
model_name=(
|
88
85
|
m_name := ok(
|
89
|
-
kwargs.get("model") or self.llm_model or
|
86
|
+
kwargs.get("model") or self.llm_model or CONFIG.llm.model, "model name is not set at any place"
|
90
87
|
)
|
91
88
|
), # pyright: ignore [reportCallIssue]
|
92
89
|
litellm_params=(
|
93
90
|
p := LiteLLM_Params(
|
94
91
|
api_key=ok(
|
95
|
-
self.llm_api_key or
|
92
|
+
self.llm_api_key or CONFIG.llm.api_key, "llm api key is not set at any place"
|
96
93
|
).get_secret_value(),
|
97
94
|
api_base=ok(
|
98
|
-
self.llm_api_endpoint or
|
95
|
+
self.llm_api_endpoint or CONFIG.llm.api_endpoint,
|
99
96
|
"llm api endpoint is not set at any place",
|
100
|
-
)
|
97
|
+
),
|
101
98
|
model=m_name,
|
102
|
-
tpm=self.llm_tpm or
|
103
|
-
rpm=self.llm_rpm or
|
104
|
-
max_retries=kwargs.get("max_retries") or self.llm_max_retries or
|
105
|
-
timeout=kwargs.get("timeout") or self.llm_timeout or
|
99
|
+
tpm=self.llm_tpm or CONFIG.llm.tpm,
|
100
|
+
rpm=self.llm_rpm or CONFIG.llm.rpm,
|
101
|
+
max_retries=kwargs.get("max_retries") or self.llm_max_retries or CONFIG.llm.max_retries,
|
102
|
+
timeout=kwargs.get("timeout") or self.llm_timeout or CONFIG.llm.timeout,
|
106
103
|
)
|
107
104
|
),
|
108
105
|
model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
|
109
106
|
)
|
110
107
|
).acompletion(
|
111
108
|
messages=messages, # pyright: ignore [reportArgumentType]
|
112
|
-
n=n or self.llm_generation_count or
|
109
|
+
n=n or self.llm_generation_count or CONFIG.llm.generation_count,
|
113
110
|
model=m_name,
|
114
|
-
temperature=kwargs.get("temperature") or self.llm_temperature or
|
115
|
-
stop=kwargs.get("stop") or self.llm_stop_sign or
|
116
|
-
top_p=kwargs.get("top_p") or self.llm_top_p or
|
117
|
-
max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or
|
118
|
-
stream=
|
111
|
+
temperature=kwargs.get("temperature") or self.llm_temperature or CONFIG.llm.temperature,
|
112
|
+
stop=kwargs.get("stop") or self.llm_stop_sign or CONFIG.llm.stop_sign,
|
113
|
+
top_p=kwargs.get("top_p") or self.llm_top_p or CONFIG.llm.top_p,
|
114
|
+
max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or CONFIG.llm.max_tokens,
|
115
|
+
stream=first_available(
|
116
|
+
(kwargs.get("stream"), self.llm_stream, CONFIG.llm.stream), "stream is not set at any place"
|
117
|
+
),
|
119
118
|
cache={
|
120
119
|
"no-cache": kwargs.get("no_cache"),
|
121
120
|
"no-store": kwargs.get("no_store"),
|
122
121
|
"cache-ttl": kwargs.get("cache_ttl"),
|
123
122
|
"s-maxage": kwargs.get("s_maxage"),
|
124
123
|
},
|
125
|
-
presence_penalty=kwargs.get("presence_penalty")
|
126
|
-
or self.llm_presence_penalty
|
127
|
-
or configs.llm.presence_penalty,
|
124
|
+
presence_penalty=kwargs.get("presence_penalty") or self.llm_presence_penalty or CONFIG.llm.presence_penalty,
|
128
125
|
frequency_penalty=kwargs.get("frequency_penalty")
|
129
126
|
or self.llm_frequency_penalty
|
130
|
-
or
|
127
|
+
or CONFIG.llm.frequency_penalty,
|
131
128
|
)
|
132
129
|
|
133
130
|
async def ainvoke(
|
@@ -155,11 +152,8 @@ class LLMUsage(ScopedConfig):
|
|
155
152
|
)
|
156
153
|
if isinstance(resp, ModelResponse):
|
157
154
|
return resp.choices
|
158
|
-
if isinstance(resp, CustomStreamWrapper):
|
159
|
-
|
160
|
-
return pack.choices
|
161
|
-
if pack := stream_chunk_builder(await asyncstdlib.list(resp)):
|
162
|
-
return pack.choices
|
155
|
+
if isinstance(resp, CustomStreamWrapper) and (pack := stream_chunk_builder(await asyncstdlib.list(resp))):
|
156
|
+
return pack.choices
|
163
157
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
164
158
|
raise ValueError(err)
|
165
159
|
|
@@ -170,6 +164,7 @@ class LLMUsage(ScopedConfig):
|
|
170
164
|
system_message: List[str],
|
171
165
|
**kwargs: Unpack[LLMKwargs],
|
172
166
|
) -> List[str]: ...
|
167
|
+
|
173
168
|
@overload
|
174
169
|
async def aask(
|
175
170
|
self,
|
@@ -177,6 +172,7 @@ class LLMUsage(ScopedConfig):
|
|
177
172
|
system_message: List[str],
|
178
173
|
**kwargs: Unpack[LLMKwargs],
|
179
174
|
) -> List[str]: ...
|
175
|
+
|
180
176
|
@overload
|
181
177
|
async def aask(
|
182
178
|
self,
|
@@ -231,7 +227,8 @@ class LLMUsage(ScopedConfig):
|
|
231
227
|
raise RuntimeError("Should not reach here.")
|
232
228
|
|
233
229
|
logger.debug(
|
234
|
-
f"Response Token Count: {token_counter(text=out) if isinstance(out, str) else sum(token_counter(text=o) for o in out)}"
|
230
|
+
f"Response Token Count: {token_counter(text=out) if isinstance(out, str) else sum(token_counter(text=o) for o in out)}"
|
231
|
+
# pyright: ignore [reportOptionalIterable]
|
235
232
|
)
|
236
233
|
return out # pyright: ignore [reportReturnType]
|
237
234
|
|
@@ -244,6 +241,7 @@ class LLMUsage(ScopedConfig):
|
|
244
241
|
max_validations: PositiveInt = 2,
|
245
242
|
**kwargs: Unpack[GenerateKwargs],
|
246
243
|
) -> T: ...
|
244
|
+
|
247
245
|
@overload
|
248
246
|
async def aask_validate[T](
|
249
247
|
self,
|
@@ -253,6 +251,7 @@ class LLMUsage(ScopedConfig):
|
|
253
251
|
max_validations: PositiveInt = 2,
|
254
252
|
**kwargs: Unpack[GenerateKwargs],
|
255
253
|
) -> List[T]: ...
|
254
|
+
|
256
255
|
@overload
|
257
256
|
async def aask_validate[T](
|
258
257
|
self,
|
@@ -331,9 +330,11 @@ class LLMUsage(ScopedConfig):
|
|
331
330
|
Returns:
|
332
331
|
Optional[List[str]]: The validated response as a list of strings.
|
333
332
|
"""
|
333
|
+
from fabricatio.parser import JsonCapture
|
334
|
+
|
334
335
|
return await self.aask_validate(
|
335
336
|
TEMPLATE_MANAGER.render_template(
|
336
|
-
|
337
|
+
CONFIG.templates.liststr_template,
|
337
338
|
{"requirement": requirement, "k": k},
|
338
339
|
),
|
339
340
|
lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
|
@@ -352,7 +353,7 @@ class LLMUsage(ScopedConfig):
|
|
352
353
|
"""
|
353
354
|
return await self.alist_str(
|
354
355
|
TEMPLATE_MANAGER.render_template(
|
355
|
-
|
356
|
+
CONFIG.templates.pathstr_template,
|
356
357
|
{"requirement": requirement},
|
357
358
|
),
|
358
359
|
**kwargs,
|
@@ -387,9 +388,11 @@ class LLMUsage(ScopedConfig):
|
|
387
388
|
Returns:
|
388
389
|
Optional[str]: The generated string.
|
389
390
|
"""
|
391
|
+
from fabricatio.parser import GenericCapture
|
392
|
+
|
390
393
|
return await self.aask_validate( # pyright: ignore [reportReturnType]
|
391
394
|
TEMPLATE_MANAGER.render_template(
|
392
|
-
|
395
|
+
CONFIG.templates.generic_string_template,
|
393
396
|
{"requirement": requirement, "language": GenericCapture.capture_type},
|
394
397
|
),
|
395
398
|
validator=lambda resp: GenericCapture.capture(resp),
|
@@ -414,11 +417,13 @@ class LLMUsage(ScopedConfig):
|
|
414
417
|
Returns:
|
415
418
|
Optional[List[T]]: The final validated selection result list, with element types matching the input `choices`.
|
416
419
|
"""
|
420
|
+
from fabricatio.parser import JsonCapture
|
421
|
+
|
417
422
|
if dup := duplicates_everseen(choices, key=lambda x: x.name):
|
418
423
|
logger.error(err := f"Redundant choices: {dup}")
|
419
424
|
raise ValueError(err)
|
420
425
|
prompt = TEMPLATE_MANAGER.render_template(
|
421
|
-
|
426
|
+
CONFIG.templates.make_choice_template,
|
422
427
|
{
|
423
428
|
"instruction": instruction,
|
424
429
|
"options": [m.model_dump(include={"name", "briefing"}) for m in choices],
|
@@ -489,9 +494,11 @@ class LLMUsage(ScopedConfig):
|
|
489
494
|
Returns:
|
490
495
|
bool: The judgment result (True or False) based on the AI's response.
|
491
496
|
"""
|
497
|
+
from fabricatio.parser import JsonCapture
|
498
|
+
|
492
499
|
return await self.aask_validate(
|
493
500
|
question=TEMPLATE_MANAGER.render_template(
|
494
|
-
|
501
|
+
CONFIG.templates.make_judgment_template,
|
495
502
|
{"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
|
496
503
|
),
|
497
504
|
validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
|
@@ -499,7 +506,7 @@ class LLMUsage(ScopedConfig):
|
|
499
506
|
)
|
500
507
|
|
501
508
|
|
502
|
-
class EmbeddingUsage(LLMUsage):
|
509
|
+
class EmbeddingUsage(LLMUsage, ABC):
|
503
510
|
"""A class representing the embedding model.
|
504
511
|
|
505
512
|
This class extends LLMUsage and provides methods to generate embeddings for input text using various models.
|
@@ -526,37 +533,36 @@ class EmbeddingUsage(LLMUsage):
|
|
526
533
|
EmbeddingResponse: The response containing the embeddings.
|
527
534
|
"""
|
528
535
|
# check seq length
|
529
|
-
max_len = self.embedding_max_sequence_length or
|
536
|
+
max_len = self.embedding_max_sequence_length or CONFIG.embedding.max_sequence_length
|
530
537
|
if max_len and any(length := (token_counter(text=t)) > max_len for t in input_text):
|
531
538
|
logger.error(err := f"Input text exceeds maximum sequence length {max_len}, got {length}.")
|
532
539
|
raise ValueError(err)
|
533
540
|
|
534
|
-
return await
|
541
|
+
return await aembedding(
|
535
542
|
input=input_text,
|
536
|
-
caching=caching or self.embedding_caching or
|
537
|
-
dimensions=dimensions or self.embedding_dimensions or
|
538
|
-
model=model or self.embedding_model or
|
543
|
+
caching=caching or self.embedding_caching or CONFIG.embedding.caching,
|
544
|
+
dimensions=dimensions or self.embedding_dimensions or CONFIG.embedding.dimensions,
|
545
|
+
model=model or self.embedding_model or CONFIG.embedding.model or self.llm_model or CONFIG.llm.model,
|
539
546
|
timeout=timeout
|
540
547
|
or self.embedding_timeout
|
541
|
-
or
|
548
|
+
or CONFIG.embedding.timeout
|
542
549
|
or self.llm_timeout
|
543
|
-
or
|
550
|
+
or CONFIG.llm.timeout,
|
544
551
|
api_key=ok(
|
545
|
-
self.embedding_api_key or
|
552
|
+
self.embedding_api_key or CONFIG.embedding.api_key or self.llm_api_key or CONFIG.llm.api_key
|
546
553
|
).get_secret_value(),
|
547
554
|
api_base=ok(
|
548
555
|
self.embedding_api_endpoint
|
549
|
-
or
|
556
|
+
or CONFIG.embedding.api_endpoint
|
550
557
|
or self.llm_api_endpoint
|
551
|
-
or
|
552
|
-
)
|
553
|
-
.unicode_string()
|
554
|
-
.rstrip("/"),
|
558
|
+
or CONFIG.llm.api_endpoint
|
559
|
+
).rstrip("/"),
|
555
560
|
# seems embedding function takes no base_url end with a slash
|
556
561
|
)
|
557
562
|
|
558
563
|
@overload
|
559
564
|
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
565
|
+
|
560
566
|
@overload
|
561
567
|
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
562
568
|
|
@@ -578,7 +584,7 @@ class EmbeddingUsage(LLMUsage):
|
|
578
584
|
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
579
585
|
|
580
586
|
|
581
|
-
class ToolBoxUsage(LLMUsage):
|
587
|
+
class ToolBoxUsage(LLMUsage, ABC):
|
582
588
|
"""A class representing the usage of tools in a task.
|
583
589
|
|
584
590
|
This class extends LLMUsage and provides methods to manage and use toolboxes and tools within tasks.
|
fabricatio/parser.py
CHANGED
@@ -1,152 +1,112 @@
|
|
1
|
-
"""A module
|
1
|
+
"""A module for capturing patterns in text using regular expressions."""
|
2
2
|
|
3
3
|
import re
|
4
|
+
from dataclasses import dataclass, field
|
4
5
|
from functools import lru_cache
|
5
|
-
from
|
6
|
-
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
|
6
|
+
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type, Union
|
7
7
|
|
8
8
|
import ujson
|
9
9
|
from json_repair import repair_json
|
10
|
-
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
11
10
|
|
12
|
-
from fabricatio.config import configs
|
13
11
|
from fabricatio.journal import logger
|
12
|
+
from fabricatio.rust import CONFIG
|
14
13
|
|
15
14
|
|
16
|
-
|
15
|
+
@dataclass(frozen=True)
|
16
|
+
class Capture:
|
17
17
|
"""A class to capture patterns in text using regular expressions.
|
18
18
|
|
19
19
|
Attributes:
|
20
|
-
|
21
|
-
|
20
|
+
target_groups (Tuple[int, ...]): The target groups to extract from the match.
|
21
|
+
pattern (str): The regex pattern to search for.
|
22
|
+
flags (int): Flags to apply when compiling the regex.
|
23
|
+
capture_type (Optional[str]): Optional hint for post-processing (e.g., 'json').
|
22
24
|
"""
|
23
25
|
|
24
|
-
|
25
|
-
target_groups: Tuple[int, ...] = Field(default_factory=tuple)
|
26
|
-
"""The target groups to capture from the pattern."""
|
27
|
-
pattern: str = Field(frozen=True)
|
26
|
+
pattern: str = field()
|
28
27
|
"""The regular expression pattern to search for."""
|
29
|
-
flags:
|
30
|
-
"""
|
28
|
+
flags: int = re.DOTALL | re.MULTILINE | re.IGNORECASE
|
29
|
+
"""Flags to control regex behavior (DOTALL, MULTILINE, IGNORECASE by default)."""
|
31
30
|
capture_type: Optional[str] = None
|
32
|
-
"""
|
33
|
-
|
31
|
+
"""Optional type identifier for post-processing (e.g., 'json' for JSON repair)."""
|
32
|
+
target_groups: Tuple[int, ...] = field(default_factory=tuple)
|
33
|
+
"""Tuple of group indices to extract from the match (1-based indexing)."""
|
34
34
|
|
35
|
-
def
|
36
|
-
"""
|
37
|
-
self._compiled = compile(self.pattern, self.flags)
|
38
|
-
|
39
|
-
def fix[T](self, text: str | Iterable[str] | T) -> str | List[str] | T:
|
40
|
-
"""Fix the text using the pattern.
|
41
|
-
|
42
|
-
Args:
|
43
|
-
text (str | List[str]): The text to fix.
|
44
|
-
|
45
|
-
Returns:
|
46
|
-
str | List[str]: The fixed text with the same type as input.
|
47
|
-
"""
|
35
|
+
def fix(self, text: Union[str, Iterable[str], Any]) -> Union[str, List[str], Any]:
|
36
|
+
"""Fix the text based on capture_type (e.g., JSON repair)."""
|
48
37
|
match self.capture_type:
|
49
|
-
case "json" if
|
50
|
-
logger.debug("Applying
|
38
|
+
case "json" if CONFIG.general.use_json_repair:
|
39
|
+
logger.debug("Applying JSON repair to text.")
|
51
40
|
if isinstance(text, str):
|
52
|
-
return repair_json(text, ensure_ascii=False)
|
53
|
-
return [repair_json(item, ensure_ascii=False) for item in
|
54
|
-
text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
41
|
+
return repair_json(text, ensure_ascii=False)
|
42
|
+
return [repair_json(item, ensure_ascii=False) for item in text]
|
55
43
|
case _:
|
56
|
-
return text
|
57
|
-
|
58
|
-
def capture(self, text: str) -> Tuple[str, ...]
|
59
|
-
"""Capture the first
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
Returns:
|
65
|
-
str | None: The captured text if the pattern is found, otherwise None.
|
66
|
-
|
67
|
-
"""
|
68
|
-
if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
|
69
|
-
logger.debug(f"Capture Failed {type(text)}: \n{text}")
|
44
|
+
return text
|
45
|
+
|
46
|
+
def capture(self, text: str) -> Optional[Union[str, Tuple[str, ...]]]:
|
47
|
+
"""Capture the first match of the pattern in the text."""
|
48
|
+
compiled = re.compile(self.pattern, self.flags)
|
49
|
+
match = compiled.match(text) or compiled.search(text)
|
50
|
+
if match is None:
|
51
|
+
logger.debug(f"Capture Failed: {text}")
|
70
52
|
return None
|
53
|
+
|
71
54
|
groups = self.fix(match.groups())
|
72
55
|
if self.target_groups:
|
73
56
|
cap = tuple(groups[g - 1] for g in self.target_groups)
|
74
|
-
logger.debug(f"Captured
|
57
|
+
logger.debug(f"Captured texts: {'\n==\n'.join(cap)}")
|
75
58
|
return cap
|
76
59
|
cap = groups[0]
|
77
60
|
logger.debug(f"Captured text: \n{cap}")
|
78
61
|
return cap
|
79
62
|
|
80
|
-
def convert_with
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
Returns:
|
88
|
-
str | None: The converted text if the pattern is found, otherwise None.
|
89
|
-
"""
|
63
|
+
def convert_with(
|
64
|
+
self,
|
65
|
+
text: str,
|
66
|
+
convertor: Callable[[Union[str, Tuple[str, ...]]], Any],
|
67
|
+
) -> Optional[Any]:
|
68
|
+
"""Convert captured text using a provided function."""
|
90
69
|
if (cap := self.capture(text)) is None:
|
91
70
|
return None
|
92
71
|
try:
|
93
|
-
return convertor(cap)
|
94
|
-
except
|
95
|
-
logger.error(f"Failed to convert text using {convertor.__name__}
|
72
|
+
return convertor(cap)
|
73
|
+
except Exception as e: # noqa: BLE001
|
74
|
+
logger.error(f"Failed to convert text using {convertor.__name__}: {e}\n{cap}")
|
96
75
|
return None
|
97
76
|
|
98
|
-
def validate_with[
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
) -> T
|
106
|
-
"""
|
107
|
-
|
108
|
-
Args:
|
109
|
-
text (str): The text to search the pattern in.
|
110
|
-
target_type (Type[T]): The expected type of the output, dict or list.
|
111
|
-
elements_type (Optional[Type[E]]): The expected type of the elements in the output dict keys or list elements.
|
112
|
-
length (Optional[int]): The expected length of the output, bool(length)==False means no length validation.
|
113
|
-
deserializer (Callable[[Tuple[str, ...]], K] | Callable[[str], K]): The function to deserialize the captured text.
|
114
|
-
|
115
|
-
Returns:
|
116
|
-
T | None: The validated text if the pattern is found and the output is of the expected type, otherwise None.
|
117
|
-
"""
|
118
|
-
judges = [lambda output_obj: isinstance(output_obj, target_type)]
|
77
|
+
def validate_with[T, K, E](
|
78
|
+
self,
|
79
|
+
text: str,
|
80
|
+
target_type: Type[T],
|
81
|
+
elements_type: Optional[Type[E]] = None,
|
82
|
+
length: Optional[int] = None,
|
83
|
+
deserializer: Callable[[Union[str, Tuple[str, ...]]], K] = lambda x: ujson.loads(x) if isinstance(x, str) else ujson.loads(x[0]),
|
84
|
+
) -> Optional[T]:
|
85
|
+
"""Deserialize and validate the captured text against expected types."""
|
86
|
+
judges = [lambda obj: isinstance(obj, target_type)]
|
119
87
|
if elements_type:
|
120
|
-
judges.append(lambda
|
88
|
+
judges.append(lambda obj: all(isinstance(e, elements_type) for e in obj))
|
121
89
|
if length:
|
122
|
-
judges.append(lambda
|
90
|
+
judges.append(lambda obj: len(obj) == length)
|
123
91
|
|
124
92
|
if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
|
125
|
-
return out #
|
93
|
+
return out # type: ignore
|
126
94
|
return None
|
127
95
|
|
128
96
|
@classmethod
|
129
97
|
@lru_cache(32)
|
130
98
|
def capture_code_block(cls, language: str) -> Self:
|
131
|
-
"""Capture
|
132
|
-
|
133
|
-
Args:
|
134
|
-
language (str): The text containing the code block.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
Self: The instance of the class with the captured code block.
|
138
|
-
"""
|
99
|
+
"""Capture a code block of the given language."""
|
139
100
|
return cls(pattern=f"```{language}(.*?)```", capture_type=language)
|
140
101
|
|
141
102
|
@classmethod
|
142
103
|
@lru_cache(32)
|
143
104
|
def capture_generic_block(cls, language: str) -> Self:
|
144
|
-
"""Capture
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
return cls(pattern=f"--- Start of {language} ---(.*?)--- end of {language} ---", capture_type=language)
|
105
|
+
"""Capture a generic block of the given language."""
|
106
|
+
return cls(
|
107
|
+
pattern=f"--- Start of {language} ---(.*?)--- End of {language} ---",
|
108
|
+
capture_type=language,
|
109
|
+
)
|
150
110
|
|
151
111
|
|
152
112
|
JsonCapture = Capture.capture_code_block("json")
|
Binary file
|