fabricatio 0.2.3.dev3__cp312-cp312-win_amd64.whl → 0.2.4.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +4 -2
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/actions/__init__.py +2 -2
- fabricatio/actions/article.py +127 -0
- fabricatio/capabilities/propose.py +55 -0
- fabricatio/capabilities/rag.py +129 -44
- fabricatio/capabilities/task.py +6 -23
- fabricatio/config.py +37 -2
- fabricatio/models/action.py +1 -0
- fabricatio/models/events.py +36 -0
- fabricatio/models/generic.py +158 -7
- fabricatio/models/kwargs_types.py +14 -0
- fabricatio/models/task.py +5 -23
- fabricatio/models/usages.py +103 -162
- fabricatio/models/utils.py +19 -0
- fabricatio/parser.py +34 -3
- fabricatio-0.2.4.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev0.dist-info}/METADATA +66 -178
- fabricatio-0.2.4.dev0.dist-info/RECORD +37 -0
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio-0.2.3.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.3.dev3.dist-info/RECORD +0 -37
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dev0.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/events.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from typing import List, Self, Union
|
4
4
|
|
5
5
|
from fabricatio.config import configs
|
6
|
+
from fabricatio.models.utils import TaskStatus
|
6
7
|
from pydantic import BaseModel, ConfigDict, Field
|
7
8
|
|
8
9
|
type EventLike = Union[str, List[str], "Event"]
|
@@ -33,6 +34,21 @@ class Event(BaseModel):
|
|
33
34
|
|
34
35
|
return cls(segments=event)
|
35
36
|
|
37
|
+
@classmethod
|
38
|
+
def quick_instantiate(cls, event: EventLike) -> Self:
|
39
|
+
"""Create an Event instance from a string or list of strings or an Event instance and push a wildcard and pending segment.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
event (EventLike): The event to instantiate from.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
Event: The Event instance.
|
46
|
+
|
47
|
+
Notes:
|
48
|
+
This method is used to create an Event instance from a string or list of strings or an Event instance and push a wildcard and pending segment.
|
49
|
+
"""
|
50
|
+
return cls.instantiate_from(event).push_wildcard().push_pending()
|
51
|
+
|
36
52
|
def derive(self, event: EventLike) -> Self:
|
37
53
|
"""Derive a new event from this event and another event or a string."""
|
38
54
|
return self.clone().concat(event)
|
@@ -59,6 +75,26 @@ class Event(BaseModel):
|
|
59
75
|
"""Push a wildcard segment to the event."""
|
60
76
|
return self.push("*")
|
61
77
|
|
78
|
+
def push_pending(self) -> Self:
|
79
|
+
"""Push a pending segment to the event."""
|
80
|
+
return self.push(TaskStatus.Pending.value)
|
81
|
+
|
82
|
+
def push_running(self) -> Self:
|
83
|
+
"""Push a running segment to the event."""
|
84
|
+
return self.push(TaskStatus.Running.value)
|
85
|
+
|
86
|
+
def push_finished(self) -> Self:
|
87
|
+
"""Push a finished segment to the event."""
|
88
|
+
return self.push(TaskStatus.Finished.value)
|
89
|
+
|
90
|
+
def push_failed(self) -> Self:
|
91
|
+
"""Push a failed segment to the event."""
|
92
|
+
return self.push(TaskStatus.Failed.value)
|
93
|
+
|
94
|
+
def push_cancelled(self) -> Self:
|
95
|
+
"""Push a cancelled segment to the event."""
|
96
|
+
return self.push(TaskStatus.Cancelled.value)
|
97
|
+
|
62
98
|
def pop(self) -> str:
|
63
99
|
"""Pop a segment from the event."""
|
64
100
|
return self.segments.pop()
|
fabricatio/models/generic.py
CHANGED
@@ -1,17 +1,23 @@
|
|
1
1
|
"""This module defines generic classes for models in the Fabricatio library."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Callable, List, Self
|
4
|
+
from typing import Callable, Iterable, List, Optional, Self, Union, final
|
5
5
|
|
6
6
|
import orjson
|
7
7
|
from fabricatio._rust import blake3_hash
|
8
8
|
from fabricatio._rust_instances import template_manager
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.fs.readers import magika, safe_text_read
|
11
|
+
from fabricatio.parser import JsonCapture
|
11
12
|
from pydantic import (
|
12
13
|
BaseModel,
|
13
14
|
ConfigDict,
|
14
15
|
Field,
|
16
|
+
HttpUrl,
|
17
|
+
NonNegativeFloat,
|
18
|
+
PositiveFloat,
|
19
|
+
PositiveInt,
|
20
|
+
SecretStr,
|
15
21
|
)
|
16
22
|
|
17
23
|
|
@@ -48,22 +54,63 @@ class WithBriefing(Named, Described):
|
|
48
54
|
return f"{self.name}: {self.description}" if self.description else self.name
|
49
55
|
|
50
56
|
|
51
|
-
class
|
52
|
-
"""Class that provides a JSON schema
|
57
|
+
class WithFormatedJsonSchema(Base):
|
58
|
+
"""Class that provides a formatted JSON schema of the model."""
|
53
59
|
|
54
60
|
@classmethod
|
55
|
-
def
|
56
|
-
"""
|
61
|
+
def formated_json_schema(cls) -> str:
|
62
|
+
"""Get the JSON schema of the model in a formatted string.
|
57
63
|
|
58
64
|
Returns:
|
59
|
-
str:
|
65
|
+
str: The JSON schema of the model in a formatted string.
|
60
66
|
"""
|
61
67
|
return orjson.dumps(
|
62
|
-
|
68
|
+
cls.model_json_schema(),
|
63
69
|
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS,
|
64
70
|
).decode()
|
65
71
|
|
66
72
|
|
73
|
+
class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
74
|
+
"""Class that provides a prompt for creating a JSON object."""
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def create_json_prompt(cls, requirement: str) -> str:
|
78
|
+
"""Create the prompt for creating a JSON object with given requirement.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
requirement (str): The requirement for the JSON object.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
str: The prompt for creating a JSON object with given requirement.
|
85
|
+
"""
|
86
|
+
return template_manager.render_template(
|
87
|
+
configs.templates.create_json_obj_template,
|
88
|
+
{"requirement": requirement, "json_schema": cls.formated_json_schema()},
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
class InstantiateFromString(Base):
|
93
|
+
"""Class that provides a method to instantiate the class from a string."""
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def instantiate_from_string(cls, string: str) -> Self | None:
|
97
|
+
"""Instantiate the class from a string.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
string (str): The string to instantiate the class from.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Self | None: The instance of the class or None if the string is not valid.
|
104
|
+
"""
|
105
|
+
return JsonCapture.convert_with(string, cls.model_validate_json)
|
106
|
+
|
107
|
+
|
108
|
+
class ProposedAble(CreateJsonObjPrompt, InstantiateFromString):
|
109
|
+
"""Class that provides methods for proposing a task."""
|
110
|
+
|
111
|
+
pass
|
112
|
+
|
113
|
+
|
67
114
|
class WithDependency(Base):
|
68
115
|
"""Class that manages file dependencies."""
|
69
116
|
|
@@ -150,3 +197,107 @@ class WithDependency(Base):
|
|
150
197
|
for p in self.dependencies
|
151
198
|
},
|
152
199
|
)
|
200
|
+
|
201
|
+
|
202
|
+
class ScopedConfig(Base):
|
203
|
+
"""Class that manages a scoped configuration."""
|
204
|
+
|
205
|
+
llm_api_endpoint: Optional[HttpUrl] = None
|
206
|
+
"""The OpenAI API endpoint."""
|
207
|
+
|
208
|
+
llm_api_key: Optional[SecretStr] = None
|
209
|
+
"""The OpenAI API key."""
|
210
|
+
|
211
|
+
llm_timeout: Optional[PositiveInt] = None
|
212
|
+
"""The timeout of the LLM model."""
|
213
|
+
|
214
|
+
llm_max_retries: Optional[PositiveInt] = None
|
215
|
+
"""The maximum number of retries."""
|
216
|
+
|
217
|
+
llm_model: Optional[str] = None
|
218
|
+
"""The LLM model name."""
|
219
|
+
|
220
|
+
llm_temperature: Optional[NonNegativeFloat] = None
|
221
|
+
"""The temperature of the LLM model."""
|
222
|
+
|
223
|
+
llm_stop_sign: Optional[str | List[str]] = None
|
224
|
+
"""The stop sign of the LLM model."""
|
225
|
+
|
226
|
+
llm_top_p: Optional[NonNegativeFloat] = None
|
227
|
+
"""The top p of the LLM model."""
|
228
|
+
|
229
|
+
llm_generation_count: Optional[PositiveInt] = None
|
230
|
+
"""The number of generations to generate."""
|
231
|
+
|
232
|
+
llm_stream: Optional[bool] = None
|
233
|
+
"""Whether to stream the LLM model's response."""
|
234
|
+
|
235
|
+
llm_max_tokens: Optional[PositiveInt] = None
|
236
|
+
"""The maximum number of tokens to generate."""
|
237
|
+
|
238
|
+
embedding_api_endpoint: Optional[HttpUrl] = None
|
239
|
+
"""The OpenAI API endpoint."""
|
240
|
+
|
241
|
+
embedding_api_key: Optional[SecretStr] = None
|
242
|
+
"""The OpenAI API key."""
|
243
|
+
|
244
|
+
embedding_timeout: Optional[PositiveInt] = None
|
245
|
+
"""The timeout of the LLM model."""
|
246
|
+
|
247
|
+
embedding_model: Optional[str] = None
|
248
|
+
"""The LLM model name."""
|
249
|
+
|
250
|
+
embedding_max_sequence_length: Optional[PositiveInt] = None
|
251
|
+
"""The maximum sequence length."""
|
252
|
+
|
253
|
+
embedding_dimensions: Optional[PositiveInt] = None
|
254
|
+
"""The dimensions of the embedding."""
|
255
|
+
embedding_caching: Optional[bool] = False
|
256
|
+
"""Whether to cache the embedding result."""
|
257
|
+
|
258
|
+
milvus_uri: Optional[HttpUrl] = Field(default=None)
|
259
|
+
"""The URI of the Milvus server."""
|
260
|
+
milvus_token: Optional[SecretStr] = Field(default=None)
|
261
|
+
"""The token for the Milvus server."""
|
262
|
+
milvus_timeout: Optional[PositiveFloat] = Field(default=None)
|
263
|
+
"""The timeout for the Milvus server."""
|
264
|
+
milvus_dimensions: Optional[PositiveInt] = Field(default=None)
|
265
|
+
"""The dimensions of the Milvus server."""
|
266
|
+
|
267
|
+
@final
|
268
|
+
def fallback_to(self, other: "ScopedConfig") -> Self:
|
269
|
+
"""Fallback to another instance's attribute values if the current instance's attributes are None.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
other (LLMUsage): Another instance from which to copy attribute values.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
Self: The current instance, allowing for method chaining.
|
276
|
+
"""
|
277
|
+
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
|
278
|
+
# noinspection PydanticTypeChecker,PyTypeChecker
|
279
|
+
for attr_name in ScopedConfig.model_fields:
|
280
|
+
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
|
281
|
+
if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
|
282
|
+
setattr(self, attr_name, attr)
|
283
|
+
|
284
|
+
# Return the current instance to allow for method chaining
|
285
|
+
return self
|
286
|
+
|
287
|
+
@final
|
288
|
+
def hold_to(self, others: Union["ScopedConfig", Iterable["ScopedConfig"]]) -> Self:
|
289
|
+
"""Hold to another instance's attribute values if the current instance's attributes are None.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
Self: The current instance, allowing for method chaining.
|
296
|
+
"""
|
297
|
+
if not isinstance(others, Iterable):
|
298
|
+
others = [others]
|
299
|
+
for other in others:
|
300
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
301
|
+
for attr_name in ScopedConfig.model_fields:
|
302
|
+
if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
|
303
|
+
setattr(other, attr_name, attr)
|
@@ -5,6 +5,20 @@ from typing import List, NotRequired, TypedDict
|
|
5
5
|
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
|
6
6
|
|
7
7
|
|
8
|
+
class CollectionSimpleConfigKwargs(TypedDict):
|
9
|
+
"""A type representing the configuration for a collection."""
|
10
|
+
|
11
|
+
dimension: NotRequired[int]
|
12
|
+
timeout: NotRequired[float]
|
13
|
+
|
14
|
+
|
15
|
+
class FetchKwargs(TypedDict):
|
16
|
+
"""A type representing the keyword arguments for the fetch method."""
|
17
|
+
|
18
|
+
similarity_threshold: NotRequired[float]
|
19
|
+
result_per_query: NotRequired[int]
|
20
|
+
|
21
|
+
|
8
22
|
class EmbeddingKwargs(TypedDict):
|
9
23
|
"""A type representing the keyword arguments for the embedding method."""
|
10
24
|
|
fabricatio/models/task.py
CHANGED
@@ -4,7 +4,6 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from asyncio import Queue
|
7
|
-
from enum import Enum
|
8
7
|
from typing import Any, List, Optional, Self
|
9
8
|
|
10
9
|
from fabricatio._rust_instances import template_manager
|
@@ -12,35 +11,18 @@ from fabricatio.config import configs
|
|
12
11
|
from fabricatio.core import env
|
13
12
|
from fabricatio.journal import logger
|
14
13
|
from fabricatio.models.events import Event, EventLike
|
15
|
-
from fabricatio.models.generic import WithBriefing, WithDependency
|
14
|
+
from fabricatio.models.generic import ProposedAble, WithBriefing, WithDependency
|
15
|
+
from fabricatio.models.utils import TaskStatus
|
16
16
|
from pydantic import Field, PrivateAttr
|
17
17
|
|
18
18
|
|
19
|
-
class
|
20
|
-
"""An enumeration representing the status of a task.
|
21
|
-
|
22
|
-
Attributes:
|
23
|
-
Pending: The task is pending.
|
24
|
-
Running: The task is currently running.
|
25
|
-
Finished: The task has been successfully completed.
|
26
|
-
Failed: The task has failed.
|
27
|
-
Cancelled: The task has been cancelled.
|
28
|
-
"""
|
29
|
-
|
30
|
-
Pending = "pending"
|
31
|
-
Running = "running"
|
32
|
-
Finished = "finished"
|
33
|
-
Failed = "failed"
|
34
|
-
Cancelled = "cancelled"
|
35
|
-
|
36
|
-
|
37
|
-
class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
19
|
+
class Task[T](WithBriefing, ProposedAble, WithDependency):
|
38
20
|
"""A class representing a task with a status and output.
|
39
21
|
|
40
22
|
Attributes:
|
41
23
|
name (str): The name of the task.
|
42
24
|
description (str): The description of the task.
|
43
|
-
|
25
|
+
goals (str): The goal of the task.
|
44
26
|
dependencies (List[str]): The file dependencies of the task, a list of file paths.
|
45
27
|
namespace (List[str]): The namespace of the task, a list of namespace segment, as string.
|
46
28
|
"""
|
@@ -58,7 +40,7 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
58
40
|
"""A list of string segments that identify the task's location in the system. If not specified, defaults to an empty list."""
|
59
41
|
|
60
42
|
dependencies: List[str] = Field(default_factory=list)
|
61
|
-
"""A list of file paths that are needed (either reading or writing) to complete this task. If not specified, defaults to an empty list."""
|
43
|
+
"""A list of file paths that are needed or mentioned in the task's description (either reading or writing) to complete this task. If not specified, defaults to an empty list."""
|
62
44
|
|
63
45
|
_output: Queue[T | None] = PrivateAttr(default_factory=Queue)
|
64
46
|
"""The output queue of the task."""
|
fabricatio/models/usages.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack, overload
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
@@ -9,11 +9,11 @@ import orjson
|
|
9
9
|
from fabricatio._rust_instances import template_manager
|
10
10
|
from fabricatio.config import configs
|
11
11
|
from fabricatio.journal import logger
|
12
|
-
from fabricatio.models.generic import
|
12
|
+
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
13
13
|
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
|
14
14
|
from fabricatio.models.task import Task
|
15
15
|
from fabricatio.models.tool import Tool, ToolBox
|
16
|
-
from fabricatio.models.utils import Messages
|
16
|
+
from fabricatio.models.utils import Messages
|
17
17
|
from fabricatio.parser import JsonCapture
|
18
18
|
from litellm import stream_chunk_builder
|
19
19
|
from litellm.types.utils import (
|
@@ -23,135 +23,15 @@ from litellm.types.utils import (
|
|
23
23
|
StreamingChoices,
|
24
24
|
)
|
25
25
|
from litellm.utils import CustomStreamWrapper
|
26
|
-
from pydantic import Field,
|
26
|
+
from pydantic import Field, NonNegativeInt, PositiveInt
|
27
27
|
|
28
28
|
|
29
|
-
class LLMUsage(
|
29
|
+
class LLMUsage(ScopedConfig):
|
30
30
|
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
llm_api_key: Optional[SecretStr] = None
|
36
|
-
"""The OpenAI API key."""
|
37
|
-
|
38
|
-
llm_timeout: Optional[PositiveInt] = None
|
39
|
-
"""The timeout of the LLM model."""
|
40
|
-
|
41
|
-
llm_max_retries: Optional[PositiveInt] = None
|
42
|
-
"""The maximum number of retries."""
|
43
|
-
|
44
|
-
llm_model: Optional[str] = None
|
45
|
-
"""The LLM model name."""
|
46
|
-
|
47
|
-
llm_temperature: Optional[NonNegativeFloat] = None
|
48
|
-
"""The temperature of the LLM model."""
|
49
|
-
|
50
|
-
llm_stop_sign: Optional[str | List[str]] = None
|
51
|
-
"""The stop sign of the LLM model."""
|
52
|
-
|
53
|
-
llm_top_p: Optional[NonNegativeFloat] = None
|
54
|
-
"""The top p of the LLM model."""
|
55
|
-
|
56
|
-
llm_generation_count: Optional[PositiveInt] = None
|
57
|
-
"""The number of generations to generate."""
|
58
|
-
|
59
|
-
llm_stream: Optional[bool] = None
|
60
|
-
"""Whether to stream the LLM model's response."""
|
61
|
-
|
62
|
-
llm_max_tokens: Optional[PositiveInt] = None
|
63
|
-
"""The maximum number of tokens to generate."""
|
64
|
-
|
65
|
-
async def aembedding(
|
66
|
-
self,
|
67
|
-
input_text: List[str],
|
68
|
-
model: Optional[str] = None,
|
69
|
-
dimensions: Optional[int] = None,
|
70
|
-
timeout: Optional[PositiveInt] = None,
|
71
|
-
caching: Optional[bool] = False,
|
72
|
-
) -> EmbeddingResponse:
|
73
|
-
"""Asynchronously generates embeddings for the given input text.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
input_text (List[str]): A list of strings to generate embeddings for.
|
77
|
-
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
78
|
-
dimensions (Optional[int]): The dimensions of the embedding. Defaults to None.
|
79
|
-
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
80
|
-
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
81
|
-
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
EmbeddingResponse: The response containing the embeddings.
|
85
|
-
"""
|
86
|
-
return await litellm.aembedding(
|
87
|
-
input=input_text,
|
88
|
-
caching=caching,
|
89
|
-
dimensions=dimensions,
|
90
|
-
model=model or self.llm_model or configs.llm.model,
|
91
|
-
timeout=timeout or self.llm_timeout or configs.llm.timeout,
|
92
|
-
api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
|
93
|
-
api_base=self.llm_api_endpoint.unicode_string().rstrip(
|
94
|
-
"/"
|
95
|
-
) # seems embedding function takes no base_url end with a slash
|
96
|
-
if self.llm_api_endpoint
|
97
|
-
else configs.llm.api_endpoint.unicode_string().rstrip("/"),
|
98
|
-
)
|
99
|
-
|
100
|
-
@overload
|
101
|
-
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
102
|
-
@overload
|
103
|
-
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
104
|
-
|
105
|
-
async def vectorize(
|
106
|
-
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
107
|
-
) -> List[List[float]] | List[float]:
|
108
|
-
"""Asynchronously generates vector embeddings for the given input text.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
112
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
113
|
-
|
114
|
-
Returns:
|
115
|
-
List[List[float]] | List[float]: The generated embeddings.
|
116
|
-
"""
|
117
|
-
if isinstance(input_text, str):
|
118
|
-
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
119
|
-
|
120
|
-
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
121
|
-
|
122
|
-
@overload
|
123
|
-
async def pack(
|
124
|
-
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
125
|
-
) -> List[MilvusData]: ...
|
126
|
-
@overload
|
127
|
-
async def pack(
|
128
|
-
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
129
|
-
) -> MilvusData: ...
|
130
|
-
|
131
|
-
async def pack(
|
132
|
-
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
133
|
-
) -> List[MilvusData] | MilvusData:
|
134
|
-
"""Asynchronously generates MilvusData objects for the given input text.
|
135
|
-
|
136
|
-
Args:
|
137
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
138
|
-
subject (Optional[str]): The subject of the input text. Defaults to None.
|
139
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
143
|
-
"""
|
144
|
-
if isinstance(input_text, str):
|
145
|
-
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
146
|
-
vecs = await self.vectorize(input_text, **kwargs)
|
147
|
-
return [
|
148
|
-
MilvusData(
|
149
|
-
vector=vec,
|
150
|
-
text=text,
|
151
|
-
subject=subject,
|
152
|
-
)
|
153
|
-
for text, vec in zip(input_text, vecs, strict=True)
|
154
|
-
]
|
32
|
+
@classmethod
|
33
|
+
def _scoped_model(cls) -> Type["LLMUsage"]:
|
34
|
+
return LLMUsage
|
155
35
|
|
156
36
|
async def aquery(
|
157
37
|
self,
|
@@ -181,10 +61,8 @@ class LLMUsage(Base):
|
|
181
61
|
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
182
62
|
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
183
63
|
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
184
|
-
api_key=
|
185
|
-
base_url=self.llm_api_endpoint.unicode_string()
|
186
|
-
if self.llm_api_endpoint
|
187
|
-
else configs.llm.api_endpoint.unicode_string(),
|
64
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
65
|
+
base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
188
66
|
)
|
189
67
|
|
190
68
|
async def ainvoke(
|
@@ -213,13 +91,13 @@ class LLMUsage(Base):
|
|
213
91
|
if isinstance(resp, ModelResponse):
|
214
92
|
return resp.choices
|
215
93
|
if isinstance(resp, CustomStreamWrapper):
|
216
|
-
if configs.debug.streaming_visible:
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
return stream_chunk_builder(
|
94
|
+
if not configs.debug.streaming_visible:
|
95
|
+
return stream_chunk_builder(await asyncstdlib.list()).choices
|
96
|
+
chunks = []
|
97
|
+
async for chunk in resp:
|
98
|
+
chunks.append(chunk)
|
99
|
+
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
100
|
+
return stream_chunk_builder(chunks).choices
|
223
101
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
224
102
|
raise ValueError(err)
|
225
103
|
|
@@ -361,6 +239,26 @@ class LLMUsage(Base):
|
|
361
239
|
"""
|
362
240
|
return await gather(*[self.aask_validate(question, validator, **kwargs) for question in questions])
|
363
241
|
|
242
|
+
async def aliststr(self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[GenerateKwargs]) -> List[str]:
|
243
|
+
"""Asynchronously generates a list of strings based on a given requirement.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
requirement (str): The requirement for the list of strings.
|
247
|
+
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
248
|
+
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
List[str]: The validated response as a list of strings.
|
252
|
+
"""
|
253
|
+
return await self.aask_validate(
|
254
|
+
template_manager.render_template(
|
255
|
+
configs.templates.liststr_template,
|
256
|
+
{"requirement": requirement, "k": k},
|
257
|
+
),
|
258
|
+
lambda resp: JsonCapture.validate_with(resp, orjson.loads, list, str, k),
|
259
|
+
**kwargs,
|
260
|
+
)
|
261
|
+
|
364
262
|
async def achoose[T: WithBriefing](
|
365
263
|
self,
|
366
264
|
instruction: str,
|
@@ -388,7 +286,7 @@ class LLMUsage(Base):
|
|
388
286
|
configs.templates.make_choice_template,
|
389
287
|
{
|
390
288
|
"instruction": instruction,
|
391
|
-
"options": [{"name"
|
289
|
+
"options": [m.model_dump(include={"name", "briefing"}) for m in choices],
|
392
290
|
"k": k,
|
393
291
|
},
|
394
292
|
)
|
@@ -475,39 +373,82 @@ class LLMUsage(Base):
|
|
475
373
|
**kwargs,
|
476
374
|
)
|
477
375
|
|
478
|
-
|
479
|
-
|
376
|
+
|
377
|
+
class EmbeddingUsage(LLMUsage):
|
378
|
+
"""A class representing the embedding model."""
|
379
|
+
|
380
|
+
async def aembedding(
|
381
|
+
self,
|
382
|
+
input_text: List[str],
|
383
|
+
model: Optional[str] = None,
|
384
|
+
dimensions: Optional[int] = None,
|
385
|
+
timeout: Optional[PositiveInt] = None,
|
386
|
+
caching: Optional[bool] = False,
|
387
|
+
) -> EmbeddingResponse:
|
388
|
+
"""Asynchronously generates embeddings for the given input text.
|
480
389
|
|
481
390
|
Args:
|
482
|
-
|
391
|
+
input_text (List[str]): A list of strings to generate embeddings for.
|
392
|
+
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
393
|
+
dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
|
394
|
+
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
395
|
+
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
396
|
+
|
483
397
|
|
484
398
|
Returns:
|
485
|
-
|
399
|
+
EmbeddingResponse: The response containing the embeddings.
|
486
400
|
"""
|
487
|
-
#
|
488
|
-
|
489
|
-
for
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
401
|
+
# check seq length
|
402
|
+
max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
|
403
|
+
if any(len(t) > max_len for t in input_text):
|
404
|
+
logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
|
405
|
+
raise ValueError(err)
|
406
|
+
|
407
|
+
return await litellm.aembedding(
|
408
|
+
input=input_text,
|
409
|
+
caching=caching or self.embedding_caching or configs.embedding.caching,
|
410
|
+
dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
|
411
|
+
model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
|
412
|
+
timeout=timeout
|
413
|
+
or self.embedding_timeout
|
414
|
+
or configs.embedding.timeout
|
415
|
+
or self.llm_timeout
|
416
|
+
or configs.llm.timeout,
|
417
|
+
api_key=(
|
418
|
+
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
419
|
+
).get_secret_value(),
|
420
|
+
api_base=(
|
421
|
+
self.embedding_api_endpoint
|
422
|
+
or configs.embedding.api_endpoint
|
423
|
+
or self.llm_api_endpoint
|
424
|
+
or configs.llm.api_endpoint
|
425
|
+
)
|
426
|
+
.unicode_string()
|
427
|
+
.rstrip("/"),
|
428
|
+
# seems embedding function takes no base_url end with a slash
|
429
|
+
)
|
430
|
+
|
431
|
+
@overload
|
432
|
+
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
433
|
+
@overload
|
434
|
+
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
496
435
|
|
497
|
-
def
|
498
|
-
|
436
|
+
async def vectorize(
|
437
|
+
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
438
|
+
) -> List[List[float]] | List[float]:
|
439
|
+
"""Asynchronously generates vector embeddings for the given input text.
|
499
440
|
|
500
441
|
Args:
|
501
|
-
|
442
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
443
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
502
444
|
|
503
445
|
Returns:
|
504
|
-
|
446
|
+
List[List[float]] | List[float]: The generated embeddings.
|
505
447
|
"""
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
setattr(other, attr_name, attr)
|
448
|
+
if isinstance(input_text, str):
|
449
|
+
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
450
|
+
|
451
|
+
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
511
452
|
|
512
453
|
|
513
454
|
class ToolBoxUsage(LLMUsage):
|