fabricatio 0.2.10.dev1__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +55 -10
- fabricatio/actions/article_rag.py +268 -14
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +3 -3
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +70 -0
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +43 -26
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/action.py +16 -11
- fabricatio/models/extra/aricle_rag.py +143 -9
- fabricatio/models/extra/article_base.py +56 -7
- fabricatio/models/extra/article_main.py +102 -6
- fabricatio/models/extra/problem.py +5 -1
- fabricatio/models/generic.py +31 -13
- fabricatio/models/kwargs_types.py +4 -2
- fabricatio/models/task.py +13 -1
- fabricatio/models/usages.py +10 -27
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +167 -62
- fabricatio/utils.py +38 -11
- fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -9
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/RECORD +31 -29
- fabricatio-0.2.10.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
@@ -14,7 +14,7 @@ from fabricatio.models.generic import Display, ProposedAble
|
|
14
14
|
from fabricatio.models.kwargs_types import CompositeScoreKwargs, ValidateKwargs
|
15
15
|
from fabricatio.parser import JsonCapture
|
16
16
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
17
|
-
from fabricatio.utils import
|
17
|
+
from fabricatio.utils import ok, override_kwargs
|
18
18
|
|
19
19
|
|
20
20
|
class Rating(Propose):
|
@@ -137,7 +137,7 @@ class Rating(Propose):
|
|
137
137
|
or dict(zip(criteria, criteria, strict=True))
|
138
138
|
)
|
139
139
|
|
140
|
-
return await self.rate_fine_grind(to_rate, manual, score_range, **
|
140
|
+
return await self.rate_fine_grind(to_rate, manual, score_range, **kwargs)
|
141
141
|
|
142
142
|
async def draft_rating_manual(
|
143
143
|
self, topic: str, criteria: Optional[Set[str]] = None, **kwargs: Unpack[ValidateKwargs[Dict[str, str]]]
|
@@ -338,6 +338,7 @@ class Rating(Propose):
|
|
338
338
|
criteria: Optional[Set[str]] = None,
|
339
339
|
weights: Optional[Dict[str, float]] = None,
|
340
340
|
manual: Optional[Dict[str, str]] = None,
|
341
|
+
approx: bool = False,
|
341
342
|
**kwargs: Unpack[ValidateKwargs[List[Dict[str, float]]]],
|
342
343
|
) -> List[float]:
|
343
344
|
"""Calculates the composite scores for a list of items based on a given topic and criteria.
|
@@ -348,6 +349,7 @@ class Rating(Propose):
|
|
348
349
|
criteria (Optional[Set[str]]): A set of criteria for the rating. Defaults to None.
|
349
350
|
weights (Optional[Dict[str, float]]): A dictionary of rating weights for each criterion. Defaults to None.
|
350
351
|
manual (Optional[Dict[str, str]]): A dictionary of manual ratings for each item. Defaults to None.
|
352
|
+
approx (bool): Whether to use approximate rating criteria. Defaults to False.
|
351
353
|
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
352
354
|
|
353
355
|
Returns:
|
@@ -355,6 +357,7 @@ class Rating(Propose):
|
|
355
357
|
"""
|
356
358
|
criteria = ok(
|
357
359
|
criteria
|
360
|
+
or (await self.draft_rating_criteria(topic, **override_kwargs(kwargs, default=None)) if approx else None)
|
358
361
|
or await self.draft_rating_criteria_from_examples(topic, to_rate, **override_kwargs(kwargs, default=None))
|
359
362
|
)
|
360
363
|
weights = ok(
|
fabricatio/capabilities/task.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
from types import CodeType
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple, Unpack
|
5
5
|
|
6
|
-
import
|
6
|
+
import ujson
|
7
7
|
|
8
8
|
from fabricatio.capabilities.propose import Propose
|
9
9
|
from fabricatio.config import configs
|
@@ -20,9 +20,9 @@ class ProposeTask(Propose):
|
|
20
20
|
"""A class that proposes a task based on a prompt."""
|
21
21
|
|
22
22
|
async def propose_task[T](
|
23
|
-
|
24
|
-
|
25
|
-
|
23
|
+
self,
|
24
|
+
prompt: str,
|
25
|
+
**kwargs: Unpack[ValidateKwargs[Task[T]]],
|
26
26
|
) -> Optional[Task[T]]:
|
27
27
|
"""Asynchronously proposes a task based on a given prompt and parameters.
|
28
28
|
|
@@ -44,11 +44,11 @@ class HandleTask(ToolBoxUsage):
|
|
44
44
|
"""A class that handles a task based on a task object."""
|
45
45
|
|
46
46
|
async def draft_tool_usage_code(
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
47
|
+
self,
|
48
|
+
task: Task,
|
49
|
+
tools: List[Tool],
|
50
|
+
data: Dict[str, Any],
|
51
|
+
**kwargs: Unpack[ValidateKwargs],
|
52
52
|
) -> Optional[Tuple[CodeType, List[str]]]:
|
53
53
|
"""Asynchronously drafts the tool usage code for a task based on a given task object and tools."""
|
54
54
|
logger.info(f"Drafting tool usage code for task: {task.briefing}")
|
@@ -60,7 +60,7 @@ class HandleTask(ToolBoxUsage):
|
|
60
60
|
|
61
61
|
def _validator(response: str) -> Tuple[CodeType, List[str]] | None:
|
62
62
|
if (source := PythonCapture.convert_with(response, lambda resp: compile(resp, "<string>", "exec"))) and (
|
63
|
-
|
63
|
+
to_extract := JsonCapture.convert_with(response, ujson.loads)
|
64
64
|
):
|
65
65
|
return source, to_extract
|
66
66
|
|
@@ -85,12 +85,12 @@ class HandleTask(ToolBoxUsage):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
async def handle_fine_grind(
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
88
|
+
self,
|
89
|
+
task: Task,
|
90
|
+
data: Dict[str, Any],
|
91
|
+
box_choose_kwargs: Optional[ChooseKwargs] = None,
|
92
|
+
tool_choose_kwargs: Optional[ChooseKwargs] = None,
|
93
|
+
**kwargs: Unpack[ValidateKwargs],
|
94
94
|
) -> Optional[Tuple]:
|
95
95
|
"""Asynchronously handles a task based on a given task object and parameters."""
|
96
96
|
logger.info(f"Handling task: \n{task.briefing}")
|
fabricatio/config.py
CHANGED
@@ -86,8 +86,10 @@ class LLMConfig(BaseModel):
|
|
86
86
|
|
87
87
|
tpm: Optional[PositiveInt] = Field(default=1000000)
|
88
88
|
"""The rate limit of the LLM model in tokens per minute. None means not checked."""
|
89
|
-
|
90
|
-
|
89
|
+
presence_penalty:Optional[PositiveFloat]=None
|
90
|
+
"""The presence penalty of the LLM model."""
|
91
|
+
frequency_penalty:Optional[PositiveFloat]=None
|
92
|
+
"""The frequency penalty of the LLM model."""
|
91
93
|
class EmbeddingConfig(BaseModel):
|
92
94
|
"""Embedding configuration class."""
|
93
95
|
|
@@ -249,6 +251,11 @@ class TemplateConfig(BaseModel):
|
|
249
251
|
|
250
252
|
rule_requirement_template: str = Field(default="rule_requirement")
|
251
253
|
"""The name of the rule requirement template which will be used to generate a rule requirement."""
|
254
|
+
|
255
|
+
|
256
|
+
extract_template: str = Field(default="extract")
|
257
|
+
"""The name of the extract template which will be used to extract model from string."""
|
258
|
+
|
252
259
|
class MagikaConfig(BaseModel):
|
253
260
|
"""Magika configuration class."""
|
254
261
|
|
fabricatio/decorators.py
CHANGED
@@ -6,14 +6,47 @@ from importlib.util import find_spec
|
|
6
6
|
from inspect import signature
|
7
7
|
from shutil import which
|
8
8
|
from types import ModuleType
|
9
|
-
from typing import Callable, List, Optional
|
10
|
-
|
11
|
-
from questionary import confirm
|
9
|
+
from typing import Callable, Coroutine, List, Optional
|
12
10
|
|
13
11
|
from fabricatio.config import configs
|
14
12
|
from fabricatio.journal import logger
|
15
13
|
|
16
14
|
|
15
|
+
def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
16
|
+
"""Check if a package exists in the current environment.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
package_name (str): The name of the package to check.
|
20
|
+
msg (str): The message to display if the package is not found.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
bool: True if the package exists, False otherwise.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def _wrapper(
|
27
|
+
func: Callable[P, R] | Callable[P, Coroutine[None, None, R]],
|
28
|
+
) -> Callable[P, R] | Callable[P, Coroutine[None, None, R]]:
|
29
|
+
if iscoroutinefunction(func):
|
30
|
+
|
31
|
+
@wraps(func)
|
32
|
+
async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
33
|
+
if find_spec(package_name):
|
34
|
+
return await func(*args, **kwargs)
|
35
|
+
raise RuntimeError(msg)
|
36
|
+
|
37
|
+
return _async_inner
|
38
|
+
|
39
|
+
@wraps(func)
|
40
|
+
def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
41
|
+
if find_spec(package_name):
|
42
|
+
return func(*args, **kwargs)
|
43
|
+
raise RuntimeError(msg)
|
44
|
+
|
45
|
+
return _inner
|
46
|
+
|
47
|
+
return _wrapper
|
48
|
+
|
49
|
+
|
17
50
|
def depend_on_external_cmd[**P, R](
|
18
51
|
bin_name: str, install_tip: Optional[str], homepage: Optional[str] = None
|
19
52
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
@@ -68,6 +101,9 @@ def logging_execution_info[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
68
101
|
return _wrapper
|
69
102
|
|
70
103
|
|
104
|
+
@precheck_package(
|
105
|
+
"questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
|
106
|
+
)
|
71
107
|
def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]] | Callable[P, R]:
|
72
108
|
"""Decorator to confirm before executing a function.
|
73
109
|
|
@@ -80,6 +116,7 @@ def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]]
|
|
80
116
|
if not configs.general.confirm_on_ops:
|
81
117
|
# Skip confirmation if the configuration is set to False
|
82
118
|
return func
|
119
|
+
from questionary import confirm
|
83
120
|
|
84
121
|
if iscoroutinefunction(func):
|
85
122
|
|
@@ -180,7 +217,9 @@ def use_temp_module[**P, R](modules: ModuleType | List[ModuleType]) -> Callable[
|
|
180
217
|
return _decorator
|
181
218
|
|
182
219
|
|
183
|
-
def logging_exec_time[**P, R](
|
220
|
+
def logging_exec_time[**P, R](
|
221
|
+
func: Callable[P, R] | Callable[P, Coroutine[None, None, R]],
|
222
|
+
) -> Callable[P, R] | Callable[P, Coroutine[None, None, R]]:
|
184
223
|
"""Decorator to log the execution time of a function.
|
185
224
|
|
186
225
|
Args:
|
@@ -210,25 +249,3 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
210
249
|
return result
|
211
250
|
|
212
251
|
return _wrapper
|
213
|
-
|
214
|
-
|
215
|
-
def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
216
|
-
"""Check if a package exists in the current environment.
|
217
|
-
|
218
|
-
Args:
|
219
|
-
package_name (str): The name of the package to check.
|
220
|
-
msg (str): The message to display if the package is not found.
|
221
|
-
|
222
|
-
Returns:
|
223
|
-
bool: True if the package exists, False otherwise.
|
224
|
-
"""
|
225
|
-
|
226
|
-
def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
|
227
|
-
def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
228
|
-
if find_spec(package_name):
|
229
|
-
return func(*args, **kwargs)
|
230
|
-
raise RuntimeError(msg)
|
231
|
-
|
232
|
-
return _inner
|
233
|
-
|
234
|
-
return _wrapper
|
fabricatio/fs/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""FileSystem manipulation module for Fabricatio."""
|
2
|
+
from importlib.util import find_spec
|
2
3
|
|
4
|
+
from fabricatio.config import configs
|
3
5
|
from fabricatio.fs.curd import (
|
4
6
|
absolute_path,
|
5
7
|
copy_file,
|
@@ -11,10 +13,9 @@ from fabricatio.fs.curd import (
|
|
11
13
|
move_file,
|
12
14
|
tree,
|
13
15
|
)
|
14
|
-
from fabricatio.fs.readers import
|
16
|
+
from fabricatio.fs.readers import safe_json_read, safe_text_read
|
15
17
|
|
16
18
|
__all__ = [
|
17
|
-
"MAGIKA",
|
18
19
|
"absolute_path",
|
19
20
|
"copy_file",
|
20
21
|
"create_directory",
|
@@ -27,3 +28,9 @@ __all__ = [
|
|
27
28
|
"safe_text_read",
|
28
29
|
"tree",
|
29
30
|
]
|
31
|
+
|
32
|
+
if find_spec("magika"):
|
33
|
+
from magika import Magika
|
34
|
+
|
35
|
+
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
36
|
+
__all__ += ["MAGIKA"]
|
fabricatio/fs/readers.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1
1
|
"""Filesystem readers for Fabricatio."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Dict, List, Tuple
|
5
6
|
|
6
|
-
import
|
7
|
-
import regex
|
8
|
-
from magika import Magika
|
7
|
+
import ujson
|
9
8
|
|
10
|
-
from fabricatio.config import configs
|
11
9
|
from fabricatio.journal import logger
|
12
10
|
|
13
|
-
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
14
|
-
|
15
11
|
|
16
12
|
def safe_text_read(path: Path | str) -> str:
|
17
13
|
"""Safely read the text from a file.
|
@@ -41,8 +37,8 @@ def safe_json_read(path: Path | str) -> Dict:
|
|
41
37
|
"""
|
42
38
|
path = Path(path)
|
43
39
|
try:
|
44
|
-
return
|
45
|
-
except (
|
40
|
+
return ujson.loads(path.read_text(encoding="utf-8"))
|
41
|
+
except (ujson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
|
46
42
|
logger.error(f"Failed to read file {path}: {e!s}")
|
47
43
|
return {}
|
48
44
|
|
@@ -58,8 +54,8 @@ def extract_sections(string: str, level: int, section_char: str = "#") -> List[T
|
|
58
54
|
Returns:
|
59
55
|
List[Tuple[str, str]]: List of (header_text, section_content) tuples
|
60
56
|
"""
|
61
|
-
return
|
57
|
+
return re.findall(
|
62
58
|
r"^%s{%d}\s+(.+?)\n((?:(?!^%s{%d}\s).|\n)*)" % (section_char, level, section_char, level),
|
63
59
|
string,
|
64
|
-
|
60
|
+
re.MULTILINE,
|
65
61
|
)
|
fabricatio/models/action.py
CHANGED
@@ -12,12 +12,13 @@ Classes:
|
|
12
12
|
import traceback
|
13
13
|
from abc import abstractmethod
|
14
14
|
from asyncio import Queue, create_task
|
15
|
-
from typing import Any, Dict, Self, Tuple, Type, Union, final
|
15
|
+
from typing import Any, Dict, Self, Sequence, Tuple, Type, Union, final
|
16
16
|
|
17
17
|
from fabricatio.journal import logger
|
18
18
|
from fabricatio.models.generic import WithBriefing
|
19
19
|
from fabricatio.models.task import Task
|
20
20
|
from fabricatio.models.usages import LLMUsage, ToolBoxUsage
|
21
|
+
from fabricatio.utils import override_kwargs
|
21
22
|
from pydantic import Field, PrivateAttr
|
22
23
|
|
23
24
|
OUTPUT_KEY = "task_output"
|
@@ -55,7 +56,7 @@ class Action(WithBriefing, LLMUsage):
|
|
55
56
|
self.description = self.description or self.__class__.__doc__ or ""
|
56
57
|
|
57
58
|
@abstractmethod
|
58
|
-
async def _execute(self, *_:Any, **cxt) -> Any:
|
59
|
+
async def _execute(self, *_: Any, **cxt) -> Any:
|
59
60
|
"""Implement the core logic of the action.
|
60
61
|
|
61
62
|
Args:
|
@@ -95,11 +96,12 @@ class Action(WithBriefing, LLMUsage):
|
|
95
96
|
return f"## Your personality: \n{self.personality}\n# The action you are going to perform: \n{super().briefing}"
|
96
97
|
return f"# The action you are going to perform: \n{super().briefing}"
|
97
98
|
|
98
|
-
def to_task_output(self)->Self:
|
99
|
+
def to_task_output(self, task_output_key: str = OUTPUT_KEY) -> Self:
|
99
100
|
"""Set the output key to OUTPUT_KEY and return the action instance."""
|
100
|
-
self.output_key=
|
101
|
+
self.output_key = task_output_key
|
101
102
|
return self
|
102
103
|
|
104
|
+
|
103
105
|
class WorkFlow(WithBriefing, ToolBoxUsage):
|
104
106
|
"""Manages sequences of actions to fulfill tasks.
|
105
107
|
|
@@ -121,9 +123,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
121
123
|
_instances: Tuple[Action, ...] = PrivateAttr(default_factory=tuple)
|
122
124
|
"""Instantiated action objects to be executed in this workflow."""
|
123
125
|
|
124
|
-
steps:
|
125
|
-
frozen=True,
|
126
|
-
)
|
126
|
+
steps: Sequence[Union[Type[Action], Action]] = Field(frozen=True)
|
127
127
|
"""The sequence of actions to be executed, can be action classes or instances."""
|
128
128
|
|
129
129
|
task_input_key: str = Field(default=INPUT_KEY)
|
@@ -177,7 +177,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
177
177
|
current_action = None
|
178
178
|
try:
|
179
179
|
# Process each action in sequence
|
180
|
-
for i,step in enumerate(self._instances):
|
180
|
+
for i, step in enumerate(self._instances):
|
181
181
|
current_action = step.name
|
182
182
|
logger.info(f"Executing step [{i}] >> {current_action}")
|
183
183
|
|
@@ -227,8 +227,13 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
227
227
|
- Any extra_init_context values
|
228
228
|
"""
|
229
229
|
logger.debug(f"Initializing context for workflow: {self.name}")
|
230
|
-
|
231
|
-
|
230
|
+
ctx = override_kwargs(self.extra_init_context, **task.extra_init_context)
|
231
|
+
if self.task_input_key in ctx:
|
232
|
+
raise ValueError(
|
233
|
+
f"Task input key: `{self.task_input_key}`, which is reserved, is already set in the init context"
|
234
|
+
)
|
235
|
+
|
236
|
+
await self._context.put({self.task_input_key: task, **ctx})
|
232
237
|
|
233
238
|
def steps_fallback_to_self(self) -> Self:
|
234
239
|
"""Configure all steps to use this workflow's configuration as fallback.
|
@@ -245,7 +250,7 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
245
250
|
Returns:
|
246
251
|
Self: The workflow instance for method chaining.
|
247
252
|
"""
|
248
|
-
self.provide_tools_to(i for i in self._instances if isinstance(i,ToolBoxUsage))
|
253
|
+
self.provide_tools_to(i for i in self._instances if isinstance(i, ToolBoxUsage))
|
249
254
|
return self
|
250
255
|
|
251
256
|
def update_init_context(self, /, **kwargs) -> Self:
|
@@ -1,22 +1,27 @@
|
|
1
1
|
"""A Module containing the article rag models."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import ClassVar, Dict, List, Self, Unpack
|
5
|
+
from typing import ClassVar, Dict, List, Optional, Self, Unpack
|
5
6
|
|
6
7
|
from fabricatio.fs import safe_text_read
|
7
8
|
from fabricatio.journal import logger
|
8
9
|
from fabricatio.models.extra.rag import MilvusDataBase
|
9
10
|
from fabricatio.models.generic import AsPrompt
|
10
11
|
from fabricatio.models.kwargs_types import ChunkKwargs
|
11
|
-
from fabricatio.rust import BibManager, split_into_chunks
|
12
|
-
from fabricatio.utils import ok
|
13
|
-
from more_itertools.recipes import flatten
|
12
|
+
from fabricatio.rust import BibManager, blake3_hash, is_chinese, split_into_chunks
|
13
|
+
from fabricatio.utils import ok
|
14
|
+
from more_itertools.recipes import flatten, unique
|
14
15
|
from pydantic import Field
|
15
16
|
|
16
17
|
|
17
18
|
class ArticleChunk(MilvusDataBase, AsPrompt):
|
18
19
|
"""The chunk of an article."""
|
19
20
|
|
21
|
+
etc_word: ClassVar[str] = "等"
|
22
|
+
and_word: ClassVar[str] = "与"
|
23
|
+
_cite_number: Optional[int] = None
|
24
|
+
|
20
25
|
head_split: ClassVar[List[str]] = [
|
21
26
|
"引 言",
|
22
27
|
"引言",
|
@@ -48,12 +53,14 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
48
53
|
|
49
54
|
def _as_prompt_inner(self) -> Dict[str, str]:
|
50
55
|
return {
|
51
|
-
|
52
|
-
f"Authors: {';'.join(self.authors)}\n"
|
53
|
-
f"Published Year: {self.year}\n"
|
54
|
-
f"Bibtex Key: {self.bibtex_cite_key}\n",
|
56
|
+
f"[[{ok(self._cite_number, 'You need to update cite number first.')}]] reference `{self.article_title}` from {self.as_auther_seq()}": self.chunk
|
55
57
|
}
|
56
58
|
|
59
|
+
@property
|
60
|
+
def cite_number(self) -> int:
|
61
|
+
"""Get the cite number."""
|
62
|
+
return ok(self._cite_number, "cite number not set")
|
63
|
+
|
57
64
|
def _prepare_vectorization_inner(self) -> str:
|
58
65
|
return self.chunk
|
59
66
|
|
@@ -89,8 +96,9 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
89
96
|
|
90
97
|
result = [
|
91
98
|
cls(chunk=c, year=year, authors=authors, article_title=article_title, bibtex_cite_key=key)
|
92
|
-
for c in split_into_chunks(cls.strip(safe_text_read(path)), **kwargs)
|
99
|
+
for c in split_into_chunks(cls.purge_numeric_citation(cls.strip(safe_text_read(path))), **kwargs)
|
93
100
|
]
|
101
|
+
|
94
102
|
logger.debug(f"Number of chunks created from file {path.as_posix()}: {len(result)}")
|
95
103
|
return result
|
96
104
|
|
@@ -118,3 +126,129 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
118
126
|
logger.warning("No decrease at tail strip, which is might be abnormal.")
|
119
127
|
|
120
128
|
return string
|
129
|
+
|
130
|
+
def as_typst_cite(self) -> str:
|
131
|
+
"""As typst cite."""
|
132
|
+
return f"#cite(<{self.bibtex_cite_key}>)"
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def purge_numeric_citation(string: str) -> str:
|
136
|
+
"""Purge numeric citation."""
|
137
|
+
import re
|
138
|
+
|
139
|
+
return re.sub(r"\[[\d\s,\\~–-]+]", "", string)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def auther_firstnames(self) -> List[str]:
|
143
|
+
"""Get the first name of the authors."""
|
144
|
+
ret = []
|
145
|
+
for n in self.authors:
|
146
|
+
if is_chinese(n):
|
147
|
+
ret.append(n[0])
|
148
|
+
else:
|
149
|
+
ret.append(n.split()[-1])
|
150
|
+
return ret
|
151
|
+
|
152
|
+
def as_auther_seq(self) -> str:
|
153
|
+
"""Get the auther sequence."""
|
154
|
+
match len(self.authors):
|
155
|
+
case 0:
|
156
|
+
raise ValueError("No authors found")
|
157
|
+
case 1:
|
158
|
+
return f"({self.auther_firstnames[0]},{self.year}){self.as_typst_cite()}"
|
159
|
+
case 2:
|
160
|
+
return f"({self.auther_firstnames[0]}{self.and_word}{self.auther_firstnames[1]},{self.year}){self.as_typst_cite()}"
|
161
|
+
case 3:
|
162
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]},{self.year}){self.as_typst_cite()}"
|
163
|
+
case _:
|
164
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]}{self.etc_word},{self.year}){self.as_typst_cite()}"
|
165
|
+
|
166
|
+
def update_cite_number(self, cite_number: int) -> Self:
|
167
|
+
"""Update the cite number."""
|
168
|
+
self._cite_number = cite_number
|
169
|
+
return self
|
170
|
+
|
171
|
+
|
172
|
+
class CitationManager(AsPrompt):
|
173
|
+
"""Citation manager."""
|
174
|
+
|
175
|
+
article_chunks: List[ArticleChunk] = Field(default_factory=list)
|
176
|
+
"""Article chunks."""
|
177
|
+
|
178
|
+
pat: str = r"(\[\[([\d\s,-]*)]])"
|
179
|
+
"""Regex pattern to match citations."""
|
180
|
+
sep: str = ","
|
181
|
+
"""Separator for citation numbers."""
|
182
|
+
abbr_sep: str = "-"
|
183
|
+
"""Separator for abbreviated citation numbers."""
|
184
|
+
|
185
|
+
def update_chunks(
|
186
|
+
self, article_chunks: List[ArticleChunk], set_cite_number: bool = True, dedup: bool = True
|
187
|
+
) -> Self:
|
188
|
+
"""Update article chunks."""
|
189
|
+
self.article_chunks.clear()
|
190
|
+
self.article_chunks.extend(article_chunks)
|
191
|
+
if dedup:
|
192
|
+
self.article_chunks = list(unique(self.article_chunks, lambda c: blake3_hash(c.chunk.encode())))
|
193
|
+
if set_cite_number:
|
194
|
+
self.set_cite_number_all()
|
195
|
+
return self
|
196
|
+
|
197
|
+
def empty(self) -> Self:
|
198
|
+
"""Empty the article chunks."""
|
199
|
+
self.article_chunks.clear()
|
200
|
+
return self
|
201
|
+
|
202
|
+
def add_chunks(self, article_chunks: List[ArticleChunk], set_cite_number: bool = True, dedup: bool = True) -> Self:
|
203
|
+
"""Add article chunks."""
|
204
|
+
self.article_chunks.extend(article_chunks)
|
205
|
+
if dedup:
|
206
|
+
self.article_chunks = list(unique(self.article_chunks, lambda c: blake3_hash(c.chunk.encode())))
|
207
|
+
if set_cite_number:
|
208
|
+
self.set_cite_number_all()
|
209
|
+
return self
|
210
|
+
|
211
|
+
def set_cite_number_all(self) -> Self:
|
212
|
+
"""Set citation numbers for all article chunks."""
|
213
|
+
for i, a in enumerate(self.article_chunks, 1):
|
214
|
+
a.update_cite_number(i)
|
215
|
+
return self
|
216
|
+
|
217
|
+
def _as_prompt_inner(self) -> Dict[str, str]:
|
218
|
+
"""Generate prompt inner representation."""
|
219
|
+
return {"References": "\n".join(r.as_prompt() for r in self.article_chunks)}
|
220
|
+
|
221
|
+
def apply(self, string: str) -> str:
|
222
|
+
"""Apply citation replacements to the input string."""
|
223
|
+
for origin, m in re.findall(self.pat, string):
|
224
|
+
logger.info(f"Matching citation: {m}")
|
225
|
+
notations = self.convert_to_numeric_notations(m)
|
226
|
+
logger.info(f"Citing Notations: {notations}")
|
227
|
+
citation_number_seq = list(flatten(self.decode_expr(n) for n in notations))
|
228
|
+
logger.info(f"Citation Number Sequence: {citation_number_seq}")
|
229
|
+
dedup = self.deduplicate_citation(citation_number_seq)
|
230
|
+
logger.info(f"Deduplicated Citation Number Sequence: {dedup}")
|
231
|
+
string = string.replace(origin, self.unpack_cite_seq(dedup))
|
232
|
+
return string
|
233
|
+
|
234
|
+
def decode_expr(self, string: str) -> List[int]:
|
235
|
+
"""Decode citation expression into a list of integers."""
|
236
|
+
if self.abbr_sep in string:
|
237
|
+
start, end = string.split(self.abbr_sep)
|
238
|
+
return list(range(int(start), int(end) + 1))
|
239
|
+
return [int(string)]
|
240
|
+
|
241
|
+
def convert_to_numeric_notations(self, string: str) -> List[str]:
|
242
|
+
"""Convert citation string into numeric notations."""
|
243
|
+
return [s.strip() for s in string.split(self.sep)]
|
244
|
+
|
245
|
+
def deduplicate_citation(self, citation_seq: List[int]) -> List[int]:
|
246
|
+
"""Deduplicate citation sequence."""
|
247
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
248
|
+
deduped = unique(chunk_seq, lambda a: a.bibtex_cite_key)
|
249
|
+
return [a.cite_number for a in deduped]
|
250
|
+
|
251
|
+
def unpack_cite_seq(self, citation_seq: List[int]) -> str:
|
252
|
+
"""Unpack citation sequence into a string."""
|
253
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
254
|
+
return "".join(a.as_typst_cite() for a in chunk_seq)
|