fabricatio 0.3.15.dev5__cp312-cp312-win_amd64.whl → 0.4.5.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +7 -8
- fabricatio/actions/__init__.py +69 -1
- fabricatio/capabilities/__init__.py +63 -1
- fabricatio/models/__init__.py +51 -0
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/toolboxes/__init__.py +2 -1
- fabricatio/toolboxes/arithmetic.py +1 -1
- fabricatio/toolboxes/fs.py +2 -2
- fabricatio/workflows/__init__.py +9 -0
- fabricatio-0.4.5.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.5.dev0.dist-info}/METADATA +58 -27
- fabricatio-0.4.5.dev0.dist-info/RECORD +15 -0
- fabricatio/actions/article.py +0 -415
- fabricatio/actions/article_rag.py +0 -407
- fabricatio/actions/fs.py +0 -25
- fabricatio/actions/output.py +0 -247
- fabricatio/actions/rag.py +0 -96
- fabricatio/actions/rules.py +0 -83
- fabricatio/capabilities/advanced_judge.py +0 -20
- fabricatio/capabilities/advanced_rag.py +0 -61
- fabricatio/capabilities/censor.py +0 -105
- fabricatio/capabilities/check.py +0 -212
- fabricatio/capabilities/correct.py +0 -228
- fabricatio/capabilities/extract.py +0 -74
- fabricatio/capabilities/propose.py +0 -65
- fabricatio/capabilities/rag.py +0 -264
- fabricatio/capabilities/rating.py +0 -404
- fabricatio/capabilities/review.py +0 -114
- fabricatio/capabilities/task.py +0 -113
- fabricatio/decorators.py +0 -253
- fabricatio/emitter.py +0 -177
- fabricatio/fs/__init__.py +0 -35
- fabricatio/fs/curd.py +0 -153
- fabricatio/fs/readers.py +0 -61
- fabricatio/journal.py +0 -12
- fabricatio/models/action.py +0 -263
- fabricatio/models/adv_kwargs_types.py +0 -63
- fabricatio/models/extra/__init__.py +0 -1
- fabricatio/models/extra/advanced_judge.py +0 -32
- fabricatio/models/extra/aricle_rag.py +0 -286
- fabricatio/models/extra/article_base.py +0 -488
- fabricatio/models/extra/article_essence.py +0 -98
- fabricatio/models/extra/article_main.py +0 -285
- fabricatio/models/extra/article_outline.py +0 -45
- fabricatio/models/extra/article_proposal.py +0 -52
- fabricatio/models/extra/patches.py +0 -20
- fabricatio/models/extra/problem.py +0 -165
- fabricatio/models/extra/rag.py +0 -98
- fabricatio/models/extra/rule.py +0 -51
- fabricatio/models/generic.py +0 -904
- fabricatio/models/kwargs_types.py +0 -121
- fabricatio/models/role.py +0 -156
- fabricatio/models/task.py +0 -310
- fabricatio/models/tool.py +0 -328
- fabricatio/models/usages.py +0 -791
- fabricatio/parser.py +0 -114
- fabricatio/rust.pyi +0 -846
- fabricatio/utils.py +0 -156
- fabricatio/workflows/articles.py +0 -24
- fabricatio/workflows/rag.py +0 -11
- fabricatio-0.3.15.dev5.data/scripts/tdown.exe +0 -0
- fabricatio-0.3.15.dev5.data/scripts/ttm.exe +0 -0
- fabricatio-0.3.15.dev5.dist-info/RECORD +0 -63
- {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.5.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.5.dev0.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/action.py
DELETED
@@ -1,263 +0,0 @@
|
|
1
|
-
"""Module that contains the classes for defining and executing task workflows.
|
2
|
-
|
3
|
-
This module provides the Action and WorkFlow classes for creating structured
|
4
|
-
task execution pipelines. Actions represent atomic operations, while WorkFlows
|
5
|
-
orchestrate sequences of actions with shared context and error handling.
|
6
|
-
|
7
|
-
Classes:
|
8
|
-
Action: Base class for defining executable actions with context management.
|
9
|
-
WorkFlow: Manages action sequences, context propagation, and task lifecycle.
|
10
|
-
"""
|
11
|
-
|
12
|
-
import traceback
|
13
|
-
from abc import abstractmethod
|
14
|
-
from asyncio import Queue, create_task
|
15
|
-
from typing import Any, ClassVar, Dict, Generator, Self, Sequence, Tuple, Type, Union, final
|
16
|
-
|
17
|
-
from fabricatio.journal import logger
|
18
|
-
from fabricatio.models.generic import WithBriefing
|
19
|
-
from fabricatio.models.task import Task
|
20
|
-
from fabricatio.utils import override_kwargs
|
21
|
-
from pydantic import Field, PrivateAttr
|
22
|
-
|
23
|
-
OUTPUT_KEY = "task_output"
|
24
|
-
|
25
|
-
INPUT_KEY = "task_input"
|
26
|
-
|
27
|
-
|
28
|
-
class Action(WithBriefing):
|
29
|
-
"""Class that represents an action to be executed in a workflow.
|
30
|
-
|
31
|
-
Actions are the atomic units of work in a workflow. Each action performs
|
32
|
-
a specific operation and can modify the shared context data.
|
33
|
-
"""
|
34
|
-
|
35
|
-
ctx_override: ClassVar[bool] = False
|
36
|
-
"""Whether to override the instance attr by the context variable."""
|
37
|
-
|
38
|
-
name: str = Field(default="")
|
39
|
-
"""The name of the action."""
|
40
|
-
|
41
|
-
description: str = Field(default="")
|
42
|
-
"""The description of the action."""
|
43
|
-
|
44
|
-
personality: str = Field(default="")
|
45
|
-
"""The personality traits or context for the action executor."""
|
46
|
-
|
47
|
-
output_key: str = Field(default="")
|
48
|
-
"""The key used to store this action's output in the context dictionary."""
|
49
|
-
|
50
|
-
@final
|
51
|
-
def model_post_init(self, __context: Any) -> None:
|
52
|
-
"""Initialize the action by setting default name and description if not provided.
|
53
|
-
|
54
|
-
Args:
|
55
|
-
__context: The context to be used for initialization.
|
56
|
-
"""
|
57
|
-
self.name = self.name or self.__class__.__name__
|
58
|
-
self.description = self.description or self.__class__.__doc__ or ""
|
59
|
-
|
60
|
-
@abstractmethod
|
61
|
-
async def _execute(self, *_: Any, **cxt) -> Any:
|
62
|
-
"""Implement the core logic of the action.
|
63
|
-
|
64
|
-
Args:
|
65
|
-
**cxt: Context dictionary containing input/output data.
|
66
|
-
|
67
|
-
Returns:
|
68
|
-
Result of the action execution to be stored in context.
|
69
|
-
"""
|
70
|
-
pass
|
71
|
-
|
72
|
-
@final
|
73
|
-
async def act(self, cxt: Dict[str, Any]) -> Dict[str, Any]:
|
74
|
-
"""Execute action and update context.
|
75
|
-
|
76
|
-
Args:
|
77
|
-
cxt (Dict[str, Any]): Shared context dictionary.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
Updated context dictionary with new/modified entries.
|
81
|
-
"""
|
82
|
-
ret = await self._execute(**cxt)
|
83
|
-
|
84
|
-
if self.output_key:
|
85
|
-
logger.debug(f"Setting output: {self.output_key}")
|
86
|
-
cxt[self.output_key] = ret
|
87
|
-
|
88
|
-
return cxt
|
89
|
-
|
90
|
-
@property
|
91
|
-
def briefing(self) -> str:
|
92
|
-
"""Generate formatted action description with personality context.
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
Briefing text combining personality and action description.
|
96
|
-
"""
|
97
|
-
if self.personality:
|
98
|
-
return f"## Your personality: \n{self.personality}\n# The action you are going to perform: \n{super().briefing}"
|
99
|
-
return f"# The action you are going to perform: \n{super().briefing}"
|
100
|
-
|
101
|
-
def to_task_output(self, to: Union[str, "WorkFlow"] = OUTPUT_KEY) -> Self:
|
102
|
-
"""Set the output key to OUTPUT_KEY and return the action instance."""
|
103
|
-
self.output_key = to.task_output_key if isinstance(to, WorkFlow) else to
|
104
|
-
return self
|
105
|
-
|
106
|
-
|
107
|
-
class WorkFlow(WithBriefing):
|
108
|
-
"""Manages sequences of actions to fulfill tasks.
|
109
|
-
|
110
|
-
Handles context propagation between actions, error handling, and task lifecycle
|
111
|
-
events like cancellation and completion.
|
112
|
-
"""
|
113
|
-
|
114
|
-
name: str = "WorkFlow"
|
115
|
-
"""The name of the workflow, which is used to identify and describe the workflow."""
|
116
|
-
description: str = ""
|
117
|
-
"""The description of the workflow, which describes the workflow's purpose and requirements."""
|
118
|
-
|
119
|
-
_context: Queue[Dict[str, Any]] = PrivateAttr(default_factory=lambda: Queue(maxsize=1))
|
120
|
-
"""Queue for storing the workflow execution context."""
|
121
|
-
|
122
|
-
_instances: Tuple[Action, ...] = PrivateAttr(default_factory=tuple)
|
123
|
-
"""Instantiated action objects to be executed in this workflow."""
|
124
|
-
|
125
|
-
steps: Sequence[Union[Type[Action], Action]] = Field(frozen=True)
|
126
|
-
"""The sequence of actions to be executed, can be action classes or instances."""
|
127
|
-
|
128
|
-
task_input_key: ClassVar[str] = INPUT_KEY
|
129
|
-
"""Key used to store the input task in the context dictionary."""
|
130
|
-
|
131
|
-
task_output_key: ClassVar[str] = OUTPUT_KEY
|
132
|
-
"""Key used to extract the final result from the context dictionary."""
|
133
|
-
|
134
|
-
extra_init_context: Dict[str, Any] = Field(default_factory=dict, frozen=True)
|
135
|
-
"""Additional initial context values to be included at workflow start."""
|
136
|
-
|
137
|
-
def model_post_init(self, __context: Any) -> None:
|
138
|
-
"""Initialize the workflow by instantiating any action classes.
|
139
|
-
|
140
|
-
Args:
|
141
|
-
__context: The context to be used for initialization.
|
142
|
-
|
143
|
-
"""
|
144
|
-
self.name = self.name or self.__class__.__name__
|
145
|
-
# Convert any action classes to instances
|
146
|
-
self._instances = tuple(step if isinstance(step, Action) else step() for step in self.steps)
|
147
|
-
|
148
|
-
def iter_actions(self) -> Generator[Action, None, None]:
|
149
|
-
"""Iterate over action instances."""
|
150
|
-
yield from self._instances
|
151
|
-
|
152
|
-
def inject_personality(self, personality: str) -> Self:
|
153
|
-
"""Set personality for actions without existing personality.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
personality (str): Shared personality context
|
157
|
-
|
158
|
-
Returns:
|
159
|
-
Workflow instance with updated actions
|
160
|
-
"""
|
161
|
-
for action in filter(lambda a: not a.personality, self._instances):
|
162
|
-
action.personality = personality
|
163
|
-
return self
|
164
|
-
|
165
|
-
def override_action_variable(self, action: Action, ctx: Dict[str, Any]) -> Self:
|
166
|
-
"""Override action variable with context values."""
|
167
|
-
if action.ctx_override:
|
168
|
-
for k, v in ctx.items():
|
169
|
-
if hasattr(action, k):
|
170
|
-
setattr(action, k, v)
|
171
|
-
|
172
|
-
return self
|
173
|
-
|
174
|
-
async def serve(self, task: Task) -> None:
|
175
|
-
"""Execute workflow to complete given task.
|
176
|
-
|
177
|
-
Args:
|
178
|
-
task (Task): Task instance to be processed.
|
179
|
-
|
180
|
-
Steps:
|
181
|
-
1. Initialize context with task instance and extra data
|
182
|
-
2. Execute each action sequentially
|
183
|
-
3. Handle task cancellation and exceptions
|
184
|
-
4. Extract final result from context
|
185
|
-
"""
|
186
|
-
logger.info(f"Start execute workflow: {self.name}")
|
187
|
-
|
188
|
-
await task.start()
|
189
|
-
await self._init_context(task)
|
190
|
-
|
191
|
-
current_action = None
|
192
|
-
try:
|
193
|
-
# Process each action in sequence
|
194
|
-
for i, step in enumerate(self._instances):
|
195
|
-
logger.info(f"Executing step [{i}] >> {(current_action := step.name)}")
|
196
|
-
|
197
|
-
# Get current context and execute action
|
198
|
-
context = await self._context.get()
|
199
|
-
|
200
|
-
self.override_action_variable(step, context)
|
201
|
-
act_task = create_task(step.act(context))
|
202
|
-
# Handle task cancellation
|
203
|
-
if task.is_cancelled():
|
204
|
-
logger.warning(f"Workflow cancelled by task: {task.name}")
|
205
|
-
act_task.cancel(f"Cancelled by task: {task.name}")
|
206
|
-
break
|
207
|
-
|
208
|
-
# Update context with modified values
|
209
|
-
modified_ctx = await act_task
|
210
|
-
logger.success(f"Step [{i}] `{current_action}` execution finished.")
|
211
|
-
if step.output_key:
|
212
|
-
logger.success(f"Setting action `{current_action}` output to `{step.output_key}`")
|
213
|
-
await self._context.put(modified_ctx)
|
214
|
-
|
215
|
-
logger.success(f"Workflow `{self.name}` execution finished.")
|
216
|
-
|
217
|
-
# Get final context and extract result
|
218
|
-
final_ctx = await self._context.get()
|
219
|
-
result = final_ctx.get(self.task_output_key)
|
220
|
-
|
221
|
-
if self.task_output_key not in final_ctx:
|
222
|
-
logger.warning(
|
223
|
-
f"Task output key: `{self.task_output_key}` not found in the context, None will be returned. "
|
224
|
-
f"You can check if `Action.output_key` is set the same as `WorkFlow.task_output_key`."
|
225
|
-
)
|
226
|
-
|
227
|
-
await task.finish(result)
|
228
|
-
|
229
|
-
except Exception as e: # noqa: BLE001
|
230
|
-
logger.critical(f"Error during task: {current_action} execution: {e}")
|
231
|
-
logger.critical(traceback.format_exc())
|
232
|
-
await task.fail()
|
233
|
-
|
234
|
-
async def _init_context[T](self, task: Task[T]) -> None:
|
235
|
-
"""Initialize workflow execution context.
|
236
|
-
|
237
|
-
Args:
|
238
|
-
task (Task[T]): Task being processed
|
239
|
-
|
240
|
-
Context includes:
|
241
|
-
- Task instance stored under task_input_key
|
242
|
-
- Any extra_init_context values
|
243
|
-
"""
|
244
|
-
logger.debug(f"Initializing context for workflow: {self.name}")
|
245
|
-
ctx = override_kwargs(self.extra_init_context, **task.extra_init_context)
|
246
|
-
if self.task_input_key in ctx:
|
247
|
-
raise ValueError(
|
248
|
-
f"Task input key: `{self.task_input_key}`, which is reserved, is already set in the init context"
|
249
|
-
)
|
250
|
-
|
251
|
-
await self._context.put({self.task_input_key: task, **ctx})
|
252
|
-
|
253
|
-
def update_init_context(self, /, **kwargs) -> Self:
|
254
|
-
"""Update the initial context with additional key-value pairs.
|
255
|
-
|
256
|
-
Args:
|
257
|
-
**kwargs: Key-value pairs to add to the initial context.
|
258
|
-
|
259
|
-
Returns:
|
260
|
-
Self: The workflow instance for method chaining.
|
261
|
-
"""
|
262
|
-
self.extra_init_context.update(kwargs)
|
263
|
-
return self
|
@@ -1,63 +0,0 @@
|
|
1
|
-
"""A module containing kwargs types for content correction and checking operations."""
|
2
|
-
|
3
|
-
from importlib.util import find_spec
|
4
|
-
from typing import NotRequired, Optional, TypedDict
|
5
|
-
|
6
|
-
from fabricatio.models.extra.problem import Improvement
|
7
|
-
from fabricatio.models.extra.rule import RuleSet
|
8
|
-
from fabricatio.models.generic import SketchedAble
|
9
|
-
from fabricatio.models.kwargs_types import ReferencedKwargs
|
10
|
-
|
11
|
-
|
12
|
-
class CorrectKwargs[T: SketchedAble](ReferencedKwargs[T], total=False):
|
13
|
-
"""Arguments for content correction operations.
|
14
|
-
|
15
|
-
Extends GenerateKwargs with parameters for correcting content based on
|
16
|
-
specific criteria and templates.
|
17
|
-
"""
|
18
|
-
|
19
|
-
improvement: Improvement
|
20
|
-
|
21
|
-
|
22
|
-
class CheckKwargs(ReferencedKwargs[Improvement], total=False):
|
23
|
-
"""Arguments for content checking operations.
|
24
|
-
|
25
|
-
Extends GenerateKwargs with parameters for checking content against
|
26
|
-
specific criteria and templates.
|
27
|
-
"""
|
28
|
-
|
29
|
-
ruleset: RuleSet
|
30
|
-
|
31
|
-
|
32
|
-
if find_spec("pymilvus"):
|
33
|
-
from pymilvus import CollectionSchema
|
34
|
-
from pymilvus.milvus_client import IndexParams
|
35
|
-
|
36
|
-
class CollectionConfigKwargs(TypedDict, total=False):
|
37
|
-
"""Configuration parameters for a vector collection.
|
38
|
-
|
39
|
-
These arguments are typically used when configuring connections to vector databases.
|
40
|
-
"""
|
41
|
-
|
42
|
-
dimension: int | None
|
43
|
-
primary_field_name: str
|
44
|
-
id_type: str
|
45
|
-
vector_field_name: str
|
46
|
-
metric_type: str
|
47
|
-
timeout: float | None
|
48
|
-
schema: CollectionSchema | None
|
49
|
-
index_params: IndexParams | None
|
50
|
-
|
51
|
-
class FetchKwargs(TypedDict):
|
52
|
-
"""Arguments for fetching data from vector collections.
|
53
|
-
|
54
|
-
Controls how data is retrieved from vector databases, including filtering
|
55
|
-
and result limiting parameters.
|
56
|
-
"""
|
57
|
-
|
58
|
-
collection_name: NotRequired[str | None]
|
59
|
-
similarity_threshold: NotRequired[float]
|
60
|
-
result_per_query: NotRequired[int]
|
61
|
-
tei_endpoint: NotRequired[Optional[str]]
|
62
|
-
reranker_threshold: NotRequired[float]
|
63
|
-
filter_expr: NotRequired[str]
|
@@ -1 +0,0 @@
|
|
1
|
-
"""A module contains extra models for fabricatio."""
|
@@ -1,32 +0,0 @@
|
|
1
|
-
"""Module containing the JudgeMent class for holding judgment results."""
|
2
|
-
|
3
|
-
from typing import List
|
4
|
-
|
5
|
-
from fabricatio.models.generic import SketchedAble
|
6
|
-
|
7
|
-
|
8
|
-
class JudgeMent(SketchedAble):
|
9
|
-
"""Represents a judgment result containing supporting/denying evidence and final verdict.
|
10
|
-
|
11
|
-
The class stores both affirmative and denies evidence, truth and reasons lists along with the final boolean judgment.
|
12
|
-
"""
|
13
|
-
|
14
|
-
issue_to_judge: str
|
15
|
-
"""The issue to be judged, including the original question and context"""
|
16
|
-
|
17
|
-
deny_evidence: List[str]
|
18
|
-
"""List of clues supporting the denial."""
|
19
|
-
|
20
|
-
affirm_evidence: List[str]
|
21
|
-
"""List of clues supporting the affirmation."""
|
22
|
-
|
23
|
-
final_judgement: bool
|
24
|
-
"""The final judgment made according to all extracted clues. true for the `issue_to_judge` is correct and false for incorrect."""
|
25
|
-
|
26
|
-
def __bool__(self) -> bool:
|
27
|
-
"""Return the final judgment value.
|
28
|
-
|
29
|
-
Returns:
|
30
|
-
bool: The stored final_judgement value indicating the judgment result.
|
31
|
-
"""
|
32
|
-
return self.final_judgement
|
@@ -1,286 +0,0 @@
|
|
1
|
-
"""A Module containing the article rag models."""
|
2
|
-
|
3
|
-
import re
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from itertools import groupby
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import ClassVar, Dict, List, Optional, Self, Unpack
|
8
|
-
|
9
|
-
from fabricatio.fs import safe_text_read
|
10
|
-
from fabricatio.journal import logger
|
11
|
-
from fabricatio.models.extra.rag import MilvusDataBase
|
12
|
-
from fabricatio.models.generic import AsPrompt
|
13
|
-
from fabricatio.models.kwargs_types import ChunkKwargs
|
14
|
-
from fabricatio.rust import BibManager, blake3_hash, split_into_chunks
|
15
|
-
from fabricatio.utils import ok, wrapp_in_block
|
16
|
-
from more_itertools.more import first
|
17
|
-
from more_itertools.recipes import flatten, unique
|
18
|
-
from pydantic import Field
|
19
|
-
|
20
|
-
|
21
|
-
class ArticleChunk(MilvusDataBase):
|
22
|
-
"""The chunk of an article."""
|
23
|
-
|
24
|
-
etc_word: ClassVar[str] = "等"
|
25
|
-
and_word: ClassVar[str] = "与"
|
26
|
-
_cite_number: Optional[int] = None
|
27
|
-
|
28
|
-
head_split: ClassVar[List[str]] = [
|
29
|
-
"引 言",
|
30
|
-
"引言",
|
31
|
-
"绪 论",
|
32
|
-
"绪论",
|
33
|
-
"前言",
|
34
|
-
"INTRODUCTION",
|
35
|
-
"Introduction",
|
36
|
-
]
|
37
|
-
tail_split: ClassVar[List[str]] = [
|
38
|
-
"参 考 文 献",
|
39
|
-
"参 考 文 献",
|
40
|
-
"参考文献",
|
41
|
-
"REFERENCES",
|
42
|
-
"References",
|
43
|
-
"Bibliography",
|
44
|
-
"Reference",
|
45
|
-
]
|
46
|
-
chunk: str
|
47
|
-
"""The segment of the article"""
|
48
|
-
year: int
|
49
|
-
"""The year of the article"""
|
50
|
-
authors: List[str] = Field(default_factory=list)
|
51
|
-
"""The authors of the article"""
|
52
|
-
article_title: str
|
53
|
-
"""The title of the article"""
|
54
|
-
bibtex_cite_key: str
|
55
|
-
"""The bibtex cite key of the article"""
|
56
|
-
|
57
|
-
@property
|
58
|
-
def reference_header(self) -> str:
|
59
|
-
"""Get the reference header."""
|
60
|
-
return f"[[{ok(self._cite_number, 'You need to update cite number first.')}]] reference `{self.article_title}` from {self.as_auther_seq()}"
|
61
|
-
|
62
|
-
@property
|
63
|
-
def cite_number(self) -> int:
|
64
|
-
"""Get the cite number."""
|
65
|
-
return ok(self._cite_number, "cite number not set")
|
66
|
-
|
67
|
-
def _prepare_vectorization_inner(self) -> str:
|
68
|
-
return self.chunk
|
69
|
-
|
70
|
-
@classmethod
|
71
|
-
def from_file[P: str | Path](
|
72
|
-
cls, path: P | List[P], bib_mgr: BibManager, **kwargs: Unpack[ChunkKwargs]
|
73
|
-
) -> List[Self]:
|
74
|
-
"""Load the article chunks from the file."""
|
75
|
-
if isinstance(path, list):
|
76
|
-
result = list(flatten(cls._from_file_inner(p, bib_mgr, **kwargs) for p in path))
|
77
|
-
logger.debug(f"Number of chunks created from list of files: {len(result)}")
|
78
|
-
return result
|
79
|
-
|
80
|
-
return cls._from_file_inner(path, bib_mgr, **kwargs)
|
81
|
-
|
82
|
-
@classmethod
|
83
|
-
def _from_file_inner(cls, path: str | Path, bib_mgr: BibManager, **kwargs: Unpack[ChunkKwargs]) -> List[Self]:
|
84
|
-
path = Path(path)
|
85
|
-
|
86
|
-
title_seg = path.stem.split(" - ").pop()
|
87
|
-
|
88
|
-
key = (
|
89
|
-
bib_mgr.get_cite_key_by_title(title_seg)
|
90
|
-
or bib_mgr.get_cite_key_by_title_fuzzy(title_seg)
|
91
|
-
or bib_mgr.get_cite_key_fuzzy(path.stem)
|
92
|
-
)
|
93
|
-
if key is None:
|
94
|
-
logger.warning(f"no cite key found for {path.as_posix()}, skip.")
|
95
|
-
return []
|
96
|
-
authors = ok(bib_mgr.get_author_by_key(key), f"no author found for {key}")
|
97
|
-
year = ok(bib_mgr.get_year_by_key(key), f"no year found for {key}")
|
98
|
-
article_title = ok(bib_mgr.get_title_by_key(key), f"no title found for {key}")
|
99
|
-
|
100
|
-
result = [
|
101
|
-
cls(chunk=c, year=year, authors=authors, article_title=article_title, bibtex_cite_key=key)
|
102
|
-
for c in split_into_chunks(cls.purge_numeric_citation(cls.strip(safe_text_read(path))), **kwargs)
|
103
|
-
]
|
104
|
-
|
105
|
-
logger.debug(f"Number of chunks created from file {path.as_posix()}: {len(result)}")
|
106
|
-
return result
|
107
|
-
|
108
|
-
@classmethod
|
109
|
-
def strip(cls, string: str) -> str:
|
110
|
-
"""Strip the head and tail of the string."""
|
111
|
-
logger.debug(f"String length before strip: {(original := len(string))}")
|
112
|
-
for split in (s for s in cls.head_split if s in string):
|
113
|
-
logger.debug(f"Strip head using {split}")
|
114
|
-
parts = string.split(split)
|
115
|
-
string = split.join(parts[1:]) if len(parts) > 1 else parts[0]
|
116
|
-
break
|
117
|
-
logger.debug(
|
118
|
-
f"String length after head strip: {(stripped_len := len(string))}, decreased by {(d := original - stripped_len)}"
|
119
|
-
)
|
120
|
-
if not d:
|
121
|
-
logger.warning("No decrease at head strip, which is might be abnormal.")
|
122
|
-
for split in (s for s in cls.tail_split if s in string):
|
123
|
-
logger.debug(f"Strip tail using {split}")
|
124
|
-
parts = string.split(split)
|
125
|
-
string = split.join(parts[:-1]) if len(parts) > 1 else parts[0]
|
126
|
-
break
|
127
|
-
logger.debug(f"String length after tail strip: {len(string)}, decreased by {(d := stripped_len - len(string))}")
|
128
|
-
if not d:
|
129
|
-
logger.warning("No decrease at tail strip, which is might be abnormal.")
|
130
|
-
|
131
|
-
return string
|
132
|
-
|
133
|
-
def as_typst_cite(self) -> str:
|
134
|
-
"""As typst cite."""
|
135
|
-
return f"#cite(<{self.bibtex_cite_key}>)"
|
136
|
-
|
137
|
-
@staticmethod
|
138
|
-
def purge_numeric_citation(string: str) -> str:
|
139
|
-
"""Purge numeric citation."""
|
140
|
-
import re
|
141
|
-
|
142
|
-
return re.sub(r"\[[\d\s,\\~–-]+]", "", string)
|
143
|
-
|
144
|
-
@property
|
145
|
-
def auther_lastnames(self) -> List[str]:
|
146
|
-
"""Get the last name of the authors."""
|
147
|
-
return [n.split()[-1] for n in self.authors]
|
148
|
-
|
149
|
-
def as_auther_seq(self) -> str:
|
150
|
-
"""Get the auther sequence."""
|
151
|
-
match len(self.authors):
|
152
|
-
case 0:
|
153
|
-
raise ValueError("No authors found")
|
154
|
-
case 1:
|
155
|
-
return f"({self.auther_lastnames[0]},{self.year}){self.as_typst_cite()}"
|
156
|
-
case 2:
|
157
|
-
return f"({self.auther_lastnames[0]}{self.and_word}{self.auther_lastnames[1]},{self.year}){self.as_typst_cite()}"
|
158
|
-
case 3:
|
159
|
-
return f"({self.auther_lastnames[0]},{self.auther_lastnames[1]}{self.and_word}{self.auther_lastnames[2]},{self.year}){self.as_typst_cite()}"
|
160
|
-
case _:
|
161
|
-
return f"({self.auther_lastnames[0]},{self.auther_lastnames[1]}{self.and_word}{self.auther_lastnames[2]}{self.etc_word},{self.year}){self.as_typst_cite()}"
|
162
|
-
|
163
|
-
def update_cite_number(self, cite_number: int) -> Self:
|
164
|
-
"""Update the cite number."""
|
165
|
-
self._cite_number = cite_number
|
166
|
-
return self
|
167
|
-
|
168
|
-
|
169
|
-
@dataclass
|
170
|
-
class CitationManager(AsPrompt):
|
171
|
-
"""Citation manager."""
|
172
|
-
|
173
|
-
article_chunks: List[ArticleChunk] = field(default_factory=list)
|
174
|
-
"""Article chunks."""
|
175
|
-
|
176
|
-
pat: str = r"(\[\[([\d\s,-]*)]])"
|
177
|
-
"""Regex pattern to match citations."""
|
178
|
-
sep: str = ","
|
179
|
-
"""Separator for citation numbers."""
|
180
|
-
abbr_sep: str = "-"
|
181
|
-
"""Separator for abbreviated citation numbers."""
|
182
|
-
|
183
|
-
def update_chunks(
|
184
|
-
self, article_chunks: List[ArticleChunk], set_cite_number: bool = True, dedup: bool = True
|
185
|
-
) -> Self:
|
186
|
-
"""Update article chunks."""
|
187
|
-
self.article_chunks.clear()
|
188
|
-
self.article_chunks.extend(article_chunks)
|
189
|
-
if dedup:
|
190
|
-
self.article_chunks = list(unique(self.article_chunks, lambda c: blake3_hash(c.chunk.encode())))
|
191
|
-
if set_cite_number:
|
192
|
-
self.set_cite_number_all()
|
193
|
-
return self
|
194
|
-
|
195
|
-
def empty(self) -> Self:
|
196
|
-
"""Empty the article chunks."""
|
197
|
-
self.article_chunks.clear()
|
198
|
-
return self
|
199
|
-
|
200
|
-
def add_chunks(self, article_chunks: List[ArticleChunk], set_cite_number: bool = True, dedup: bool = True) -> Self:
|
201
|
-
"""Add article chunks."""
|
202
|
-
self.article_chunks.extend(article_chunks)
|
203
|
-
if dedup:
|
204
|
-
self.article_chunks = list(unique(self.article_chunks, lambda c: blake3_hash(c.chunk.encode())))
|
205
|
-
if set_cite_number:
|
206
|
-
self.set_cite_number_all()
|
207
|
-
return self
|
208
|
-
|
209
|
-
def set_cite_number_all(self) -> Self:
|
210
|
-
"""Set citation numbers for all article chunks."""
|
211
|
-
number_mapping = {a.bibtex_cite_key: 0 for a in self.article_chunks}
|
212
|
-
|
213
|
-
for i, k in enumerate(number_mapping.keys()):
|
214
|
-
number_mapping[k] = i
|
215
|
-
|
216
|
-
for a in self.article_chunks:
|
217
|
-
a.update_cite_number(number_mapping[a.bibtex_cite_key])
|
218
|
-
return self
|
219
|
-
|
220
|
-
def _as_prompt_inner(self) -> Dict[str, str]:
|
221
|
-
"""Generate prompt inner representation."""
|
222
|
-
seg = []
|
223
|
-
for k, g_iter in groupby(self.article_chunks, key=lambda a: a.bibtex_cite_key):
|
224
|
-
g = list(g_iter)
|
225
|
-
|
226
|
-
logger.debug(f"Group [{k}]: {len(g)}")
|
227
|
-
seg.append(wrapp_in_block("\n\n".join(a.chunk for a in g), first(g).reference_header))
|
228
|
-
return {"References": "\n".join(seg)}
|
229
|
-
|
230
|
-
def apply(self, string: str) -> str:
|
231
|
-
"""Apply citation replacements to the input string."""
|
232
|
-
for origin, m in re.findall(self.pat, string):
|
233
|
-
logger.info(f"Matching citation: {m}")
|
234
|
-
notations = self.convert_to_numeric_notations(m)
|
235
|
-
logger.info(f"Citing Notations: {notations}")
|
236
|
-
citation_number_seq = list(flatten(self.decode_expr(n) for n in notations))
|
237
|
-
logger.info(f"Citation Number Sequence: {citation_number_seq}")
|
238
|
-
dedup = self.deduplicate_citation(citation_number_seq)
|
239
|
-
logger.info(f"Deduplicated Citation Number Sequence: {dedup}")
|
240
|
-
string = string.replace(origin, self.unpack_cite_seq(dedup))
|
241
|
-
return string
|
242
|
-
|
243
|
-
def citation_count(self, string: str) -> int:
|
244
|
-
"""Get the citation count in the string."""
|
245
|
-
count = 0
|
246
|
-
for _, m in re.findall(self.pat, string):
|
247
|
-
logger.info(f"Matching citation: {m}")
|
248
|
-
notations = self.convert_to_numeric_notations(m)
|
249
|
-
logger.info(f"Citing Notations: {notations}")
|
250
|
-
citation_number_seq = list(flatten(self.decode_expr(n) for n in notations))
|
251
|
-
logger.info(f"Citation Number Sequence: {citation_number_seq}")
|
252
|
-
count += len(dedup := self.deduplicate_citation(citation_number_seq))
|
253
|
-
logger.info(f"Deduplicated Citation Number Sequence: {dedup}")
|
254
|
-
return count
|
255
|
-
|
256
|
-
def citation_coverage(self, string: str) -> float:
|
257
|
-
"""Get the citation coverage in the string."""
|
258
|
-
return self.citation_count(string) / len(self.article_chunks)
|
259
|
-
|
260
|
-
def decode_expr(self, string: str) -> List[int]:
|
261
|
-
"""Decode citation expression into a list of integers."""
|
262
|
-
if self.abbr_sep in string:
|
263
|
-
start, end = string.split(self.abbr_sep)
|
264
|
-
return list(range(int(start), int(end) + 1))
|
265
|
-
return [int(string)]
|
266
|
-
|
267
|
-
def convert_to_numeric_notations(self, string: str) -> List[str]:
|
268
|
-
"""Convert citation string into numeric notations."""
|
269
|
-
return [s.strip() for s in string.split(self.sep)]
|
270
|
-
|
271
|
-
def deduplicate_citation(self, citation_seq: List[int]) -> List[int]:
|
272
|
-
"""Deduplicate citation sequence."""
|
273
|
-
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
274
|
-
deduped = unique(chunk_seq, lambda a: a.bibtex_cite_key)
|
275
|
-
return [a.cite_number for a in deduped]
|
276
|
-
|
277
|
-
def unpack_cite_seq(self, citation_seq: List[int]) -> str:
|
278
|
-
"""Unpack citation sequence into a string."""
|
279
|
-
chunk_seq = {a.bibtex_cite_key: a for a in self.article_chunks if a.cite_number in citation_seq}
|
280
|
-
return "".join(a.as_typst_cite() for a in chunk_seq.values())
|
281
|
-
|
282
|
-
def as_milvus_filter_expr(self, blacklist: bool = True) -> str:
|
283
|
-
"""Asynchronously fetches documents from a Milvus database based on input vectors."""
|
284
|
-
if blacklist:
|
285
|
-
return " and ".join(f'bibtex_cite_key != "{a.bibtex_cite_key}"' for a in self.article_chunks)
|
286
|
-
return " or ".join(f'bibtex_cite_key == "{a.bibtex_cite_key}"' for a in self.article_chunks)
|