fabricatio 0.2.9.dev4__cp312-cp312-win_amd64.whl → 0.2.10.dev1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,28 +3,22 @@
3
3
  try:
4
4
  from pymilvus import MilvusClient
5
5
  except ImportError as e:
6
- raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?") from e
6
+ raise RuntimeError(
7
+ "pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
8
+ ) from e
7
9
  from functools import lru_cache
8
10
  from operator import itemgetter
9
- from os import PathLike
10
- from pathlib import Path
11
- from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
11
+ from typing import List, Optional, Self, Type, Unpack
12
12
 
13
13
  from more_itertools.recipes import flatten, unique
14
14
  from pydantic import Field, PrivateAttr
15
15
 
16
16
  from fabricatio.config import configs
17
17
  from fabricatio.journal import logger
18
- from fabricatio.models.kwargs_types import (
19
- ChooseKwargs,
20
- CollectionConfigKwargs,
21
- EmbeddingKwargs,
22
- FetchKwargs,
23
- LLMKwargs,
24
- RetrievalKwargs,
25
- )
18
+ from fabricatio.models.adv_kwargs_types import CollectionConfigKwargs, FetchKwargs
19
+ from fabricatio.models.extra.rag import MilvusDataBase
20
+ from fabricatio.models.kwargs_types import ChooseKwargs
26
21
  from fabricatio.models.usages import EmbeddingUsage
27
- from fabricatio.models.utils import MilvusData
28
22
  from fabricatio.rust_instances import TEMPLATE_MANAGER
29
23
  from fabricatio.utils import ok
30
24
 
@@ -78,40 +72,6 @@ class RAG(EmbeddingUsage):
78
72
  raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
79
73
  return self
80
74
 
81
- @overload
82
- async def pack(
83
- self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
84
- ) -> List[MilvusData]: ...
85
- @overload
86
- async def pack(
87
- self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
88
- ) -> MilvusData: ...
89
-
90
- async def pack(
91
- self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
92
- ) -> List[MilvusData] | MilvusData:
93
- """Asynchronously generates MilvusData objects for the given input text.
94
-
95
- Args:
96
- input_text (List[str] | str): A string or list of strings to generate embeddings for.
97
- subject (Optional[str]): The subject of the input text. Defaults to None.
98
- **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
99
-
100
- Returns:
101
- List[MilvusData] | MilvusData: The generated MilvusData objects.
102
- """
103
- if isinstance(input_text, str):
104
- return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
105
- vecs = await self.vectorize(input_text, **kwargs)
106
- return [
107
- MilvusData(
108
- vector=vec,
109
- text=text,
110
- subject=subject,
111
- )
112
- for text, vec in zip(input_text, vecs, strict=True)
113
- ]
114
-
115
75
  def view(
116
76
  self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
117
77
  ) -> Self:
@@ -152,29 +112,27 @@ class RAG(EmbeddingUsage):
152
112
  Returns:
153
113
  str: The name of the collection being viewed.
154
114
  """
155
- if self.target_collection is None:
156
- raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
157
- return self.target_collection
115
+ return ok(self.target_collection, "No collection is being viewed. Have you called `self.view()`?")
158
116
 
159
- def add_document[D: Union[Dict[str, Any], MilvusData]](
160
- self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
117
+ async def add_document[D: MilvusDataBase](
118
+ self, data: List[D] | D, collection_name: Optional[str] = None, flush: bool = False
161
119
  ) -> Self:
162
120
  """Adds a document to the specified collection.
163
121
 
164
122
  Args:
165
- data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
123
+ data (Union[Dict[str, Any], MilvusDataBase] | List[Union[Dict[str, Any], MilvusDataBase]]): The data to be added to the collection.
166
124
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
167
125
  flush (bool): Whether to flush the collection after insertion.
168
126
 
169
127
  Returns:
170
128
  Self: The current instance, allowing for method chaining.
