fabricatio 0.2.9.dev3__cp312-cp312-win_amd64.whl → 0.2.10.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +13 -113
- fabricatio/actions/article_rag.py +9 -2
- fabricatio/capabilities/check.py +15 -9
- fabricatio/capabilities/correct.py +5 -6
- fabricatio/capabilities/rag.py +39 -232
- fabricatio/capabilities/rating.py +46 -40
- fabricatio/config.py +2 -2
- fabricatio/constants.py +20 -0
- fabricatio/decorators.py +23 -0
- fabricatio/fs/readers.py +20 -1
- fabricatio/models/adv_kwargs_types.py +42 -0
- fabricatio/models/events.py +6 -6
- fabricatio/models/extra/advanced_judge.py +4 -4
- fabricatio/models/extra/article_base.py +25 -211
- fabricatio/models/extra/article_main.py +69 -95
- fabricatio/models/extra/article_proposal.py +15 -14
- fabricatio/models/extra/patches.py +6 -6
- fabricatio/models/extra/problem.py +12 -17
- fabricatio/models/extra/rag.py +72 -0
- fabricatio/models/extra/rule.py +1 -2
- fabricatio/models/generic.py +34 -10
- fabricatio/models/kwargs_types.py +1 -38
- fabricatio/models/task.py +3 -3
- fabricatio/models/usages.py +78 -8
- fabricatio/parser.py +5 -5
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +27 -12
- fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/METADATA +1 -1
- fabricatio-0.2.10.dev0.dist-info/RECORD +62 -0
- fabricatio/models/utils.py +0 -148
- fabricatio-0.2.9.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.9.dev3.dist-info/RECORD +0 -61
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,12 @@
|
|
1
1
|
"""A class representing a problem-solution pair identified during a review process."""
|
2
2
|
|
3
3
|
from itertools import chain
|
4
|
-
from typing import Any, List,
|
4
|
+
from typing import Any, List, Optional, Self, Tuple, Unpack
|
5
5
|
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.generic import SketchedAble, WithBriefing
|
8
8
|
from fabricatio.utils import ask_edit
|
9
|
+
from pydantic import Field
|
9
10
|
from questionary import Choice, checkbox, text
|
10
11
|
from rich import print as r_print
|
11
12
|
|
@@ -13,36 +14,30 @@ from rich import print as r_print
|
|
13
14
|
class Problem(SketchedAble, WithBriefing):
|
14
15
|
"""Represents a problem identified during review."""
|
15
16
|
|
16
|
-
description: str
|
17
|
-
"""
|
17
|
+
description: str = Field(alias="cause")
|
18
|
+
"""The cause of the problem, including the root cause, the context, and the impact, make detailed enough for engineer to understand the problem and its impact."""
|
18
19
|
|
19
|
-
|
20
|
-
"""Severity level of the problem."""
|
21
|
-
|
22
|
-
category: str
|
23
|
-
"""Category of the problem."""
|
20
|
+
severity_level: int = Field(ge=0, le=10)
|
21
|
+
"""Severity level of the problem, which is a number between 0 and 10, 0 means the problem is not severe, 10 means the problem is extremely severe."""
|
24
22
|
|
25
23
|
location: str
|
26
24
|
"""Location where the problem was identified."""
|
27
25
|
|
28
|
-
recommendation: str
|
29
|
-
"""Recommended solution or action."""
|
30
|
-
|
31
26
|
|
32
27
|
class Solution(SketchedAble, WithBriefing):
|
33
28
|
"""Represents a proposed solution to a problem."""
|
34
29
|
|
35
|
-
description: str
|
30
|
+
description: str = Field(alias="mechanism")
|
36
31
|
"""Description of the solution, including a detailed description of the execution steps, and the mechanics, principle or fact."""
|
37
32
|
|
38
33
|
execute_steps: List[str]
|
39
|
-
"""A list of steps to execute to implement the solution, which is expected to be able to finally solve the corresponding problem."""
|
34
|
+
"""A list of steps to execute to implement the solution, which is expected to be able to finally solve the corresponding problem, and which should be an Idiot-proof tutorial."""
|
40
35
|
|
41
|
-
|
42
|
-
"""Feasibility level of the solution."""
|
36
|
+
feasibility_level: int = Field(ge=0, le=10)
|
37
|
+
"""Feasibility level of the solution, which is a number between 0 and 10, 0 means the solution is not feasible, 10 means the solution is complete feasible."""
|
43
38
|
|
44
|
-
|
45
|
-
"""Impact level of the solution."""
|
39
|
+
impact_level: int = Field(ge=0, le=10)
|
40
|
+
"""Impact level of the solution, which is a number between 0 and 10, 0 means the solution is not impactful, 10 means the solution is extremely impactful."""
|
46
41
|
|
47
42
|
|
48
43
|
class ProblemSolutions(SketchedAble):
|
@@ -0,0 +1,72 @@
|
|
1
|
+
"""A module containing the RAG (Retrieval-Augmented Generation) models."""
|
2
|
+
|
3
|
+
from abc import ABCMeta, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Self, Sequence
|
5
|
+
|
6
|
+
from fabricatio.decorators import precheck_package
|
7
|
+
from pydantic import BaseModel, ConfigDict, JsonValue
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from importlib.util import find_spec
|
11
|
+
|
12
|
+
from pydantic.fields import FieldInfo
|
13
|
+
|
14
|
+
if find_spec("pymilvus"):
|
15
|
+
from pymilvus import CollectionSchema
|
16
|
+
|
17
|
+
|
18
|
+
class MilvusDataBase(BaseModel, metaclass=ABCMeta):
|
19
|
+
"""A base class for Milvus data."""
|
20
|
+
|
21
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
22
|
+
|
23
|
+
primary_field_name: ClassVar[str] = "id"
|
24
|
+
|
25
|
+
vector_field_name: ClassVar[str] = "vector"
|
26
|
+
|
27
|
+
def prepare_insertion(self, vector: List[float]) -> Dict[str, Any]:
|
28
|
+
"""Prepares the data for insertion into Milvus.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
dict: A dictionary containing the data to be inserted into Milvus.
|
32
|
+
"""
|
33
|
+
return {**self.model_dump(exclude_none=True, by_alias=True), self.vector_field_name: vector}
|
34
|
+
|
35
|
+
@property
|
36
|
+
@abstractmethod
|
37
|
+
def to_vectorize(self) -> str:
|
38
|
+
"""The text representation of the data."""
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
@precheck_package(
|
42
|
+
"pymilvus", "pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
43
|
+
)
|
44
|
+
def as_milvus_schema(cls, dimension: int = 1024) -> "CollectionSchema":
|
45
|
+
"""Generates the schema for Milvus collection."""
|
46
|
+
from pymilvus import CollectionSchema, DataType, FieldSchema
|
47
|
+
|
48
|
+
fields = [
|
49
|
+
FieldSchema(cls.primary_field_name, dtype=DataType.INT64, is_primary=True, auto_id=True),
|
50
|
+
FieldSchema(cls.vector_field_name, dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
51
|
+
]
|
52
|
+
|
53
|
+
type_mapping = {
|
54
|
+
str: DataType.STRING,
|
55
|
+
int: DataType.INT64,
|
56
|
+
float: DataType.DOUBLE,
|
57
|
+
JsonValue: DataType.JSON,
|
58
|
+
# TODO add more mapping
|
59
|
+
}
|
60
|
+
|
61
|
+
for k, v in cls.model_fields.items():
|
62
|
+
k: str
|
63
|
+
v: FieldInfo
|
64
|
+
fields.append(
|
65
|
+
FieldSchema(k, dtype=type_mapping.get(v.annotation, DataType.UNKNOWN), description=v.description or "")
|
66
|
+
)
|
67
|
+
return CollectionSchema(fields)
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def from_sequence(cls, data: Sequence[Dict[str, Any]]) -> List[Self]:
|
71
|
+
"""Constructs a list of instances from a sequence of dictionaries."""
|
72
|
+
return [cls(**d) for d in data]
|
fabricatio/models/extra/rule.py
CHANGED
@@ -40,12 +40,11 @@ class RuleSet(SketchedAble, PersistentAble, WithBriefing, Language):
|
|
40
40
|
framework for the topic or domain covered by the rule set."""
|
41
41
|
|
42
42
|
@classmethod
|
43
|
-
def gather(cls, *rulesets: Unpack[Tuple["RuleSet"
|
43
|
+
def gather(cls, *rulesets: Unpack[Tuple["RuleSet", ...]]) -> Self:
|
44
44
|
"""Gathers multiple rule sets into a single rule set."""
|
45
45
|
if not rulesets:
|
46
46
|
raise ValueError("No rulesets provided")
|
47
47
|
return cls(
|
48
|
-
language=rulesets[0].language,
|
49
48
|
name=";".join(ruleset.name for ruleset in rulesets),
|
50
49
|
description=";".join(ruleset.description for ruleset in rulesets),
|
51
50
|
rules=list(flatten(r.rules for r in rulesets)),
|
fabricatio/models/generic.py
CHANGED
@@ -11,7 +11,7 @@ from fabricatio.config import configs
|
|
11
11
|
from fabricatio.fs.readers import MAGIKA, safe_text_read
|
12
12
|
from fabricatio.journal import logger
|
13
13
|
from fabricatio.parser import JsonCapture
|
14
|
-
from fabricatio.rust import blake3_hash
|
14
|
+
from fabricatio.rust import blake3_hash, detect_language
|
15
15
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
16
|
from fabricatio.utils import ok
|
17
17
|
from litellm.utils import token_counter
|
@@ -36,6 +36,7 @@ class Base(BaseModel):
|
|
36
36
|
The `model_config` uses `use_attribute_docstrings=True` to ensure field descriptions are
|
37
37
|
pulled from the attribute's docstring instead of the default Pydantic behavior.
|
38
38
|
"""
|
39
|
+
|
39
40
|
model_config = ConfigDict(use_attribute_docstrings=True)
|
40
41
|
|
41
42
|
|
@@ -45,13 +46,14 @@ class Display(Base):
|
|
45
46
|
Provides methods to generate both pretty-printed and compact JSON representations of the model.
|
46
47
|
Used for debugging and logging purposes.
|
47
48
|
"""
|
49
|
+
|
48
50
|
def display(self) -> str:
|
49
51
|
"""Generate pretty-printed JSON representation.
|
50
52
|
|
51
53
|
Returns:
|
52
54
|
str: JSON string with 1-level indentation for readability
|
53
55
|
"""
|
54
|
-
return self.model_dump_json(indent=1)
|
56
|
+
return self.model_dump_json(indent=1,by_alias=True)
|
55
57
|
|
56
58
|
def compact(self) -> str:
|
57
59
|
"""Generate compact JSON representation.
|
@@ -59,7 +61,7 @@ class Display(Base):
|
|
59
61
|
Returns:
|
60
62
|
str: Minified JSON string without whitespace
|
61
63
|
"""
|
62
|
-
return self.model_dump_json()
|
64
|
+
return self.model_dump_json(by_alias=True)
|
63
65
|
|
64
66
|
@staticmethod
|
65
67
|
def seq_display(seq: Iterable["Display"], compact: bool = False) -> str:
|
@@ -102,6 +104,20 @@ class Described(Base):
|
|
102
104
|
this object's intent and application."""
|
103
105
|
|
104
106
|
|
107
|
+
class Titled(Base):
|
108
|
+
"""Class that includes a title attribute."""
|
109
|
+
|
110
|
+
title: str
|
111
|
+
"""The title of this object, make it professional and concise.No prefixed heading number should be included."""
|
112
|
+
|
113
|
+
|
114
|
+
class WordCount(Base):
|
115
|
+
"""Class that includes a word count attribute."""
|
116
|
+
|
117
|
+
expected_word_count: int
|
118
|
+
"""Expected word count of this research component."""
|
119
|
+
|
120
|
+
|
105
121
|
class AsPrompt(Base):
|
106
122
|
"""Class that provides a method to generate a prompt from the model.
|
107
123
|
|
@@ -194,6 +210,7 @@ class PersistentAble(Base):
|
|
194
210
|
Enables saving model instances to disk with timestamped filenames and loading from persisted files.
|
195
211
|
Implements basic versioning through filename hashing and timestamping.
|
196
212
|
"""
|
213
|
+
|
197
214
|
def persist(self, path: str | Path) -> Self:
|
198
215
|
"""Save model instance to disk with versioned filename.
|
199
216
|
|
@@ -208,7 +225,7 @@ class PersistentAble(Base):
|
|
208
225
|
- Hash generated from JSON content ensures uniqueness
|
209
226
|
"""
|
210
227
|
p = Path(path)
|
211
|
-
out = self.model_dump_json()
|
228
|
+
out = self.model_dump_json(indent=1,by_alias=True)
|
212
229
|
|
213
230
|
# Generate a timestamp in the format YYYYMMDD_HHMMSS
|
214
231
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -281,11 +298,17 @@ class PersistentAble(Base):
|
|
281
298
|
class Language(Base):
|
282
299
|
"""Class that provides a language attribute."""
|
283
300
|
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
301
|
+
@property
|
302
|
+
def language(self)->str:
|
303
|
+
"""Get the language of the object."""
|
304
|
+
if isinstance(self,Described):
|
305
|
+
return detect_language(self.description)
|
306
|
+
if isinstance(self,Titled):
|
307
|
+
return detect_language(self.title)
|
308
|
+
if isinstance(self,Named):
|
309
|
+
return detect_language(self.name)
|
310
|
+
|
311
|
+
return detect_language(self.model_dump_json(by_alias=True))
|
289
312
|
class ModelHash(Base):
|
290
313
|
"""Class that provides a hash value for the object.
|
291
314
|
|
@@ -527,7 +550,7 @@ class FinalizedDumpAble(Base):
|
|
527
550
|
Returns:
|
528
551
|
str: The finalized dump of the object.
|
529
552
|
"""
|
530
|
-
return self.model_dump_json()
|
553
|
+
return self.model_dump_json(indent=1,by_alias=True)
|
531
554
|
|
532
555
|
def finalized_dump_to(self, path: str | Path) -> Self:
|
533
556
|
"""Finalize the dump of the object to a file.
|
@@ -670,6 +693,7 @@ class ScopedConfig(Base):
|
|
670
693
|
Manages LLM, embedding, and vector database configurations with fallback logic.
|
671
694
|
Allows configuration values to be overridden in a hierarchical manner.
|
672
695
|
"""
|
696
|
+
|
673
697
|
llm_api_endpoint: Optional[HttpUrl] = None
|
674
698
|
"""The OpenAI API endpoint."""
|
675
699
|
|
@@ -1,48 +1,10 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from importlib.util import find_spec
|
4
3
|
from typing import Any, Dict, List, Optional, Required, TypedDict
|
5
4
|
|
6
5
|
from litellm.caching.caching import CacheMode
|
7
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
8
7
|
|
9
|
-
if find_spec("pymilvus"):
|
10
|
-
from pymilvus import CollectionSchema
|
11
|
-
from pymilvus.milvus_client import IndexParams
|
12
|
-
|
13
|
-
class CollectionConfigKwargs(TypedDict, total=False):
|
14
|
-
"""Configuration parameters for a vector collection.
|
15
|
-
|
16
|
-
These arguments are typically used when configuring connections to vector databases.
|
17
|
-
"""
|
18
|
-
|
19
|
-
dimension: int | None
|
20
|
-
primary_field_name: str
|
21
|
-
id_type: str
|
22
|
-
vector_field_name: str
|
23
|
-
metric_type: str
|
24
|
-
timeout: float | None
|
25
|
-
schema: CollectionSchema | None
|
26
|
-
index_params: IndexParams | None
|
27
|
-
|
28
|
-
|
29
|
-
class FetchKwargs(TypedDict, total=False):
|
30
|
-
"""Arguments for fetching data from vector collections.
|
31
|
-
|
32
|
-
Controls how data is retrieved from vector databases, including filtering
|
33
|
-
and result limiting parameters.
|
34
|
-
"""
|
35
|
-
|
36
|
-
collection_name: str | None
|
37
|
-
similarity_threshold: float
|
38
|
-
result_per_query: int
|
39
|
-
|
40
|
-
|
41
|
-
class RetrievalKwargs(FetchKwargs, total=False):
|
42
|
-
"""Arguments for retrieval operations."""
|
43
|
-
|
44
|
-
final_limit: int
|
45
|
-
|
46
8
|
|
47
9
|
class EmbeddingKwargs(TypedDict, total=False):
|
48
10
|
"""Configuration parameters for text embedding operations.
|
@@ -139,6 +101,7 @@ class ReviewKwargs[T](ReviewInnerKwargs[T], total=False):
|
|
139
101
|
|
140
102
|
class ReferencedKwargs[T](ValidateKwargs[T], total=False):
|
141
103
|
"""Arguments for content review operations."""
|
104
|
+
|
142
105
|
reference: str
|
143
106
|
|
144
107
|
|
fabricatio/models/task.py
CHANGED
@@ -7,11 +7,11 @@ from asyncio import Queue
|
|
7
7
|
from typing import Any, List, Optional, Self
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
|
+
from fabricatio.constants import TaskStatus
|
10
11
|
from fabricatio.core import env
|
11
12
|
from fabricatio.journal import logger
|
12
13
|
from fabricatio.models.events import Event, EventLike
|
13
14
|
from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
|
14
|
-
from fabricatio.models.utils import TaskStatus
|
15
15
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
16
|
from pydantic import Field, PrivateAttr
|
17
17
|
|
@@ -112,12 +112,12 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
112
112
|
"""Return a formatted status label for the task.
|
113
113
|
|
114
114
|
Args:
|
115
|
-
status (TaskStatus): The status of the task.
|
115
|
+
status (fabricatio.constants.TaskStatus): The status of the task.
|
116
116
|
|
117
117
|
Returns:
|
118
118
|
str: The formatted status label.
|
119
119
|
"""
|
120
|
-
return self._namespace.derive(self.name).push(status
|
120
|
+
return self._namespace.derive(self.name).push(status).collapse()
|
121
121
|
|
122
122
|
@property
|
123
123
|
def pending_label(self) -> str:
|
fabricatio/models/usages.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import traceback
|
4
4
|
from asyncio import gather
|
5
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Union, Unpack, overload
|
5
|
+
from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
|
6
6
|
|
7
7
|
import asyncstdlib
|
8
8
|
import litellm
|
@@ -13,7 +13,6 @@ from fabricatio.models.generic import ScopedConfig, WithBriefing
|
|
13
13
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.models.utils import Messages
|
17
16
|
from fabricatio.parser import GenericCapture, JsonCapture
|
18
17
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
19
18
|
from fabricatio.utils import ok
|
@@ -28,7 +27,7 @@ from litellm.types.utils import (
|
|
28
27
|
)
|
29
28
|
from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore [reportPrivateImportUsage]
|
30
29
|
from more_itertools import duplicates_everseen
|
31
|
-
from pydantic import Field, NonNegativeInt, PositiveInt
|
30
|
+
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
32
31
|
|
33
32
|
if configs.cache.enabled and configs.cache.type:
|
34
33
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
@@ -299,10 +298,11 @@ class LLMUsage(ScopedConfig):
|
|
299
298
|
for lap in range(max_validations):
|
300
299
|
try:
|
301
300
|
if ((validated := validator(response := await self.aask(question=q, **kwargs))) is not None) or (
|
302
|
-
co_extractor
|
301
|
+
co_extractor is not None
|
302
|
+
and logger.debug("Co-extraction is enabled.") is None
|
303
303
|
and (
|
304
304
|
validated := validator(
|
305
|
-
await self.aask(
|
305
|
+
response := await self.aask(
|
306
306
|
question=(
|
307
307
|
TEMPLATE_MANAGER.render_template(
|
308
308
|
configs.templates.co_validation_template,
|
@@ -319,12 +319,13 @@ class LLMUsage(ScopedConfig):
|
|
319
319
|
return validated
|
320
320
|
|
321
321
|
except RateLimitError as e:
|
322
|
-
logger.warning(f"Rate limit error
|
322
|
+
logger.warning(f"Rate limit error:\n{e}")
|
323
323
|
continue
|
324
324
|
except Exception as e: # noqa: BLE001
|
325
|
-
logger.error(f"Error during validation
|
325
|
+
logger.error(f"Error during validation:\n{e}")
|
326
326
|
logger.debug(traceback.format_exc())
|
327
327
|
break
|
328
|
+
logger.error(f"Failed to validate the response at {lap}th attempt:\n{response}")
|
328
329
|
if not kwargs.get("no_cache"):
|
329
330
|
kwargs["no_cache"] = True
|
330
331
|
logger.debug("Closed the cache for the next attempt")
|
@@ -493,7 +494,7 @@ class LLMUsage(ScopedConfig):
|
|
493
494
|
affirm_case: str = "",
|
494
495
|
deny_case: str = "",
|
495
496
|
**kwargs: Unpack[ValidateKwargs[bool]],
|
496
|
-
) -> bool:
|
497
|
+
) -> Optional[bool]:
|
497
498
|
"""Asynchronously judges a prompt using AI validation.
|
498
499
|
|
499
500
|
Args:
|
@@ -730,3 +731,72 @@ class ToolBoxUsage(LLMUsage):
|
|
730
731
|
for other in (x for x in others if isinstance(x, ToolBoxUsage)):
|
731
732
|
other.toolboxes.update(self.toolboxes)
|
732
733
|
return self
|
734
|
+
|
735
|
+
|
736
|
+
class Message(BaseModel):
|
737
|
+
"""A class representing a message."""
|
738
|
+
|
739
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
740
|
+
role: Literal["user", "system", "assistant"]
|
741
|
+
"""The role of the message sender."""
|
742
|
+
content: str
|
743
|
+
"""The content of the message."""
|
744
|
+
|
745
|
+
|
746
|
+
class Messages(list):
|
747
|
+
"""A list of messages."""
|
748
|
+
|
749
|
+
def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
|
750
|
+
"""Adds a message to the list with the specified role and content.
|
751
|
+
|
752
|
+
Args:
|
753
|
+
role (Literal["user", "system", "assistant"]): The role of the message sender.
|
754
|
+
content (str): The content of the message.
|
755
|
+
|
756
|
+
Returns:
|
757
|
+
Self: The current instance of Messages to allow method chaining.
|
758
|
+
"""
|
759
|
+
if content:
|
760
|
+
self.append(Message(role=role, content=content))
|
761
|
+
return self
|
762
|
+
|
763
|
+
def add_user_message(self, content: str) -> Self:
|
764
|
+
"""Adds a user message to the list with the specified content.
|
765
|
+
|
766
|
+
Args:
|
767
|
+
content (str): The content of the user message.
|
768
|
+
|
769
|
+
Returns:
|
770
|
+
Self: The current instance of Messages to allow method chaining.
|
771
|
+
"""
|
772
|
+
return self.add_message("user", content)
|
773
|
+
|
774
|
+
def add_system_message(self, content: str) -> Self:
|
775
|
+
"""Adds a system message to the list with the specified content.
|
776
|
+
|
777
|
+
Args:
|
778
|
+
content (str): The content of the system message.
|
779
|
+
|
780
|
+
Returns:
|
781
|
+
Self: The current instance of Messages to allow method chaining.
|
782
|
+
"""
|
783
|
+
return self.add_message("system", content)
|
784
|
+
|
785
|
+
def add_assistant_message(self, content: str) -> Self:
|
786
|
+
"""Adds an assistant message to the list with the specified content.
|
787
|
+
|
788
|
+
Args:
|
789
|
+
content (str): The content of the assistant message.
|
790
|
+
|
791
|
+
Returns:
|
792
|
+
Self: The current instance of Messages to allow method chaining.
|
793
|
+
"""
|
794
|
+
return self.add_message("assistant", content)
|
795
|
+
|
796
|
+
def as_list(self) -> List[Dict[str, str]]:
|
797
|
+
"""Converts the messages to a list of dictionaries.
|
798
|
+
|
799
|
+
Returns:
|
800
|
+
list[dict]: A list of dictionaries representing the messages.
|
801
|
+
"""
|
802
|
+
return [message.model_dump() for message in self]
|
fabricatio/parser.py
CHANGED
@@ -48,10 +48,10 @@ class Capture(BaseModel):
|
|
48
48
|
case "json" if configs.general.use_json_repair:
|
49
49
|
logger.debug("Applying json repair to text.")
|
50
50
|
if isinstance(text, str):
|
51
|
-
return repair_json(text, ensure_ascii=False)
|
52
|
-
return [repair_json(item, ensure_ascii=False) for item in text]
|
51
|
+
return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
|
52
|
+
return [repair_json(item, ensure_ascii=False) for item in text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
53
53
|
case _:
|
54
|
-
return text
|
54
|
+
return text # pyright: ignore [reportReturnType]
|
55
55
|
|
56
56
|
def capture(self, text: str) -> Tuple[str, ...] | str | None:
|
57
57
|
"""Capture the first occurrence of the pattern in the given text.
|
@@ -88,7 +88,7 @@ class Capture(BaseModel):
|
|
88
88
|
if (cap := self.capture(text)) is None:
|
89
89
|
return None
|
90
90
|
try:
|
91
|
-
return convertor(cap)
|
91
|
+
return convertor(cap) # pyright: ignore [reportArgumentType]
|
92
92
|
except (ValueError, SyntaxError, ValidationError) as e:
|
93
93
|
logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
|
94
94
|
return None
|
@@ -120,7 +120,7 @@ class Capture(BaseModel):
|
|
120
120
|
judges.append(lambda output_obj: len(output_obj) == length)
|
121
121
|
|
122
122
|
if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
|
123
|
-
return out
|
123
|
+
return out # pyright: ignore [reportReturnType]
|
124
124
|
return None
|
125
125
|
|
126
126
|
@classmethod
|
Binary file
|
fabricatio/rust.pyi
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Python interface definitions for Rust-based functionality.
|
1
|
+
"""Python interface definitions for Rust-based functionality.
|
3
2
|
|
4
3
|
This module provides type stubs and documentation for Rust-implemented utilities,
|
5
4
|
including template rendering, cryptographic hashing, language detection, and
|
@@ -12,12 +11,8 @@ Key Features:
|
|
12
11
|
- Text utilities: Word boundary splitting and word counting.
|
13
12
|
"""
|
14
13
|
|
15
|
-
|
16
14
|
from pathlib import Path
|
17
|
-
from typing import List, Optional
|
18
|
-
|
19
|
-
from pydantic import JsonValue
|
20
|
-
|
15
|
+
from typing import Any, Dict, List, Optional
|
21
16
|
|
22
17
|
class TemplateManager:
|
23
18
|
"""Template rendering engine using Handlebars templates.
|
@@ -59,7 +54,7 @@ class TemplateManager:
|
|
59
54
|
This refreshes the template cache, finding any new or modified templates.
|
60
55
|
"""
|
61
56
|
|
62
|
-
def render_template(self, name: str, data:
|
57
|
+
def render_template(self, name: str, data: Dict[str, Any]) -> str:
|
63
58
|
"""Render a template with context data.
|
64
59
|
|
65
60
|
Args:
|
@@ -73,7 +68,7 @@ class TemplateManager:
|
|
73
68
|
RuntimeError: If template rendering fails
|
74
69
|
"""
|
75
70
|
|
76
|
-
def render_template_raw(self, template: str, data:
|
71
|
+
def render_template_raw(self, template: str, data: Dict[str, Any]) -> str:
|
77
72
|
"""Render a template with context data.
|
78
73
|
|
79
74
|
Args:
|
@@ -97,7 +92,6 @@ def blake3_hash(content: bytes) -> str:
|
|
97
92
|
def detect_language(string: str) -> str:
|
98
93
|
"""Detect the language of a given string."""
|
99
94
|
|
100
|
-
|
101
95
|
def split_word_bounds(string: str) -> List[str]:
|
102
96
|
"""Split the string into words based on word boundaries.
|
103
97
|
|
@@ -107,6 +101,29 @@ def split_word_bounds(string: str) -> List[str]:
|
|
107
101
|
Returns:
|
108
102
|
A list of words extracted from the string.
|
109
103
|
"""
|
104
|
+
|
105
|
+
def split_sentence_bounds(string: str) -> List[str]:
|
106
|
+
"""Split the string into sentences based on sentence boundaries.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
string: The input string to be split.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
A list of sentences extracted from the string.
|
113
|
+
"""
|
114
|
+
|
115
|
+
def split_into_chunks(string: str, max_chunk_size: int, max_overlapping_rate: float = 0.3) -> List[str]:
|
116
|
+
"""Split the string into chunks of a specified size.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
string: The input string to be split.
|
120
|
+
max_chunk_size: The maximum size of each chunk.
|
121
|
+
max_overlapping_rate: The minimum overlapping rate between chunks.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
A list of chunks extracted from the string.
|
125
|
+
"""
|
126
|
+
|
110
127
|
def word_count(string: str) -> int:
|
111
128
|
"""Count the number of words in the string.
|
112
129
|
|
@@ -117,8 +134,6 @@ def word_count(string: str) -> int:
|
|
117
134
|
The number of words in the string.
|
118
135
|
"""
|
119
136
|
|
120
|
-
|
121
|
-
|
122
137
|
class BibManager:
|
123
138
|
"""BibTeX bibliography manager for parsing and querying citation data."""
|
124
139
|
|
Binary file
|