fabricatio 0.2.9.dev4__cp312-cp312-win_amd64.whl → 0.2.10__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +20 -106
- fabricatio/actions/article_rag.py +153 -22
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +40 -18
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/check.py +2 -1
- fabricatio/capabilities/rag.py +41 -231
- fabricatio/config.py +4 -2
- fabricatio/constants.py +20 -0
- fabricatio/decorators.py +23 -0
- fabricatio/models/adv_kwargs_types.py +35 -0
- fabricatio/models/events.py +6 -6
- fabricatio/models/extra/advanced_judge.py +2 -2
- fabricatio/models/extra/aricle_rag.py +170 -0
- fabricatio/models/extra/article_base.py +2 -186
- fabricatio/models/extra/article_essence.py +8 -7
- fabricatio/models/extra/article_main.py +39 -107
- fabricatio/models/extra/problem.py +12 -17
- fabricatio/models/extra/rag.py +98 -0
- fabricatio/models/extra/rule.py +1 -2
- fabricatio/models/generic.py +35 -12
- fabricatio/models/kwargs_types.py +8 -36
- fabricatio/models/task.py +3 -3
- fabricatio/models/usages.py +80 -6
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +138 -6
- fabricatio/utils.py +62 -4
- fabricatio-0.2.10.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev4.dist-info → fabricatio-0.2.10.dist-info}/METADATA +1 -4
- fabricatio-0.2.10.dist-info/RECORD +64 -0
- fabricatio/models/utils.py +0 -148
- fabricatio-0.2.9.dev4.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.9.dev4.dist-info/RECORD +0 -61
- {fabricatio-0.2.9.dev4.dist-info → fabricatio-0.2.10.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev4.dist-info → fabricatio-0.2.10.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,12 @@
|
|
1
1
|
"""A class representing a problem-solution pair identified during a review process."""
|
2
2
|
|
3
3
|
from itertools import chain
|
4
|
-
from typing import Any, List,
|
4
|
+
from typing import Any, List, Optional, Self, Tuple, Unpack
|
5
5
|
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.generic import SketchedAble, WithBriefing
|
8
8
|
from fabricatio.utils import ask_edit
|
9
|
+
from pydantic import Field
|
9
10
|
from questionary import Choice, checkbox, text
|
10
11
|
from rich import print as r_print
|
11
12
|
|
@@ -13,36 +14,30 @@ from rich import print as r_print
|
|
13
14
|
class Problem(SketchedAble, WithBriefing):
|
14
15
|
"""Represents a problem identified during review."""
|
15
16
|
|
16
|
-
description: str
|
17
|
-
"""
|
17
|
+
description: str = Field(alias="cause")
|
18
|
+
"""The cause of the problem, including the root cause, the context, and the impact, make detailed enough for engineer to understand the problem and its impact."""
|
18
19
|
|
19
|
-
|
20
|
-
"""Severity level of the problem."""
|
21
|
-
|
22
|
-
category: str
|
23
|
-
"""Category of the problem."""
|
20
|
+
severity_level: int = Field(ge=0, le=10)
|
21
|
+
"""Severity level of the problem, which is a number between 0 and 10, 0 means the problem is not severe, 10 means the problem is extremely severe."""
|
24
22
|
|
25
23
|
location: str
|
26
24
|
"""Location where the problem was identified."""
|
27
25
|
|
28
|
-
recommendation: str
|
29
|
-
"""Recommended solution or action."""
|
30
|
-
|
31
26
|
|
32
27
|
class Solution(SketchedAble, WithBriefing):
|
33
28
|
"""Represents a proposed solution to a problem."""
|
34
29
|
|
35
|
-
description: str
|
30
|
+
description: str = Field(alias="mechanism")
|
36
31
|
"""Description of the solution, including a detailed description of the execution steps, and the mechanics, principle or fact."""
|
37
32
|
|
38
33
|
execute_steps: List[str]
|
39
|
-
"""A list of steps to execute to implement the solution, which is expected to be able to finally solve the corresponding problem."""
|
34
|
+
"""A list of steps to execute to implement the solution, which is expected to be able to finally solve the corresponding problem, and which should be an Idiot-proof tutorial."""
|
40
35
|
|
41
|
-
|
42
|
-
"""Feasibility level of the solution."""
|
36
|
+
feasibility_level: int = Field(ge=0, le=10)
|
37
|
+
"""Feasibility level of the solution, which is a number between 0 and 10, 0 means the solution is not feasible, 10 means the solution is complete feasible."""
|
43
38
|
|
44
|
-
|
45
|
-
"""Impact level of the solution."""
|
39
|
+
impact_level: int = Field(ge=0, le=10)
|
40
|
+
"""Impact level of the solution, which is a number between 0 and 10, 0 means the solution is not impactful, 10 means the solution is extremely impactful."""
|
46
41
|
|
47
42
|
|
48
43
|
class ProblemSolutions(SketchedAble):
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""A module containing the RAG (Retrieval-Augmented Generation) models."""
|
2
|
+
|
3
|
+
from abc import ABC
|
4
|
+
from functools import partial
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Self, Sequence, Set
|
6
|
+
|
7
|
+
from fabricatio.decorators import precheck_package
|
8
|
+
from fabricatio.models.generic import Vectorizable
|
9
|
+
from fabricatio.utils import ok
|
10
|
+
from pydantic import JsonValue
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from importlib.util import find_spec
|
14
|
+
|
15
|
+
from pydantic.fields import FieldInfo
|
16
|
+
|
17
|
+
if find_spec("pymilvus"):
|
18
|
+
from pymilvus import CollectionSchema
|
19
|
+
|
20
|
+
|
21
|
+
class MilvusDataBase(Vectorizable, ABC):
|
22
|
+
"""A base class for Milvus data."""
|
23
|
+
|
24
|
+
primary_field_name: ClassVar[str] = "id"
|
25
|
+
"""The name of the primary field in Milvus."""
|
26
|
+
vector_field_name: ClassVar[str] = "vector"
|
27
|
+
"""The name of the vector field in Milvus."""
|
28
|
+
|
29
|
+
index_type: ClassVar[str] = "FLAT"
|
30
|
+
"""The type of index to be used in Milvus."""
|
31
|
+
metric_type: ClassVar[str] = "COSINE"
|
32
|
+
"""The type of metric to be used in Milvus."""
|
33
|
+
|
34
|
+
def prepare_insertion(self, vector: List[float]) -> Dict[str, Any]:
|
35
|
+
"""Prepares the data for insertion into Milvus.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
dict: A dictionary containing the data to be inserted into Milvus.
|
39
|
+
"""
|
40
|
+
return {**self.model_dump(exclude_none=True, by_alias=True), self.vector_field_name: vector}
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
@precheck_package(
|
44
|
+
"pymilvus", "pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
45
|
+
)
|
46
|
+
def as_milvus_schema(cls, dimension: int = 1024) -> "CollectionSchema":
|
47
|
+
"""Generates the schema for Milvus collection."""
|
48
|
+
from pymilvus import CollectionSchema, DataType, FieldSchema
|
49
|
+
|
50
|
+
fields = [
|
51
|
+
FieldSchema(cls.primary_field_name, dtype=DataType.INT64, is_primary=True, auto_id=True),
|
52
|
+
FieldSchema(cls.vector_field_name, dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
53
|
+
]
|
54
|
+
|
55
|
+
for k, v in cls.model_fields.items():
|
56
|
+
k: str
|
57
|
+
v: FieldInfo
|
58
|
+
schema = partial(FieldSchema, k, description=v.description or "")
|
59
|
+
anno = ok(v.annotation)
|
60
|
+
|
61
|
+
if anno == int:
|
62
|
+
fields.append(schema(dtype=DataType.INT64))
|
63
|
+
elif anno == str:
|
64
|
+
fields.append(schema(dtype=DataType.VARCHAR, max_length=65535))
|
65
|
+
elif anno == float:
|
66
|
+
fields.append(schema(dtype=DataType.DOUBLE))
|
67
|
+
elif anno == list[str] or anno == List[str] or anno == set[str] or anno == Set[str]:
|
68
|
+
fields.append(
|
69
|
+
schema(dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=65535, max_capacity=4096)
|
70
|
+
)
|
71
|
+
elif anno == list[int] or anno == List[int] or anno == set[int] or anno == Set[int]:
|
72
|
+
fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=4096))
|
73
|
+
elif anno == list[float] or anno == List[float] or anno == set[float] or anno == Set[float]:
|
74
|
+
fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=4096))
|
75
|
+
elif anno == JsonValue:
|
76
|
+
fields.append(schema(dtype=DataType.JSON))
|
77
|
+
|
78
|
+
else:
|
79
|
+
raise NotImplementedError(f"{k}:{anno} is not supported")
|
80
|
+
|
81
|
+
return CollectionSchema(fields)
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def from_sequence(cls, data: Sequence[Dict[str, Any]]) -> List[Self]:
|
85
|
+
"""Constructs a list of instances from a sequence of dictionaries."""
|
86
|
+
return [cls(**d) for d in data]
|
87
|
+
|
88
|
+
|
89
|
+
class MilvusClassicModel(MilvusDataBase):
|
90
|
+
"""A class representing a classic model stored in Milvus."""
|
91
|
+
|
92
|
+
text: str
|
93
|
+
"""The text to be stored in Milvus."""
|
94
|
+
subject: str = ""
|
95
|
+
"""The subject of the text."""
|
96
|
+
|
97
|
+
def _prepare_vectorization_inner(self) -> str:
|
98
|
+
return self.text
|
fabricatio/models/extra/rule.py
CHANGED
@@ -40,12 +40,11 @@ class RuleSet(SketchedAble, PersistentAble, WithBriefing, Language):
|
|
40
40
|
framework for the topic or domain covered by the rule set."""
|
41
41
|
|
42
42
|
@classmethod
|
43
|
-
def gather(cls, *rulesets: Unpack[Tuple["RuleSet"
|
43
|
+
def gather(cls, *rulesets: Unpack[Tuple["RuleSet", ...]]) -> Self:
|
44
44
|
"""Gathers multiple rule sets into a single rule set."""
|
45
45
|
if not rulesets:
|
46
46
|
raise ValueError("No rulesets provided")
|
47
47
|
return cls(
|
48
|
-
language=rulesets[0].language,
|
49
48
|
name=";".join(ruleset.name for ruleset in rulesets),
|
50
49
|
description=";".join(ruleset.description for ruleset in rulesets),
|
51
50
|
rules=list(flatten(r.rules for r in rulesets)),
|
fabricatio/models/generic.py
CHANGED
@@ -3,15 +3,14 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
8
|
import orjson
|
9
|
-
import rtoml
|
10
9
|
from fabricatio.config import configs
|
11
10
|
from fabricatio.fs.readers import MAGIKA, safe_text_read
|
12
11
|
from fabricatio.journal import logger
|
13
12
|
from fabricatio.parser import JsonCapture
|
14
|
-
from fabricatio.rust import blake3_hash
|
13
|
+
from fabricatio.rust import blake3_hash, detect_language
|
15
14
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
15
|
from fabricatio.utils import ok
|
17
16
|
from litellm.utils import token_counter
|
@@ -53,7 +52,7 @@ class Display(Base):
|
|
53
52
|
Returns:
|
54
53
|
str: JSON string with 1-level indentation for readability
|
55
54
|
"""
|
56
|
-
return self.model_dump_json(indent=1)
|
55
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
57
56
|
|
58
57
|
def compact(self) -> str:
|
59
58
|
"""Generate compact JSON representation.
|
@@ -61,7 +60,7 @@ class Display(Base):
|
|
61
60
|
Returns:
|
62
61
|
str: Minified JSON string without whitespace
|
63
62
|
"""
|
64
|
-
return self.model_dump_json()
|
63
|
+
return self.model_dump_json(by_alias=True)
|
65
64
|
|
66
65
|
@staticmethod
|
67
66
|
def seq_display(seq: Iterable["Display"], compact: bool = False) -> str:
|
@@ -118,6 +117,15 @@ class WordCount(Base):
|
|
118
117
|
"""Expected word count of this research component."""
|
119
118
|
|
120
119
|
|
120
|
+
class FromMapping(Base):
|
121
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
@abstractmethod
|
125
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
126
|
+
"""Generate a list of objects from a mapping."""
|
127
|
+
|
128
|
+
|
121
129
|
class AsPrompt(Base):
|
122
130
|
"""Class that provides a method to generate a prompt from the model.
|
123
131
|
|
@@ -225,7 +233,7 @@ class PersistentAble(Base):
|
|
225
233
|
- Hash generated from JSON content ensures uniqueness
|
226
234
|
"""
|
227
235
|
p = Path(path)
|
228
|
-
out = self.model_dump_json(indent=1)
|
236
|
+
out = self.model_dump_json(indent=1, by_alias=True)
|
229
237
|
|
230
238
|
# Generate a timestamp in the format YYYYMMDD_HHMMSS
|
231
239
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -298,8 +306,17 @@ class PersistentAble(Base):
|
|
298
306
|
class Language(Base):
|
299
307
|
"""Class that provides a language attribute."""
|
300
308
|
|
301
|
-
|
302
|
-
|
309
|
+
@property
|
310
|
+
def language(self) -> str:
|
311
|
+
"""Get the language of the object."""
|
312
|
+
if isinstance(self, Described):
|
313
|
+
return detect_language(self.description)
|
314
|
+
if isinstance(self, Titled):
|
315
|
+
return detect_language(self.title)
|
316
|
+
if isinstance(self, Named):
|
317
|
+
return detect_language(self.name)
|
318
|
+
|
319
|
+
return detect_language(self.model_dump_json(by_alias=True))
|
303
320
|
|
304
321
|
|
305
322
|
class ModelHash(Base):
|
@@ -543,7 +560,7 @@ class FinalizedDumpAble(Base):
|
|
543
560
|
Returns:
|
544
561
|
str: The finalized dump of the object.
|
545
562
|
"""
|
546
|
-
return self.model_dump_json()
|
563
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
547
564
|
|
548
565
|
def finalized_dump_to(self, path: str | Path) -> Self:
|
549
566
|
"""Finalize the dump of the object to a file.
|
@@ -655,8 +672,9 @@ class Vectorizable(Base):
|
|
655
672
|
This class includes methods to prepare the model for vectorization, ensuring it fits within a specified token length.
|
656
673
|
"""
|
657
674
|
|
675
|
+
@abstractmethod
|
658
676
|
def _prepare_vectorization_inner(self) -> str:
|
659
|
-
|
677
|
+
"""Prepare the model for vectorization."""
|
660
678
|
|
661
679
|
@final
|
662
680
|
def prepare_vectorization(self, max_length: Optional[int] = None) -> str:
|
@@ -674,8 +692,7 @@ class Vectorizable(Base):
|
|
674
692
|
max_length = max_length or configs.embedding.max_sequence_length
|
675
693
|
chunk = self._prepare_vectorization_inner()
|
676
694
|
if max_length and (length := token_counter(text=chunk)) > max_length:
|
677
|
-
|
678
|
-
raise ValueError(err)
|
695
|
+
raise ValueError(f"Chunk exceeds maximum sequence length {max_length}, got {length}, see \n{chunk}")
|
679
696
|
|
680
697
|
return chunk
|
681
698
|
|
@@ -726,6 +743,12 @@ class ScopedConfig(Base):
|
|
726
743
|
llm_rpm: Optional[PositiveInt] = None
|
727
744
|
"""The requests per minute of the LLM model."""
|
728
745
|
|
746
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
747
|
+
"""The presence penalty of the LLM model."""
|
748
|
+
|
749
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
750
|
+
"""The frequency penalty of the LLM model."""
|
751
|
+
|
729
752
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
730
753
|
"""The OpenAI API endpoint."""
|
731
754
|
|
@@ -1,47 +1,16 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from
|
4
|
-
from typing import Any, Dict, List, Optional, Required, TypedDict
|
3
|
+
from typing import Any, Dict, List, NotRequired, Optional, Required, TypedDict
|
5
4
|
|
6
5
|
from litellm.caching.caching import CacheMode
|
7
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
8
7
|
|
9
|
-
if find_spec("pymilvus"):
|
10
|
-
from pymilvus import CollectionSchema
|
11
|
-
from pymilvus.milvus_client import IndexParams
|
12
8
|
|
13
|
-
|
14
|
-
|
9
|
+
class ChunkKwargs(TypedDict):
|
10
|
+
"""Configuration parameters for chunking operations."""
|
15
11
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
dimension: int | None
|
20
|
-
primary_field_name: str
|
21
|
-
id_type: str
|
22
|
-
vector_field_name: str
|
23
|
-
metric_type: str
|
24
|
-
timeout: float | None
|
25
|
-
schema: CollectionSchema | None
|
26
|
-
index_params: IndexParams | None
|
27
|
-
|
28
|
-
|
29
|
-
class FetchKwargs(TypedDict, total=False):
|
30
|
-
"""Arguments for fetching data from vector collections.
|
31
|
-
|
32
|
-
Controls how data is retrieved from vector databases, including filtering
|
33
|
-
and result limiting parameters.
|
34
|
-
"""
|
35
|
-
|
36
|
-
collection_name: str | None
|
37
|
-
similarity_threshold: float
|
38
|
-
result_per_query: int
|
39
|
-
|
40
|
-
|
41
|
-
class RetrievalKwargs(FetchKwargs, total=False):
|
42
|
-
"""Arguments for retrieval operations."""
|
43
|
-
|
44
|
-
final_limit: int
|
12
|
+
max_chunk_size: int
|
13
|
+
max_overlapping_rate: NotRequired[float]
|
45
14
|
|
46
15
|
|
47
16
|
class EmbeddingKwargs(TypedDict, total=False):
|
@@ -76,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
76
45
|
no_store: bool # If store the response of this call to cache
|
77
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
78
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
79
50
|
|
80
51
|
|
81
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -139,6 +110,7 @@ class ReviewKwargs[T](ReviewInnerKwargs[T], total=False):
|
|
139
110
|
|
140
111
|
class ReferencedKwargs[T](ValidateKwargs[T], total=False):
|
141
112
|
"""Arguments for content review operations."""
|
113
|
+
|
142
114
|
reference: str
|
143
115
|
|
144
116
|
|
fabricatio/models/task.py
CHANGED
@@ -7,11 +7,11 @@ from asyncio import Queue
|
|
7
7
|
from typing import Any, List, Optional, Self
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
|
+
from fabricatio.constants import TaskStatus
|
10
11
|
from fabricatio.core import env
|
11
12
|
from fabricatio.journal import logger
|
12
13
|
from fabricatio.models.events import Event, EventLike
|
13
14
|
from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
|
14
|
-
from fabricatio.models.utils import TaskStatus
|
15
15
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
16
16
|
from pydantic import Field, PrivateAttr
|
17
17
|
|
@@ -112,12 +112,12 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
112
112
|
"""Return a formatted status label for the task.
|
113
113
|
|
114
114
|
Args:
|
115
|
-
status (TaskStatus): The status of the task.
|
115
|
+
status (fabricatio.constants.TaskStatus): The status of the task.
|
116
116
|
|
117
117
|
Returns:
|
118
118
|
str: The formatted status label.
|
119
119
|
"""
|
120
|
-
return self._namespace.derive(self.name).push(status
|
120
|
+
return self._namespace.derive(self.name).push(status).collapse()
|
121
121
|
|
122
122
|
@property
|
123
123
|
def pending_label(self) -> str:
|
fabricatio/models/usages.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import traceback
|
4
4
|
from asyncio import gather
|
5
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Union, Unpack, overload
|
5
|
+
from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
|
6
6
|
|
7
7
|
import asyncstdlib
|
8
8
|
import litellm
|
@@ -13,7 +13,6 @@ from fabricatio.models.generic import ScopedConfig, WithBriefing
|
|
13
13
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.models.utils import Messages
|
17
16
|
from fabricatio.parser import GenericCapture, JsonCapture
|
18
17
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
19
18
|
from fabricatio.utils import ok
|
@@ -28,7 +27,7 @@ from litellm.types.utils import (
|
|
28
27
|
)
|
29
28
|
from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore [reportPrivateImportUsage]
|
30
29
|
from more_itertools import duplicates_everseen
|
31
|
-
from pydantic import Field, NonNegativeInt, PositiveInt
|
30
|
+
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
32
31
|
|
33
32
|
if configs.cache.enabled and configs.cache.type:
|
34
33
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
@@ -64,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
64
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
65
64
|
return ROUTER
|
66
65
|
|
67
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
68
67
|
async def aquery(
|
69
68
|
self,
|
70
69
|
messages: List[Dict[str, str]],
|
@@ -123,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
123
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
124
123
|
"s-maxage": kwargs.get("s_maxage"),
|
125
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
126
131
|
)
|
127
132
|
|
128
133
|
async def ainvoke(
|
@@ -303,7 +308,7 @@ class LLMUsage(ScopedConfig):
|
|
303
308
|
and logger.debug("Co-extraction is enabled.") is None
|
304
309
|
and (
|
305
310
|
validated := validator(
|
306
|
-
response:=await self.aask(
|
311
|
+
response := await self.aask(
|
307
312
|
question=(
|
308
313
|
TEMPLATE_MANAGER.render_template(
|
309
314
|
configs.templates.co_validation_template,
|
@@ -495,7 +500,7 @@ class LLMUsage(ScopedConfig):
|
|
495
500
|
affirm_case: str = "",
|
496
501
|
deny_case: str = "",
|
497
502
|
**kwargs: Unpack[ValidateKwargs[bool]],
|
498
|
-
) -> bool:
|
503
|
+
) -> Optional[bool]:
|
499
504
|
"""Asynchronously judges a prompt using AI validation.
|
500
505
|
|
501
506
|
Args:
|
@@ -732,3 +737,72 @@ class ToolBoxUsage(LLMUsage):
|
|
732
737
|
for other in (x for x in others if isinstance(x, ToolBoxUsage)):
|
733
738
|
other.toolboxes.update(self.toolboxes)
|
734
739
|
return self
|
740
|
+
|
741
|
+
|
742
|
+
class Message(BaseModel):
|
743
|
+
"""A class representing a message."""
|
744
|
+
|
745
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
746
|
+
role: Literal["user", "system", "assistant"]
|
747
|
+
"""The role of the message sender."""
|
748
|
+
content: str
|
749
|
+
"""The content of the message."""
|
750
|
+
|
751
|
+
|
752
|
+
class Messages(list):
|
753
|
+
"""A list of messages."""
|
754
|
+
|
755
|
+
def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
|
756
|
+
"""Adds a message to the list with the specified role and content.
|
757
|
+
|
758
|
+
Args:
|
759
|
+
role (Literal["user", "system", "assistant"]): The role of the message sender.
|
760
|
+
content (str): The content of the message.
|
761
|
+
|
762
|
+
Returns:
|
763
|
+
Self: The current instance of Messages to allow method chaining.
|
764
|
+
"""
|
765
|
+
if content:
|
766
|
+
self.append(Message(role=role, content=content))
|
767
|
+
return self
|
768
|
+
|
769
|
+
def add_user_message(self, content: str) -> Self:
|
770
|
+
"""Adds a user message to the list with the specified content.
|
771
|
+
|
772
|
+
Args:
|
773
|
+
content (str): The content of the user message.
|
774
|
+
|
775
|
+
Returns:
|
776
|
+
Self: The current instance of Messages to allow method chaining.
|
777
|
+
"""
|
778
|
+
return self.add_message("user", content)
|
779
|
+
|
780
|
+
def add_system_message(self, content: str) -> Self:
|
781
|
+
"""Adds a system message to the list with the specified content.
|
782
|
+
|
783
|
+
Args:
|
784
|
+
content (str): The content of the system message.
|
785
|
+
|
786
|
+
Returns:
|
787
|
+
Self: The current instance of Messages to allow method chaining.
|
788
|
+
"""
|
789
|
+
return self.add_message("system", content)
|
790
|
+
|
791
|
+
def add_assistant_message(self, content: str) -> Self:
|
792
|
+
"""Adds an assistant message to the list with the specified content.
|
793
|
+
|
794
|
+
Args:
|
795
|
+
content (str): The content of the assistant message.
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
Self: The current instance of Messages to allow method chaining.
|
799
|
+
"""
|
800
|
+
return self.add_message("assistant", content)
|
801
|
+
|
802
|
+
def as_list(self) -> List[Dict[str, str]]:
|
803
|
+
"""Converts the messages to a list of dictionaries.
|
804
|
+
|
805
|
+
Returns:
|
806
|
+
list[dict]: A list of dictionaries representing the messages.
|
807
|
+
"""
|
808
|
+
return [message.model_dump() for message in self]
|
Binary file
|