171
129
  """
172
- if isinstance(data, MilvusData):
173
- prepared_data = data.prepare_insertion()
174
- elif isinstance(data, list):
175
- prepared_data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
176
- else:
177
- raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
130
+ if isinstance(data, MilvusDataBase):
131
+ data = [data]
132
+
133
+ data_vec = await self.vectorize([d.prepare_vectorization() for d in data])
134
+ prepared_data = [d.prepare_insertion(vec) for d, vec in zip(data, data_vec, strict=True)]
135
+
178
136
  c_name = collection_name or self.safe_target_collection
179
137
  self.check_client().client.insert(c_name, prepared_data)
180
138
 
@@ -183,84 +141,33 @@ class RAG(EmbeddingUsage):
183
141
  self.client.flush(c_name)
184
142
  return self
185
143
 
186
- async def consume_file(
187
- self,
188
- source: List[PathLike] | PathLike,
189
- reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
190
- collection_name: Optional[str] = None,
191
- ) -> Self:
192
- """Consume a file and add its content to the collection.
193
-
194
- Args:
195
- source (PathLike): The path to the file to be consumed.
196
- reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
197
- collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
198
-
199
- Returns:
200
- Self: The current instance, allowing for method chaining.
201
- """
202
- if not isinstance(source, list):
203
- source = [source]
204
- return await self.consume_string([reader(s) for s in source], collection_name)
205
-
206
- async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
207
- """Consume a string and add it to the collection.
208
-
209
- Args:
210
- text (List[str] | str): The text to be added to the collection.
211
- collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
212
-
213
- Returns:
214
- Self: The current instance, allowing for method chaining.
215
- """
216
- self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
217
- return self
218
-
219
- @overload
220
- async def afetch_document[V: (int, str, float, bytes)](
144
+ async def afetch_document[D: MilvusDataBase](
221
145
  self,
222
146
  vecs: List[List[float]],
223
- desired_fields: List[str],
147
+ document_model: Type[D],
224
148
  collection_name: Optional[str] = None,
225
149
  similarity_threshold: float = 0.37,
226
150
  result_per_query: int = 10,
227
- ) -> List[Dict[str, V]]: ...
228
-
229
- @overload
230
- async def afetch_document[V: (int, str, float, bytes)](
231
- self,
232
- vecs: List[List[float]],
233
- desired_fields: str,
234
- collection_name: Optional[str] = None,
235
- similarity_threshold: float = 0.37,
236
- result_per_query: int = 10,
237
- ) -> List[V]: ...
238
- async def afetch_document[V: (int, str, float, bytes)](
239
- self,
240
- vecs: List[List[float]],
241
- desired_fields: List[str] | str,
242
- collection_name: Optional[str] = None,
243
- similarity_threshold: float = 0.37,
244
- result_per_query: int = 10,
245
- ) -> List[Dict[str, Any]] | List[V]:
246
- """Fetch data from the collection.
151
+ ) -> List[D]:
152
+ """Asynchronously fetches documents from a Milvus database based on input vectors.
247
153
 
248
154
  Args:
249
- vecs (List[List[float]]): The vectors to search for.
250
- desired_fields (List[str] | str): The fields to retrieve.
251
- collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
252
- similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
253
- result_per_query (int): The number of results to return per query.
155
+ vecs (List[List[float]]): A list of vectors to search for in the database.
156
+ document_model (Type[D]): The model class used to convert fetched data into document objects.
157
+ collection_name (Optional[str]): The name of the collection to search within.
158
+ If None, the currently viewed collection is used.
159
+ similarity_threshold (float): The similarity threshold for vector search. Defaults to 0.37.
160
+ result_per_query (int): The maximum number of results to return per query. Defaults to 10.
254
161
 
255
162
  Returns:
256
- List[Dict[str, Any]] | List[Any]: The retrieved data.
163
+ List[D]: A list of document objects created from the fetched data.
257
164
  """
258
165
  # Step 1: Search for vectors
259
166
  search_results = self.check_client().client.search(
260
167
  collection_name or self.safe_target_collection,
261
168
  vecs,
262
169
  search_params={"radius": similarity_threshold},
263
- output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
170
+ output_fields=list(document_model.model_fields),
264
171
  limit=result_per_query,
265
172
  )
266
173
 
@@ -270,104 +177,42 @@ class RAG(EmbeddingUsage):
270
177
  # Step 3: Sort by distance (descending)
271
178
  sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
272
179
 
