fabricatio 0.2.9.dev3__cp312-cp312-win_amd64.whl → 0.2.10__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +24 -114
- fabricatio/actions/article_rag.py +156 -18
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +40 -18
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/check.py +15 -9
- fabricatio/capabilities/correct.py +5 -6
- fabricatio/capabilities/rag.py +41 -231
- fabricatio/capabilities/rating.py +46 -40
- fabricatio/config.py +6 -4
- fabricatio/constants.py +20 -0
- fabricatio/decorators.py +23 -0
- fabricatio/fs/readers.py +20 -1
- fabricatio/models/adv_kwargs_types.py +35 -0
- fabricatio/models/events.py +6 -6
- fabricatio/models/extra/advanced_judge.py +4 -4
- fabricatio/models/extra/aricle_rag.py +170 -0
- fabricatio/models/extra/article_base.py +25 -211
- fabricatio/models/extra/article_essence.py +8 -7
- fabricatio/models/extra/article_main.py +98 -97
- fabricatio/models/extra/article_proposal.py +15 -14
- fabricatio/models/extra/patches.py +6 -6
- fabricatio/models/extra/problem.py +12 -17
- fabricatio/models/extra/rag.py +98 -0
- fabricatio/models/extra/rule.py +1 -2
- fabricatio/models/generic.py +53 -13
- fabricatio/models/kwargs_types.py +8 -36
- fabricatio/models/task.py +3 -3
- fabricatio/models/usages.py +85 -9
- fabricatio/parser.py +5 -5
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +137 -10
- fabricatio/utils.py +62 -4
- fabricatio-0.2.10.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/METADATA +1 -4
- fabricatio-0.2.10.dist-info/RECORD +64 -0
- fabricatio/models/utils.py +0 -148
- fabricatio-0.2.9.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.9.dev3.dist-info/RECORD +0 -61
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/generic.py
CHANGED
@@ -3,15 +3,14 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
8
|
import orjson
|
9
|
-
import rtoml
|
10
9
|
from fabricatio.config import configs
|
11
10
|
from fabricatio.fs.readers import MAGIKA, safe_text_read
|
12
11
|
from fabricatio.journal import logger
|
13
12
|
from fabricatio.parser import JsonCapture
|
14
|
-
from fabricatio.rust import blake3_hash
|
13
|
+
from fabricatio.rust import blake3_hash, detect_language
|
15
14
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
15
|
from fabricatio.utils import ok
|
17
16
|
from litellm.utils import token_counter
|
@@ -36,6 +35,7 @@ class Base(BaseModel):
|
|
36
35
|
The `model_config` uses `use_attribute_docstrings=True` to ensure field descriptions are
|
37
36
|
pulled from the attribute's docstring instead of the default Pydantic behavior.
|
38
37
|
"""
|
38
|
+
|
39
39
|
model_config = ConfigDict(use_attribute_docstrings=True)
|
40
40
|
|
41
41
|
|
@@ -45,13 +45,14 @@ class Display(Base):
|
|
45
45
|
Provides methods to generate both pretty-printed and compact JSON representations of the model.
|
46
46
|
Used for debugging and logging purposes.
|
47
47
|
"""
|
48
|
+
|
48
49
|
def display(self) -> str:
|
49
50
|
"""Generate pretty-printed JSON representation.
|
50
51
|
|
51
52
|
Returns:
|
52
53
|
str: JSON string with 1-level indentation for readability
|
53
54
|
"""
|
54
|
-
return self.model_dump_json(indent=1)
|
55
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
55
56
|
|
56
57
|
def compact(self) -> str:
|
57
58
|
"""Generate compact JSON representation.
|
@@ -59,7 +60,7 @@ class Display(Base):
|
|
59
60
|
Returns:
|
60
61
|
str: Minified JSON string without whitespace
|
61
62
|
"""
|
62
|
-
return self.model_dump_json()
|
63
|
+
return self.model_dump_json(by_alias=True)
|
63
64
|
|
64
65
|
@staticmethod
|
65
66
|
def seq_display(seq: Iterable["Display"], compact: bool = False) -> str:
|
@@ -102,6 +103,29 @@ class Described(Base):
|
|
102
103
|
this object's intent and application."""
|
103
104
|
|
104
105
|
|
106
|
+
class Titled(Base):
|
107
|
+
"""Class that includes a title attribute."""
|
108
|
+
|
109
|
+
title: str
|
110
|
+
"""The title of this object, make it professional and concise.No prefixed heading number should be included."""
|
111
|
+
|
112
|
+
|
113
|
+
class WordCount(Base):
|
114
|
+
"""Class that includes a word count attribute."""
|
115
|
+
|
116
|
+
expected_word_count: int
|
117
|
+
"""Expected word count of this research component."""
|
118
|
+
|
119
|
+
|
120
|
+
class FromMapping(Base):
|
121
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
@abstractmethod
|
125
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
126
|
+
"""Generate a list of objects from a mapping."""
|
127
|
+
|
128
|
+
|
105
129
|
class AsPrompt(Base):
|
106
130
|
"""Class that provides a method to generate a prompt from the model.
|
107
131
|
|
@@ -194,6 +218,7 @@ class PersistentAble(Base):
|
|
194
218
|
Enables saving model instances to disk with timestamped filenames and loading from persisted files.
|
195
219
|
Implements basic versioning through filename hashing and timestamping.
|
196
220
|
"""
|
221
|
+
|
197
222
|
def persist(self, path: str | Path) -> Self:
|
198
223
|
"""Save model instance to disk with versioned filename.
|
199
224
|
|
@@ -208,7 +233,7 @@ class PersistentAble(Base):
|
|
208
233
|
- Hash generated from JSON content ensures uniqueness
|
209
234
|
"""
|
210
235
|
p = Path(path)
|
211
|
-
out = self.model_dump_json()
|
236
|
+
out = self.model_dump_json(indent=1, by_alias=True)
|
212
237
|
|
213
238
|
# Generate a timestamp in the format YYYYMMDD_HHMMSS
|
214
239
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -281,9 +306,17 @@ class PersistentAble(Base):
|
|
281
306
|
class Language(Base):
|
282
307
|
"""Class that provides a language attribute."""
|
283
308
|
|
284
|
-
|
285
|
-
|
286
|
-
|
309
|
+
@property
|
310
|
+
def language(self) -> str:
|
311
|
+
"""Get the language of the object."""
|
312
|
+
if isinstance(self, Described):
|
313
|
+
return detect_language(self.description)
|
314
|
+
if isinstance(self, Titled):
|
315
|
+
return detect_language(self.title)
|
316
|
+
if isinstance(self, Named):
|
317
|
+
return detect_language(self.name)
|
318
|
+
|
319
|
+
return detect_language(self.model_dump_json(by_alias=True))
|
287
320
|
|
288
321
|
|
289
322
|
class ModelHash(Base):
|
@@ -527,7 +560,7 @@ class FinalizedDumpAble(Base):
|
|
527
560
|
Returns:
|
528
561
|
str: The finalized dump of the object.
|
529
562
|
"""
|
530
|
-
return self.model_dump_json()
|
563
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
531
564
|
|
532
565
|
def finalized_dump_to(self, path: str | Path) -> Self:
|
533
566
|
"""Finalize the dump of the object to a file.
|
@@ -639,8 +672,9 @@ class Vectorizable(Base):
|
|
639
672
|
This class includes methods to prepare the model for vectorization, ensuring it fits within a specified token length.
|
640
673
|
"""
|
641
674
|
|
675
|
+
@abstractmethod
|
642
676
|
def _prepare_vectorization_inner(self) -> str:
|
643
|
-
|
677
|
+
"""Prepare the model for vectorization."""
|
644
678
|
|
645
679
|
@final
|
646
680
|
def prepare_vectorization(self, max_length: Optional[int] = None) -> str:
|
@@ -658,8 +692,7 @@ class Vectorizable(Base):
|
|
658
692
|
max_length = max_length or configs.embedding.max_sequence_length
|
659
693
|
chunk = self._prepare_vectorization_inner()
|
660
694
|
if max_length and (length := token_counter(text=chunk)) > max_length:
|
661
|
-
|
662
|
-
raise ValueError(err)
|
695
|
+
raise ValueError(f"Chunk exceeds maximum sequence length {max_length}, got {length}, see \n{chunk}")
|
663
696
|
|
664
697
|
return chunk
|
665
698
|
|
@@ -670,6 +703,7 @@ class ScopedConfig(Base):
|
|
670
703
|
Manages LLM, embedding, and vector database configurations with fallback logic.
|
671
704
|
Allows configuration values to be overridden in a hierarchical manner.
|
672
705
|
"""
|
706
|
+
|
673
707
|
llm_api_endpoint: Optional[HttpUrl] = None
|
674
708
|
"""The OpenAI API endpoint."""
|
675
709
|
|
@@ -709,6 +743,12 @@ class ScopedConfig(Base):
|
|
709
743
|
llm_rpm: Optional[PositiveInt] = None
|
710
744
|
"""The requests per minute of the LLM model."""
|
711
745
|
|
746
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
747
|
+
"""The presence penalty of the LLM model."""
|
748
|
+
|
749
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
750
|
+
"""The frequency penalty of the LLM model."""
|
751
|
+
|
712
752
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
713
753
|
"""The OpenAI API endpoint."""
|
714
754
|
|
@@ -1,47 +1,16 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from
|
4
|
-
from typing import Any, Dict, List, Optional, Required, TypedDict
|
3
|
+
from typing import Any, Dict, List, NotRequired, Optional, Required, TypedDict
|
5
4
|
|
6
5
|
from litellm.caching.caching import CacheMode
|
7
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
8
7
|
|
9
|
-
if find_spec("pymilvus"):
|
10
|
-
from pymilvus import CollectionSchema
|
11
|
-
from pymilvus.milvus_client import IndexParams
|
12
8
|
|
13
|
-
|
14
|
-
|
9
|
+
class ChunkKwargs(TypedDict):
|
10
|
+
"""Configuration parameters for chunking operations."""
|
15
11
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
dimension: int | None
|
20
|
-
primary_field_name: str
|
21
|
-
id_type: str
|
22
|
-
vector_field_name: str
|
23
|
-
metric_type: str
|
24
|
-
timeout: float | None
|
25
|
-
schema: CollectionSchema | None
|
26
|
-
index_params: IndexParams | None
|
27
|
-
|
28
|
-
|
29
|
-
class FetchKwargs(TypedDict, total=False):
|
30
|
-
"""Arguments for fetching data from vector collections.
|
31
|
-
|
32
|
-
Controls how data is retrieved from vector databases, including filtering
|
33
|
-
and result limiting parameters.
|
34
|
-
"""
|
35
|
-
|
36
|
-
collection_name: str | None
|
37
|
-
similarity_threshold: float
|
38
|
-
result_per_query: int
|
39
|
-
|
40
|
-
|
41
|
-
class RetrievalKwargs(FetchKwargs, total=False):
|
42
|
-
"""Arguments for retrieval operations."""
|
43
|
-
|
44
|
-
final_limit: int
|
12
|
+
max_chunk_size: int
|
13
|
+
max_overlapping_rate: NotRequired[float]
|
45
14
|
|
46
15
|
|
47
16
|
class EmbeddingKwargs(TypedDict, total=False):
|
@@ -76,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
76
45
|
no_store: bool # If store the response of this call to cache
|
77
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
78
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
79
50
|
|
80
51
|
|
81
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -139,6 +110,7 @@ class ReviewKwargs[T](ReviewInnerKwargs[T], total=False):
|
|
139
110
|
|
140
111
|
class ReferencedKwargs[T](ValidateKwargs[T], total=False):
|
141
112
|
"""Arguments for content review operations."""
|
113
|
+
|
142
114
|
reference: str
|
143
115
|
|
144
116
|
|
fabricatio/models/task.py
CHANGED
@@ -7,11 +7,11 @@ from asyncio import Queue
|
|
7
7
|
from typing import Any, List, Optional, Self
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
|
+
from fabricatio.constants import TaskStatus
|
10
11
|
from fabricatio.core import env
|
11
12
|
from fabricatio.journal import logger
|
12
13
|
from fabricatio.models.events import Event, EventLike
|
13
14
|
from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
|
14
|
-
from fabricatio.models.utils import TaskStatus
|
15
15
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
16
|
from pydantic import Field, PrivateAttr
|
17
17
|
|
@@ -112,12 +112,12 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
112
112
|
"""Return a formatted status label for the task.
|
113
113
|
|
114
114
|
Args:
|
115
|
-
status (TaskStatus): The status of the task.
|
115
|
+
status (fabricatio.constants.TaskStatus): The status of the task.
|
116
116
|
|
117
117
|
Returns:
|
118
118
|
str: The formatted status label.
|
119
119
|
"""
|
120
|
-
return self._namespace.derive(self.name).push(status
|
120
|
+
return self._namespace.derive(self.name).push(status).collapse()
|
121
121
|
|
122
122
|
@property
|
123
123
|
def pending_label(self) -> str:
|
fabricatio/models/usages.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import traceback
|
4
4
|
from asyncio import gather
|
5
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Union, Unpack, overload
|
5
|
+
from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
|
6
6
|
|
7
7
|
import asyncstdlib
|
8
8
|
import litellm
|
@@ -13,7 +13,6 @@ from fabricatio.models.generic import ScopedConfig, WithBriefing
|
|
13
13
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.models.utils import Messages
|
17
16
|
from fabricatio.parser import GenericCapture, JsonCapture
|
18
17
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
19
18
|
from fabricatio.utils import ok
|
@@ -28,7 +27,7 @@ from litellm.types.utils import (
|
|
28
27
|
)
|
29
28
|
from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore [reportPrivateImportUsage]
|
30
29
|
from more_itertools import duplicates_everseen
|
31
|
-
from pydantic import Field, NonNegativeInt, PositiveInt
|
30
|
+
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
32
31
|
|
33
32
|
if configs.cache.enabled and configs.cache.type:
|
34
33
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
@@ -64,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
64
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
65
64
|
return ROUTER
|
66
65
|
|
67
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
68
67
|
async def aquery(
|
69
68
|
self,
|
70
69
|
messages: List[Dict[str, str]],
|
@@ -123,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
123
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
124
123
|
"s-maxage": kwargs.get("s_maxage"),
|
125
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
126
131
|
)
|
127
132
|
|
128
133
|
async def ainvoke(
|
@@ -299,10 +304,11 @@ class LLMUsage(ScopedConfig):
|
|
299
304
|
for lap in range(max_validations):
|
300
305
|
try:
|
301
306
|
if ((validated := validator(response := await self.aask(question=q, **kwargs))) is not None) or (
|
302
|
-
co_extractor
|
307
|
+
co_extractor is not None
|
308
|
+
and logger.debug("Co-extraction is enabled.") is None
|
303
309
|
and (
|
304
310
|
validated := validator(
|
305
|
-
await self.aask(
|
311
|
+
response := await self.aask(
|
306
312
|
question=(
|
307
313
|
TEMPLATE_MANAGER.render_template(
|
308
314
|
configs.templates.co_validation_template,
|
@@ -319,12 +325,13 @@ class LLMUsage(ScopedConfig):
|
|
319
325
|
return validated
|
320
326
|
|
321
327
|
except RateLimitError as e:
|
322
|
-
logger.warning(f"Rate limit error
|
328
|
+
logger.warning(f"Rate limit error:\n{e}")
|
323
329
|
continue
|
324
330
|
except Exception as e: # noqa: BLE001
|
325
|
-
logger.error(f"Error during validation
|
331
|
+
logger.error(f"Error during validation:\n{e}")
|
326
332
|
logger.debug(traceback.format_exc())
|
327
333
|
break
|
334
|
+
logger.error(f"Failed to validate the response at {lap}th attempt:\n{response}")
|
328
335
|
if not kwargs.get("no_cache"):
|
329
336
|
kwargs["no_cache"] = True
|
330
337
|
logger.debug("Closed the cache for the next attempt")
|
@@ -493,7 +500,7 @@ class LLMUsage(ScopedConfig):
|
|
493
500
|
affirm_case: str = "",
|
494
501
|
deny_case: str = "",
|
495
502
|
**kwargs: Unpack[ValidateKwargs[bool]],
|
496
|
-
) -> bool:
|
503
|
+
) -> Optional[bool]:
|
497
504
|
"""Asynchronously judges a prompt using AI validation.
|
498
505
|
|
499
506
|
Args:
|
@@ -730,3 +737,72 @@ class ToolBoxUsage(LLMUsage):
|
|
730
737
|
for other in (x for x in others if isinstance(x, ToolBoxUsage)):
|
731
738
|
other.toolboxes.update(self.toolboxes)
|
732
739
|
return self
|
740
|
+
|
741
|
+
|
742
|
+
class Message(BaseModel):
|
743
|
+
"""A class representing a message."""
|
744
|
+
|
745
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
746
|
+
role: Literal["user", "system", "assistant"]
|
747
|
+
"""The role of the message sender."""
|
748
|
+
content: str
|
749
|
+
"""The content of the message."""
|
750
|
+
|
751
|
+
|
752
|
+
class Messages(list):
|
753
|
+
"""A list of messages."""
|
754
|
+
|
755
|
+
def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
|
756
|
+
"""Adds a message to the list with the specified role and content.
|
757
|
+
|
758
|
+
Args:
|
759
|
+
role (Literal["user", "system", "assistant"]): The role of the message sender.
|
760
|
+
content (str): The content of the message.
|
761
|
+
|
762
|
+
Returns:
|
763
|
+
Self: The current instance of Messages to allow method chaining.
|
764
|
+
"""
|
765
|
+
if content:
|
766
|
+
self.append(Message(role=role, content=content))
|
767
|
+
return self
|
768
|
+
|
769
|
+
def add_user_message(self, content: str) -> Self:
|
770
|
+
"""Adds a user message to the list with the specified content.
|
771
|
+
|
772
|
+
Args:
|
773
|
+
content (str): The content of the user message.
|
774
|
+
|
775
|
+
Returns:
|
776
|
+
Self: The current instance of Messages to allow method chaining.
|
777
|
+
"""
|
778
|
+
return self.add_message("user", content)
|
779
|
+
|
780
|
+
def add_system_message(self, content: str) -> Self:
|
781
|
+
"""Adds a system message to the list with the specified content.
|
782
|
+
|
783
|
+
Args:
|
784
|
+
content (str): The content of the system message.
|
785
|
+
|
786
|
+
Returns:
|
787
|
+
Self: The current instance of Messages to allow method chaining.
|
788
|
+
"""
|
789
|
+
return self.add_message("system", content)
|
790
|
+
|
791
|
+
def add_assistant_message(self, content: str) -> Self:
|
792
|
+
"""Adds an assistant message to the list with the specified content.
|
793
|
+
|
794
|
+
Args:
|
795
|
+
content (str): The content of the assistant message.
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
Self: The current instance of Messages to allow method chaining.
|
799
|
+
"""
|
800
|
+
return self.add_message("assistant", content)
|
801
|
+
|
802
|
+
def as_list(self) -> List[Dict[str, str]]:
|
803
|
+
"""Converts the messages to a list of dictionaries.
|
804
|
+
|
805
|
+
Returns:
|
806
|
+
list[dict]: A list of dictionaries representing the messages.
|
807
|
+
"""
|
808
|
+
return [message.model_dump() for message in self]
|
fabricatio/parser.py
CHANGED
@@ -48,10 +48,10 @@ class Capture(BaseModel):
|
|
48
48
|
case "json" if configs.general.use_json_repair:
|
49
49
|
logger.debug("Applying json repair to text.")
|
50
50
|
if isinstance(text, str):
|
51
|
-
return repair_json(text, ensure_ascii=False)
|
52
|
-
return [repair_json(item, ensure_ascii=False) for item in text]
|
51
|
+
return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
|
52
|
+
return [repair_json(item, ensure_ascii=False) for item in text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
53
53
|
case _:
|
54
|
-
return text
|
54
|
+
return text # pyright: ignore [reportReturnType]
|
55
55
|
|
56
56
|
def capture(self, text: str) -> Tuple[str, ...] | str | None:
|
57
57
|
"""Capture the first occurrence of the pattern in the given text.
|
@@ -88,7 +88,7 @@ class Capture(BaseModel):
|
|
88
88
|
if (cap := self.capture(text)) is None:
|
89
89
|
return None
|
90
90
|
try:
|
91
|
-
return convertor(cap)
|
91
|
+
return convertor(cap) # pyright: ignore [reportArgumentType]
|
92
92
|
except (ValueError, SyntaxError, ValidationError) as e:
|
93
93
|
logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
|
94
94
|
return None
|
@@ -120,7 +120,7 @@ class Capture(BaseModel):
|
|
120
120
|
judges.append(lambda output_obj: len(output_obj) == length)
|
121
121
|
|
122
122
|
if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
|
123
|
-
return out
|
123
|
+
return out # pyright: ignore [reportReturnType]
|
124
124
|
return None
|
125
125
|
|
126
126
|
@classmethod
|
Binary file
|
fabricatio/rust.pyi
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Python interface definitions for Rust-based functionality.
|
1
|
+
"""Python interface definitions for Rust-based functionality.
|
3
2
|
|
4
3
|
This module provides type stubs and documentation for Rust-implemented utilities,
|
5
4
|
including template rendering, cryptographic hashing, language detection, and
|
@@ -12,11 +11,8 @@ Key Features:
|
|
12
11
|
- Text utilities: Word boundary splitting and word counting.
|
13
12
|
"""
|
14
13
|
|
15
|
-
|
16
14
|
from pathlib import Path
|
17
|
-
from typing import List, Optional
|
18
|
-
|
19
|
-
from pydantic import JsonValue
|
15
|
+
from typing import Any, Dict, List, Optional
|
20
16
|
|
21
17
|
|
22
18
|
class TemplateManager:
|
@@ -29,7 +25,7 @@ class TemplateManager:
|
|
29
25
|
"""
|
30
26
|
|
31
27
|
def __init__(
|
32
|
-
|
28
|
+
self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
|
33
29
|
) -> None:
|
34
30
|
"""Initialize the template manager.
|
35
31
|
|
@@ -59,7 +55,7 @@ class TemplateManager:
|
|
59
55
|
This refreshes the template cache, finding any new or modified templates.
|
60
56
|
"""
|
61
57
|
|
62
|
-
def render_template(self, name: str, data:
|
58
|
+
def render_template(self, name: str, data: Dict[str, Any]) -> str:
|
63
59
|
"""Render a template with context data.
|
64
60
|
|
65
61
|
Args:
|
@@ -73,7 +69,7 @@ class TemplateManager:
|
|
73
69
|
RuntimeError: If template rendering fails
|
74
70
|
"""
|
75
71
|
|
76
|
-
def render_template_raw(self, template: str, data:
|
72
|
+
def render_template_raw(self, template: str, data: Dict[str, Any]) -> str:
|
77
73
|
"""Render a template with context data.
|
78
74
|
|
79
75
|
Args:
|
@@ -84,6 +80,7 @@ class TemplateManager:
|
|
84
80
|
Rendered template content as string
|
85
81
|
"""
|
86
82
|
|
83
|
+
|
87
84
|
def blake3_hash(content: bytes) -> str:
|
88
85
|
"""Calculate the BLAKE3 cryptographic hash of data.
|
89
86
|
|
@@ -94,6 +91,7 @@ def blake3_hash(content: bytes) -> str:
|
|
94
91
|
Hex-encoded BLAKE3 hash string
|
95
92
|
"""
|
96
93
|
|
94
|
+
|
97
95
|
def detect_language(string: str) -> str:
|
98
96
|
"""Detect the language of a given string."""
|
99
97
|
|
@@ -107,6 +105,32 @@ def split_word_bounds(string: str) -> List[str]:
|
|
107
105
|
Returns:
|
108
106
|
A list of words extracted from the string.
|
109
107
|
"""
|
108
|
+
|
109
|
+
|
110
|
+
def split_sentence_bounds(string: str) -> List[str]:
|
111
|
+
"""Split the string into sentences based on sentence boundaries.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
string: The input string to be split.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
A list of sentences extracted from the string.
|
118
|
+
"""
|
119
|
+
|
120
|
+
|
121
|
+
def split_into_chunks(string: str, max_chunk_size: int, max_overlapping_rate: float = 0.3) -> List[str]:
|
122
|
+
"""Split the string into chunks of a specified size.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
string: The input string to be split.
|
126
|
+
max_chunk_size: The maximum size of each chunk.
|
127
|
+
max_overlapping_rate: The minimum overlapping rate between chunks.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
A list of chunks extracted from the string.
|
131
|
+
"""
|
132
|
+
|
133
|
+
|
110
134
|
def word_count(string: str) -> int:
|
111
135
|
"""Count the number of words in the string.
|
112
136
|
|
@@ -118,6 +142,98 @@ def word_count(string: str) -> int:
|
|
118
142
|
"""
|
119
143
|
|
120
144
|
|
145
|
+
def is_chinese(string: str) -> bool:
|
146
|
+
"""Check if the given string is in Chinese."""
|
147
|
+
|
148
|
+
|
149
|
+
def is_english(string: str) -> bool:
|
150
|
+
"""Check if the given string is in English."""
|
151
|
+
|
152
|
+
|
153
|
+
def is_japanese(string: str) -> bool:
|
154
|
+
"""Check if the given string is in Japanese."""
|
155
|
+
|
156
|
+
|
157
|
+
def is_korean(string: str) -> bool:
|
158
|
+
"""Check if the given string is in Korean."""
|
159
|
+
|
160
|
+
|
161
|
+
def is_arabic(string: str) -> bool:
|
162
|
+
"""Check if the given string is in Arabic."""
|
163
|
+
|
164
|
+
|
165
|
+
def is_russian(string: str) -> bool:
|
166
|
+
"""Check if the given string is in Russian."""
|
167
|
+
|
168
|
+
|
169
|
+
def is_german(string: str) -> bool:
|
170
|
+
"""Check if the given string is in German."""
|
171
|
+
|
172
|
+
|
173
|
+
def is_french(string: str) -> bool:
|
174
|
+
"""Check if the given string is in French."""
|
175
|
+
|
176
|
+
|
177
|
+
def is_hindi(string: str) -> bool:
|
178
|
+
"""Check if the given string is in Hindi."""
|
179
|
+
|
180
|
+
|
181
|
+
def is_italian(string: str) -> bool:
|
182
|
+
"""Check if the given string is in Italian."""
|
183
|
+
|
184
|
+
|
185
|
+
def is_dutch(string: str) -> bool:
|
186
|
+
"""Check if the given string is in Dutch."""
|
187
|
+
|
188
|
+
|
189
|
+
def is_portuguese(string: str) -> bool:
|
190
|
+
"""Check if the given string is in Portuguese."""
|
191
|
+
|
192
|
+
|
193
|
+
def is_swedish(string: str) -> bool:
|
194
|
+
"""Check if the given string is in Swedish."""
|
195
|
+
|
196
|
+
|
197
|
+
def is_turkish(string: str) -> bool:
|
198
|
+
"""Check if the given string is in Turkish."""
|
199
|
+
|
200
|
+
|
201
|
+
def is_vietnamese(string: str) -> bool:
|
202
|
+
"""Check if the given string is in Vietnamese."""
|
203
|
+
|
204
|
+
|
205
|
+
def tex_to_typst(string: str) -> str:
|
206
|
+
"""Convert TeX to Typst.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
string: The input TeX string to be converted.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
The converted Typst string.
|
213
|
+
"""
|
214
|
+
|
215
|
+
|
216
|
+
def convert_all_inline_tex(string: str) -> str:
|
217
|
+
"""Convert all inline TeX code in the string.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
string: The input string containing inline TeX code wrapped in $code$.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
The converted string with inline TeX code replaced.
|
224
|
+
"""
|
225
|
+
|
226
|
+
|
227
|
+
def convert_all_block_tex(string: str) -> str:
|
228
|
+
"""Convert all block TeX code in the string.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
string: The input string containing block TeX code wrapped in $$code$$.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
The converted string with block TeX code replaced.
|
235
|
+
"""
|
236
|
+
|
121
237
|
|
122
238
|
class BibManager:
|
123
239
|
"""BibTeX bibliography manager for parsing and querying citation data."""
|
@@ -132,7 +248,7 @@ class BibManager:
|
|
132
248
|
RuntimeError: If file cannot be read or parsed
|
133
249
|
"""
|
134
250
|
|
135
|
-
def
|
251
|
+
def get_cite_key_by_title(self, title: str) -> Optional[str]:
|
136
252
|
"""Find citation key by exact title match.
|
137
253
|
|
138
254
|
Args:
|
@@ -142,6 +258,16 @@ class BibManager:
|
|
142
258
|
Citation key if exact match found, None otherwise
|
143
259
|
"""
|
144
260
|
|
261
|
+
def get_cite_key_by_title_fuzzy(self, title: str) -> Optional[str]:
|
262
|
+
"""Find citation key by fuzzy title match.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
title: Search term to find in bibliography entries
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
Citation key of best matching entry, or None if no good match
|
269
|
+
"""
|
270
|
+
|
145
271
|
def get_cite_key_fuzzy(self, query: str) -> Optional[str]:
|
146
272
|
"""Find best matching citation using fuzzy text search.
|
147
273
|
|
@@ -195,6 +321,7 @@ class BibManager:
|
|
195
321
|
Returns:
|
196
322
|
Abstract if found, None otherwise
|
197
323
|
"""
|
324
|
+
|
198
325
|
def get_title_by_key(self, key: str) -> Optional[str]:
|
199
326
|
"""Retrieve the title by citation key.
|
200
327
|
|