273
- logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
180
+ logger.debug(
181
+ f"Fetched {len(sorted_results)} document,searched similarities: {[t['distance'] for t in sorted_results]}"
182
+ )
274
183
  # Step 4: Extract the entities
275
184
  resp = [result["entity"] for result in sorted_results]
276
185
 
277
- if isinstance(desired_fields, list):
278
- return resp
279
- return [r.get(desired_fields) for r in resp] # extract the single field as list
186
+ return document_model.from_sequence(resp)
280
187
 
281
- async def aretrieve(
188
+ async def aretrieve[D: MilvusDataBase](
282
189
  self,
283
190
  query: List[str] | str,
191
+ document_model: Type[D],
284
192
  final_limit: int = 20,
285
193
  **kwargs: Unpack[FetchKwargs],
286
- ) -> List[str]:
194
+ ) -> List[D]:
287
195
  """Retrieve data from the collection.
288
196
 
289
197
  Args:
290
198
  query (List[str] | str): The query to be used for retrieval.
199
+ document_model (Type[D]): The model class used to convert retrieved data into document objects.
291
200
  final_limit (int): The final limit on the number of results to return.
292
201
  **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
293
202
 
294
203
  Returns:
295
- List[str]: A list of strings containing the retrieved data.
204
+ List[D]: A list of document objects created from the retrieved data.
296
205
  """
297
206
  if isinstance(query, str):
298
207
  query = [query]
299
- return cast(
300
- "List[str]",
208
+ return (
301
209
  await self.afetch_document(
302
210
  vecs=(await self.vectorize(query)),
303
- desired_fields="text",
211
+ document_model=document_model,
304
212
  **kwargs,
305
- ),
213
+ )
306
214
  )[:final_limit]
307
215
 
308
- async def aretrieve_compact(
309
- self,
310
- query: List[str] | str,
311
- **kwargs: Unpack[RetrievalKwargs],
312
- ) -> str:
313
- """Retrieve data from the collection and format it for display.
314
-
315
- Args:
316
- query (List[str] | str): The query to be used for retrieval.
317
- **kwargs (Unpack[RetrievalKwargs]): Additional keyword arguments for retrieval.
318
-
319
- Returns:
320
- str: A formatted string containing the retrieved data.
321
- """
322
- return TEMPLATE_MANAGER.render_template(
323
- configs.templates.retrieved_display_template, {"docs": (await self.aretrieve(query, **kwargs))}
324
- )
325
-
326
- async def aask_retrieved(
327
- self,
328
- question: str,
329
- query: Optional[List[str] | str] = None,
330
- collection_name: Optional[str] = None,
331
- extra_system_message: str = "",
332
- result_per_query: int = 10,
333
- final_limit: int = 20,
334
- similarity_threshold: float = 0.37,
335
- **kwargs: Unpack[LLMKwargs],
336
- ) -> str:
337
- """Asks a question by retrieving relevant documents based on the provided query.
338
-
339
- This method performs document retrieval using the given query, then asks the
340
- specified question using the retrieved documents as context.
341
-
342
- Args:
343
- question (str): The question to be asked.
344
- query (List[str] | str): The query or list of queries used for document retrieval.
345
- collection_name (Optional[str]): The name of the collection to retrieve documents from.
346
- If not provided, the currently viewed collection is used.
347
- extra_system_message (str): An additional system message to be included in the prompt.
348
- result_per_query (int): The number of results to return per query. Default is 10.
349
- final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
350
- similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
351
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
352
-
353
- Returns:
354
- str: A string response generated after asking with the context of retrieved documents.
355
- """
356
- rendered = await self.aretrieve_compact(
357
- query or question,
358
- final_limit=final_limit,
359
- collection_name=collection_name,
360
- result_per_query=result_per_query,
361
- similarity_threshold=similarity_threshold,
362
- )
363
-
364
- logger.debug(f"Retrieved Documents: \n{rendered}")
365
- return await self.aask(
366
- question,
367
- f"{rendered}\n\n{extra_system_message}",
368
- **kwargs,
369
- )
370
-
371
216
  async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> Optional[List[str]]:
372
217
  """Refines the given question using a template.
373
218
 
@@ -385,38 +230,3 @@ class RAG(EmbeddingUsage):
385
230
  ),
386
231
  **kwargs,
387
232
  )
388
-
389
- async def aask_refined(
390
- self,
391
- question: str,
392
- collection_name: Optional[str] = None,
393
- extra_system_message: str = "",
394
- result_per_query: int = 10,
395
- final_limit: int = 20,
396
- similarity_threshold: float = 0.37,
397
- **kwargs: Unpack[LLMKwargs],
398
- ) -> str:
399
- """Asks a question using a refined query based on the provided question.
400
-
401
- Args:
402
- question (str): The question to be asked.
403
- collection_name (Optional[str]): The name of the collection to retrieve documents from.
404
- extra_system_message (str): An additional system message to be included in the prompt.
405
- result_per_query (int): The number of results to return per query. Default is 10.
406
- final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
407
- similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
408
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
409
-
410
- Returns:
411
- str: A string response generated after asking with the refined question.
412
- """
413
- return await self.aask_retrieved(
414
- question,
415
- await self.arefined_query(question, **kwargs),
416
- collection_name=collection_name,
417
- extra_system_message=extra_system_message,
418
- result_per_query=result_per_query,
419
- final_limit=final_limit,
420
- similarity_threshold=similarity_threshold,
421
- **kwargs,
422
- )
@@ -0,0 +1,20 @@
1
+ """A module containing constants used throughout the library."""
2
+ from enum import StrEnum
3
+
4
+
5
+ class TaskStatus(StrEnum):
6
+ """An enumeration representing the status of a task.
7
+
8
+ Attributes:
9
+ Pending: The task is pending.
10
+ Running: The task is currently running.
11
+ Finished: The task has been successfully completed.
12
+ Failed: The task has failed.
13
+ Cancelled: The task has been cancelled.
14
+ """
15
+
16
+ Pending = "pending"
17
+ Running = "running"
18
+ Finished = "finished"
19
+ Failed = "failed"
20
+ Cancelled = "cancelled"
fabricatio/decorators.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  from asyncio import iscoroutinefunction
4
4
  from functools import wraps
5
+ from importlib.util import find_spec
5
6
  from inspect import signature
6
7
  from shutil import which
7
8
  from types import ModuleType
@@ -209,3 +210,25 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
209
210
  return result
210
211
 
211
212
  return _wrapper
213
+
214
+
215
+ def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
216
+ """Check if a package exists in the current environment.
217
+
218
+ Args:
219
+ package_name (str): The name of the package to check.
220
+ msg (str): The message to display if the package is not found.
221
+
222
+ Returns:
223
+ bool: True if the package exists, False otherwise.
224
+ """
225
+
226
+ def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
227
+ def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
228
+ if find_spec(package_name):
229
+ return func(*args, **kwargs)
230
+ raise RuntimeError(msg)
231
+
232
+ return _inner
233
+
234
+ return _wrapper
@@ -1,4 +1,8 @@
1
1
  """A module containing kwargs types for content correction and checking operations."""
2
+
3
+ from importlib.util import find_spec
4
+ from typing import NotRequired, TypedDict
5
+
2
6
  from fabricatio.models.extra.problem import Improvement
3
7
  from fabricatio.models.extra.rule import RuleSet
4
8
  from fabricatio.models.generic import SketchedAble
@@ -23,3 +27,34 @@ class CheckKwargs(ReferencedKwargs[Improvement], total=False):
23
27
  """
24
28
 
25
29
  ruleset: RuleSet
30
+
31
+
32
+ if find_spec("pymilvus"):
33
+ from pymilvus import CollectionSchema
34
+ from pymilvus.milvus_client import IndexParams
35
+
36
+ class CollectionConfigKwargs(TypedDict, total=False):
37
+ """Configuration parameters for a vector collection.
38
+
39
+ These arguments are typically used when configuring connections to vector databases.
40
+ """
41
+
42
+ dimension: int | None
43
+ primary_field_name: str
44
+ id_type: str
45
+ vector_field_name: str
46
+ metric_type: str
47
+ timeout: float | None
48
+ schema: CollectionSchema | None
49
+ index_params: IndexParams | None
50
+
51
+ class FetchKwargs(TypedDict):
52
+ """Arguments for fetching data from vector collections.
53
+
54
+ Controls how data is retrieved from vector databases, including filtering
55
+ and result limiting parameters.
56
+ """
57
+
58
+ collection_name: NotRequired[str | None]
59
+ similarity_threshold: NotRequired[float]
60
+ result_per_query: NotRequired[int]
@@ -3,7 +3,7 @@
3
3
  from typing import List, Self, Union
4
4
 
5
5
  from fabricatio.config import configs
6
- from fabricatio.models.utils import TaskStatus
6
+ from fabricatio.constants import TaskStatus
7
7
  from pydantic import BaseModel, ConfigDict, Field
8
8
 
9
9
  type EventLike = Union[str, List[str], "Event"]
@@ -77,23 +77,23 @@ class Event(BaseModel):
77
77
 
78
78
  def push_pending(self) -> Self:
79
79
  """Push a pending segment to the event."""
80
- return self.push(TaskStatus.Pending.value)
80
+ return self.push(TaskStatus.Pending)
81
81
 
82
82
  def push_running(self) -> Self:
83
83
  """Push a running segment to the event."""
84
- return self.push(TaskStatus.Running.value)
84
+ return self.push(TaskStatus.Running)
85
85
 
86
86
  def push_finished(self) -> Self:
87
87
  """Push a finished segment to the event."""
88
- return self.push(TaskStatus.Finished.value)
88
+ return self.push(TaskStatus.Finished)
89
89
 
90
90
  def push_failed(self) -> Self:
91
91
  """Push a failed segment to the event."""
92
- return self.push(TaskStatus.Failed.value)
92
+ return self.push(TaskStatus.Failed)
93
93
 
94
94
  def push_cancelled(self) -> Self:
95
95
  """Push a cancelled segment to the event."""
96
- return self.push(TaskStatus.Cancelled.value)
96
+ return self.push(TaskStatus.Cancelled)
97
97
 
98
98
  def pop(self) -> str:
99
99
  """Pop a segment from the event."""
@@ -12,7 +12,7 @@ class JudgeMent(SketchedAble):
12
12
  """
13
13
 
14
14
  issue_to_judge: str
15
- """The issue to be judged, true for affirmation, false for denial."""
15
+ """The issue to be judged, including the original question and context"""
16
16
 
17
17
  deny_evidence: List[str]
18
18
  """List of clues supporting the denial."""
@@ -21,7 +21,7 @@ class JudgeMent(SketchedAble):
21
21
  """List of clues supporting the affirmation."""
22
22
 
23
23
  final_judgement: bool
24
- """The final judgment made according to all extracted clues."""
24
+ """The final judgment made according to all extracted clues. true for the `issue_to_judge` is correct and false for incorrect."""
25
25
 
26
26
  def __bool__(self) -> bool:
27
27
  """Return the final judgment value.
@@ -0,0 +1,120 @@
1
+ """A Module containing the article rag models."""
2
+
3
+ from pathlib import Path
4
+ from typing import ClassVar, Dict, List, Self, Unpack
5
+
6
+ from fabricatio.fs import safe_text_read
7
+ from fabricatio.journal import logger
8
+ from fabricatio.models.extra.rag import MilvusDataBase
9
+ from fabricatio.models.generic import AsPrompt
10
+ from fabricatio.models.kwargs_types import ChunkKwargs
11
+ from fabricatio.rust import BibManager, split_into_chunks
12
+ from fabricatio.utils import ok, wrapp_in_block
13
+ from more_itertools.recipes import flatten
14
+ from pydantic import Field
15
+
16
+
17
+ class ArticleChunk(MilvusDataBase, AsPrompt):
18
+ """The chunk of an article."""
19
+
20
+ head_split: ClassVar[List[str]] = [
21
+ "引 言",
22
+ "引言",
23
+ "绪 论",
24
+ "绪论",
25
+ "前言",
26
+ "INTRODUCTION",
27
+ "Introduction",
28
+ ]
29
+ tail_split: ClassVar[List[str]] = [
30
+ "参 考 文 献",
31
+ "参 考 文 献",
32
+ "参考文献",
33
+ "REFERENCES",
34
+ "References",
35
+ "Bibliography",
36
+ "Reference",
37
+ ]
38
+ chunk: str
39
+ """The segment of the article"""
40
+ year: int
41
+ """The year of the article"""
42
+ authors: List[str] = Field(default_factory=list)
43
+ """The authors of the article"""
44
+ article_title: str
45
+ """The title of the article"""
46
+ bibtex_cite_key: str
47
+ """The bibtex cite key of the article"""
48
+
49
+ def _as_prompt_inner(self) -> Dict[str, str]:
50
+ return {
51
+ self.article_title: f"{wrapp_in_block(self.chunk, 'Referring Content')}\n"
52
+ f"Authors: {';'.join(self.authors)}\n"
53
+ f"Published Year: {self.year}\n"
54
+ f"Bibtex Key: {self.bibtex_cite_key}\n",
55
+ }
56
+
57
+ def _prepare_vectorization_inner(self) -> str:
58
+ return self.chunk
59
+
60
+ @classmethod
61
+ def from_file[P: str | Path](
62
+ cls, path: P | List[P], bib_mgr: BibManager, **kwargs: Unpack[ChunkKwargs]
63
+ ) -> List[Self]:
64
+ """Load the article chunks from the file."""
65
+ if isinstance(path, list):
66
+ result = list(flatten(cls._from_file_inner(p, bib_mgr, **kwargs) for p in path))
67
+ logger.debug(f"Number of chunks created from list of files: {len(result)}")
68
+ return result
69
+
70
+ return cls._from_file_inner(path, bib_mgr, **kwargs)
71
+
72
+ @classmethod
73
+ def _from_file_inner(cls, path: str | Path, bib_mgr: BibManager, **kwargs: Unpack[ChunkKwargs]) -> List[Self]:
74
+ path = Path(path)
75
+
76
+ title_seg = path.stem.split(" - ").pop()
77
+
78
+ key = (
79
+ bib_mgr.get_cite_key_by_title(title_seg)
80
+ or bib_mgr.get_cite_key_by_title_fuzzy(title_seg)
81
+ or bib_mgr.get_cite_key_fuzzy(path.stem)
82
+ )
83
+ if key is None:
84
+ logger.warning(f"no cite key found for {path.as_posix()}, skip.")
85
+ return []
86
+ authors = ok(bib_mgr.get_author_by_key(key), f"no author found for {key}")
87
+ year = ok(bib_mgr.get_year_by_key(key), f"no year found for {key}")
88
+ article_title = ok(bib_mgr.get_title_by_key(key), f"no title found for {key}")
89
+
90
+ result = [
91
+ cls(chunk=c, year=year, authors=authors, article_title=article_title, bibtex_cite_key=key)
92
+ for c in split_into_chunks(cls.strip(safe_text_read(path)), **kwargs)
93
+ ]
94
+ logger.debug(f"Number of chunks created from file {path.as_posix()}: {len(result)}")
95
+ return result
96
+
97
+ @classmethod
98
+ def strip(cls, string: str) -> str:
99
+ """Strip the head and tail of the string."""
100
+ logger.debug(f"String length before strip: {(original := len(string))}")
101
+ for split in (s for s in cls.head_split if s in string):
102
+ logger.debug(f"Strip head using {split}")
103
+ parts = string.split(split)
104
+ string = split.join(parts[1:]) if len(parts) > 1 else parts[0]
105
+ break
106
+ logger.debug(
107
+ f"String length after head strip: {(stripped_len := len(string))}, decreased by {(d := original - stripped_len)}"
108
+ )
109
+ if not d:
110
+ logger.warning("No decrease at head strip, which is might be abnormal.")
111
+ for split in (s for s in cls.tail_split if s in string):
112
+ logger.debug(f"Strip tail using {split}")
113
+ parts = string.split(split)
114
+ string = split.join(parts[:-1]) if len(parts) > 1 else parts[0]
115
+ break
116
+ logger.debug(f"String length after tail strip: {len(string)}, decreased by {(d := stripped_len - len(string))}")
117
+ if not d:
118
+ logger.warning("No decrease at tail strip, which is might be abnormal.")
119
+
120
+ return string