fabricatio 0.2.6.dev3__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +60 -0
- fabricatio/_rust.cp39-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +116 -0
- fabricatio/_rust_instances.py +10 -0
- fabricatio/actions/article.py +81 -0
- fabricatio/actions/output.py +19 -0
- fabricatio/actions/rag.py +25 -0
- fabricatio/capabilities/correct.py +115 -0
- fabricatio/capabilities/propose.py +49 -0
- fabricatio/capabilities/rag.py +369 -0
- fabricatio/capabilities/rating.py +339 -0
- fabricatio/capabilities/review.py +278 -0
- fabricatio/capabilities/task.py +113 -0
- fabricatio/config.py +400 -0
- fabricatio/core.py +181 -0
- fabricatio/decorators.py +179 -0
- fabricatio/fs/__init__.py +29 -0
- fabricatio/fs/curd.py +149 -0
- fabricatio/fs/readers.py +46 -0
- fabricatio/journal.py +21 -0
- fabricatio/models/action.py +158 -0
- fabricatio/models/events.py +120 -0
- fabricatio/models/extra.py +171 -0
- fabricatio/models/generic.py +406 -0
- fabricatio/models/kwargs_types.py +158 -0
- fabricatio/models/role.py +48 -0
- fabricatio/models/task.py +299 -0
- fabricatio/models/tool.py +189 -0
- fabricatio/models/usages.py +682 -0
- fabricatio/models/utils.py +167 -0
- fabricatio/parser.py +149 -0
- fabricatio/py.typed +0 -0
- fabricatio/toolboxes/__init__.py +15 -0
- fabricatio/toolboxes/arithmetic.py +62 -0
- fabricatio/toolboxes/fs.py +31 -0
- fabricatio/workflows/articles.py +15 -0
- fabricatio/workflows/rag.py +11 -0
- fabricatio-0.2.6.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.6.dev3.dist-info/METADATA +432 -0
- fabricatio-0.2.6.dev3.dist-info/RECORD +42 -0
- fabricatio-0.2.6.dev3.dist-info/WHEEL +4 -0
- fabricatio-0.2.6.dev3.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,682 @@
|
|
1
|
+
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
|
+
|
3
|
+
from asyncio import gather
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Type, Union, Unpack, overload
|
5
|
+
|
6
|
+
import asyncstdlib
|
7
|
+
import litellm
|
8
|
+
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
9
|
+
from fabricatio.config import configs
|
10
|
+
from fabricatio.journal import logger
|
11
|
+
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
12
|
+
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
13
|
+
from fabricatio.models.task import Task
|
14
|
+
from fabricatio.models.tool import Tool, ToolBox
|
15
|
+
from fabricatio.models.utils import Messages
|
16
|
+
from fabricatio.parser import GenericCapture, JsonCapture
|
17
|
+
from litellm import Router, stream_chunk_builder
|
18
|
+
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
19
|
+
from litellm.types.utils import (
|
20
|
+
Choices,
|
21
|
+
EmbeddingResponse,
|
22
|
+
ModelResponse,
|
23
|
+
StreamingChoices,
|
24
|
+
TextChoices,
|
25
|
+
)
|
26
|
+
from litellm.utils import CustomStreamWrapper # pyright: ignore [reportPrivateImportUsage]
|
27
|
+
from more_itertools import duplicates_everseen
|
28
|
+
from pydantic import Field, NonNegativeInt, PositiveInt
|
29
|
+
|
30
|
+
if configs.cache.enabled and configs.cache.type:
|
31
|
+
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
32
|
+
logger.success(f"{configs.cache.type.name} Cache enabled")
|
33
|
+
|
34
|
+
ROUTER = Router(
|
35
|
+
routing_strategy="usage-based-routing-v2",
|
36
|
+
allowed_fails=configs.routing.allowed_fails,
|
37
|
+
retry_after=configs.routing.retry_after,
|
38
|
+
cooldown_time=configs.routing.cooldown_time,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class LLMUsage(ScopedConfig):
|
43
|
+
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
44
|
+
|
45
|
+
def _deploy(self, deployment: Deployment) -> Router:
|
46
|
+
"""Add a deployment to the router."""
|
47
|
+
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
48
|
+
return ROUTER
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def _scoped_model(cls) -> Type["LLMUsage"]:
|
52
|
+
return LLMUsage
|
53
|
+
|
54
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
55
|
+
async def aquery(
|
56
|
+
self,
|
57
|
+
messages: List[Dict[str, str]],
|
58
|
+
n: PositiveInt | None = None,
|
59
|
+
**kwargs: Unpack[LLMKwargs],
|
60
|
+
) -> ModelResponse | CustomStreamWrapper:
|
61
|
+
"""Asynchronously queries the language model to generate a response based on the provided messages and parameters.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
|
65
|
+
n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
|
66
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
ModelResponse | CustomStreamWrapper: An object containing the generated response and other metadata from the model.
|
70
|
+
"""
|
71
|
+
# Call the underlying asynchronous completion function with the provided and default parameters
|
72
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
73
|
+
|
74
|
+
return await self._deploy(
|
75
|
+
Deployment(
|
76
|
+
model_name=(m_name := kwargs.get("model") or self.llm_model or configs.llm.model),
|
77
|
+
litellm_params=(
|
78
|
+
p := LiteLLM_Params(
|
79
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
80
|
+
api_base=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
81
|
+
model=m_name,
|
82
|
+
tpm=self.llm_tpm or configs.llm.tpm,
|
83
|
+
rpm=self.llm_rpm or configs.llm.rpm,
|
84
|
+
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
85
|
+
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
86
|
+
)
|
87
|
+
),
|
88
|
+
model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
|
89
|
+
)
|
90
|
+
).acompletion(
|
91
|
+
messages=messages,
|
92
|
+
n=n or self.llm_generation_count or configs.llm.generation_count,
|
93
|
+
model=m_name,
|
94
|
+
temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
|
95
|
+
stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
|
96
|
+
top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
|
97
|
+
max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
|
98
|
+
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
99
|
+
cache={
|
100
|
+
"no-cache": kwargs.get("no_cache"),
|
101
|
+
"no-store": kwargs.get("no_store"),
|
102
|
+
"cache-ttl": kwargs.get("cache_ttl"),
|
103
|
+
"s-maxage": kwargs.get("s_maxage"),
|
104
|
+
},
|
105
|
+
)
|
106
|
+
|
107
|
+
async def ainvoke(
|
108
|
+
self,
|
109
|
+
question: str,
|
110
|
+
system_message: str = "",
|
111
|
+
n: PositiveInt | None = None,
|
112
|
+
**kwargs: Unpack[LLMKwargs],
|
113
|
+
) -> Sequence[TextChoices | Choices | StreamingChoices]:
|
114
|
+
"""Asynchronously invokes the language model with a question and optional system message.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
question (str): The question to ask the model.
|
118
|
+
system_message (str): The system message to provide context to the model. Defaults to an empty string.
|
119
|
+
n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
|
120
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
List[Choices | StreamingChoices]: A list of choices or streaming choices from the model response.
|
124
|
+
"""
|
125
|
+
resp = await self.aquery(
|
126
|
+
messages=Messages().add_system_message(system_message).add_user_message(question),
|
127
|
+
n=n,
|
128
|
+
**kwargs,
|
129
|
+
)
|
130
|
+
if isinstance(resp, ModelResponse):
|
131
|
+
return resp.choices
|
132
|
+
if isinstance(resp, CustomStreamWrapper):
|
133
|
+
if not configs.debug.streaming_visible and (pack := stream_chunk_builder(await asyncstdlib.list())):
|
134
|
+
return pack.choices
|
135
|
+
chunks = []
|
136
|
+
async for chunk in resp:
|
137
|
+
chunks.append(chunk)
|
138
|
+
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
139
|
+
if pack := stream_chunk_builder(chunks):
|
140
|
+
return pack.choices
|
141
|
+
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
142
|
+
raise ValueError(err)
|
143
|
+
|
144
|
+
@overload
|
145
|
+
async def aask(
|
146
|
+
self,
|
147
|
+
question: List[str],
|
148
|
+
system_message: List[str],
|
149
|
+
**kwargs: Unpack[LLMKwargs],
|
150
|
+
) -> List[str]: ...
|
151
|
+
@overload
|
152
|
+
async def aask(
|
153
|
+
self,
|
154
|
+
question: str,
|
155
|
+
system_message: List[str],
|
156
|
+
**kwargs: Unpack[LLMKwargs],
|
157
|
+
) -> List[str]: ...
|
158
|
+
@overload
|
159
|
+
async def aask(
|
160
|
+
self,
|
161
|
+
question: List[str],
|
162
|
+
system_message: Optional[str] = None,
|
163
|
+
**kwargs: Unpack[LLMKwargs],
|
164
|
+
) -> List[str]: ...
|
165
|
+
|
166
|
+
@overload
|
167
|
+
async def aask(
|
168
|
+
self,
|
169
|
+
question: str,
|
170
|
+
system_message: Optional[str] = None,
|
171
|
+
**kwargs: Unpack[LLMKwargs],
|
172
|
+
) -> str: ...
|
173
|
+
|
174
|
+
async def aask(
|
175
|
+
self,
|
176
|
+
question: str | List[str],
|
177
|
+
system_message: Optional[str | List[str]] = None,
|
178
|
+
**kwargs: Unpack[LLMKwargs],
|
179
|
+
) -> str | List[str]:
|
180
|
+
"""Asynchronously asks the language model a question and returns the response content.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
question (str | List[str]): The question to ask the model.
|
184
|
+
system_message (str | List[str] | None): The system message to provide context to the model. Defaults to an empty string.
|
185
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
str | List[str]: The content of the model's response message.
|
189
|
+
"""
|
190
|
+
system_message = system_message or ""
|
191
|
+
match (question, system_message):
|
192
|
+
case (list(q_seq), list(sm_seq)):
|
193
|
+
res = await gather(
|
194
|
+
*[
|
195
|
+
self.ainvoke(n=1, question=q, system_message=sm, **kwargs)
|
196
|
+
for q, sm in zip(q_seq, sm_seq, strict=True)
|
197
|
+
]
|
198
|
+
)
|
199
|
+
return [r[0].message.content for r in res]
|
200
|
+
case (list(q_seq), str(sm)):
|
201
|
+
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
|
202
|
+
return [r[0].message.content for r in res]
|
203
|
+
case (str(q), list(sm_seq)):
|
204
|
+
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
|
205
|
+
return [r[0].message.content for r in res]
|
206
|
+
case (str(q), str(sm)):
|
207
|
+
return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
|
208
|
+
case _:
|
209
|
+
raise RuntimeError("Should not reach here.")
|
210
|
+
|
211
|
+
@overload
|
212
|
+
async def aask_validate[T](
|
213
|
+
self,
|
214
|
+
question: str,
|
215
|
+
validator: Callable[[str], T | None],
|
216
|
+
default: T,
|
217
|
+
max_validations: PositiveInt = 2,
|
218
|
+
**kwargs: Unpack[GenerateKwargs],
|
219
|
+
) -> T: ...
|
220
|
+
@overload
|
221
|
+
async def aask_validate[T](
|
222
|
+
self,
|
223
|
+
question: List[str],
|
224
|
+
validator: Callable[[str], T | None],
|
225
|
+
default: T,
|
226
|
+
max_validations: PositiveInt = 2,
|
227
|
+
**kwargs: Unpack[GenerateKwargs],
|
228
|
+
) -> List[T]: ...
|
229
|
+
@overload
|
230
|
+
async def aask_validate[T](
|
231
|
+
self,
|
232
|
+
question: str,
|
233
|
+
validator: Callable[[str], T | None],
|
234
|
+
default: None = None,
|
235
|
+
max_validations: PositiveInt = 2,
|
236
|
+
**kwargs: Unpack[GenerateKwargs],
|
237
|
+
) -> Optional[T]: ...
|
238
|
+
|
239
|
+
@overload
|
240
|
+
async def aask_validate[T](
|
241
|
+
self,
|
242
|
+
question: List[str],
|
243
|
+
validator: Callable[[str], T | None],
|
244
|
+
default: None = None,
|
245
|
+
max_validations: PositiveInt = 2,
|
246
|
+
**kwargs: Unpack[GenerateKwargs],
|
247
|
+
) -> List[Optional[T]]: ...
|
248
|
+
|
249
|
+
async def aask_validate[T](
|
250
|
+
self,
|
251
|
+
question: str | List[str],
|
252
|
+
validator: Callable[[str], T | None],
|
253
|
+
default: Optional[T] = None,
|
254
|
+
max_validations: PositiveInt = 2,
|
255
|
+
**kwargs: Unpack[GenerateKwargs],
|
256
|
+
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
257
|
+
"""Asynchronously asks a question and validates the response using a given validator.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
question (str): The question to ask.
|
261
|
+
validator (Callable[[str], T | None]): A function to validate the response.
|
262
|
+
default (T | None): Default value to return if validation fails. Defaults to None.
|
263
|
+
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
|
264
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
T: The validated response.
|
268
|
+
|
269
|
+
"""
|
270
|
+
|
271
|
+
async def _inner(q: str) -> Optional[T]:
|
272
|
+
for lap in range(max_validations):
|
273
|
+
try:
|
274
|
+
if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
|
275
|
+
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
276
|
+
return validated
|
277
|
+
except Exception as e: # noqa: BLE001
|
278
|
+
logger.error(f"Error during validation: \n{e}")
|
279
|
+
break
|
280
|
+
kwargs["no_cache"] = True
|
281
|
+
logger.debug("Closed the cache for the next attempt")
|
282
|
+
if default is None:
|
283
|
+
logger.error(f"Failed to validate the response after {max_validations} attempts.")
|
284
|
+
return default
|
285
|
+
|
286
|
+
if isinstance(question, str):
|
287
|
+
return await _inner(question)
|
288
|
+
|
289
|
+
return await gather(*[_inner(q) for q in question])
|
290
|
+
|
291
|
+
async def aliststr(
|
292
|
+
self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
|
293
|
+
) -> List[str]:
|
294
|
+
"""Asynchronously generates a list of strings based on a given requirement.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
requirement (str): The requirement for the list of strings.
|
298
|
+
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
299
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
List[str]: The validated response as a list of strings.
|
303
|
+
"""
|
304
|
+
return await self.aask_validate(
|
305
|
+
TEMPLATE_MANAGER.render_template(
|
306
|
+
configs.templates.liststr_template,
|
307
|
+
{"requirement": requirement, "k": k},
|
308
|
+
),
|
309
|
+
lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
|
310
|
+
**kwargs,
|
311
|
+
)
|
312
|
+
|
313
|
+
async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> List[str]:
|
314
|
+
"""Asynchronously generates a list of strings based on a given requirement.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
requirement (str): The requirement for the list of strings.
|
318
|
+
**kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the LLM usage.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
List[str]: The validated response as a list of strings.
|
322
|
+
"""
|
323
|
+
return await self.aliststr(
|
324
|
+
TEMPLATE_MANAGER.render_template(
|
325
|
+
configs.templates.pathstr_template,
|
326
|
+
{"requirement": requirement},
|
327
|
+
),
|
328
|
+
**kwargs,
|
329
|
+
)
|
330
|
+
|
331
|
+
async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[ValidateKwargs[List[str]]]) -> str:
|
332
|
+
"""Asynchronously generates a single path string based on a given requirement.
|
333
|
+
|
334
|
+
Args:
|
335
|
+
requirement (str): The requirement for the list of strings.
|
336
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
str: The validated response as a single string.
|
340
|
+
"""
|
341
|
+
return (
|
342
|
+
await self.apathstr(
|
343
|
+
requirement,
|
344
|
+
k=1,
|
345
|
+
**kwargs,
|
346
|
+
)
|
347
|
+
).pop()
|
348
|
+
|
349
|
+
async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> str:
|
350
|
+
"""Asynchronously generates a generic string based on a given requirement.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
requirement (str): The requirement for the string.
|
354
|
+
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
str: The generated string.
|
358
|
+
"""
|
359
|
+
return await self.aask_validate(
|
360
|
+
TEMPLATE_MANAGER.render_template(
|
361
|
+
configs.templates.generic_string_template,
|
362
|
+
{"requirement": requirement, "language": GenericCapture.capture_type},
|
363
|
+
),
|
364
|
+
validator=lambda resp: GenericCapture.capture(resp),
|
365
|
+
**kwargs,
|
366
|
+
)
|
367
|
+
|
368
|
+
async def achoose[T: WithBriefing](
|
369
|
+
self,
|
370
|
+
instruction: str,
|
371
|
+
choices: List[T],
|
372
|
+
k: NonNegativeInt = 0,
|
373
|
+
**kwargs: Unpack[ValidateKwargs[List[T]]],
|
374
|
+
) -> List[T]:
|
375
|
+
"""Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
|
376
|
+
|
377
|
+
Args:
|
378
|
+
instruction (str): The user-provided instruction/question description.
|
379
|
+
choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
|
380
|
+
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
381
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
382
|
+
|
383
|
+
Returns:
|
384
|
+
List[T]: The final validated selection result list, with element types matching the input `choices`.
|
385
|
+
|
386
|
+
Important:
|
387
|
+
- Uses a template engine to generate structured prompts.
|
388
|
+
- Ensures response compliance through JSON parsing and format validation.
|
389
|
+
- Relies on `aask_validate` to implement retry mechanisms with validation.
|
390
|
+
"""
|
391
|
+
if dup := duplicates_everseen(choices, key=lambda x: x.name):
|
392
|
+
logger.error(err := f"Redundant choices: {dup}")
|
393
|
+
raise ValueError(err)
|
394
|
+
prompt = TEMPLATE_MANAGER.render_template(
|
395
|
+
configs.templates.make_choice_template,
|
396
|
+
{
|
397
|
+
"instruction": instruction,
|
398
|
+
"options": [m.model_dump(include={"name", "briefing"}) for m in choices],
|
399
|
+
"k": k,
|
400
|
+
},
|
401
|
+
)
|
402
|
+
names = {c.name for c in choices}
|
403
|
+
|
404
|
+
logger.debug(f"Start choosing between {names} with prompt: \n{prompt}")
|
405
|
+
|
406
|
+
def _validate(response: str) -> List[T] | None:
|
407
|
+
ret = JsonCapture.validate_with(response, target_type=List, elements_type=str, length=k)
|
408
|
+
if ret is None or set(ret) - names:
|
409
|
+
return None
|
410
|
+
return [
|
411
|
+
next(candidate for candidate in choices if candidate.name == candidate_name) for candidate_name in ret
|
412
|
+
]
|
413
|
+
|
414
|
+
return await self.aask_validate(
|
415
|
+
question=prompt,
|
416
|
+
validator=_validate,
|
417
|
+
**kwargs,
|
418
|
+
)
|
419
|
+
|
420
|
+
async def apick[T: WithBriefing](
|
421
|
+
self,
|
422
|
+
instruction: str,
|
423
|
+
choices: List[T],
|
424
|
+
**kwargs: Unpack[ValidateKwargs[List[T]]],
|
425
|
+
) -> T:
|
426
|
+
"""Asynchronously picks a single choice from a list of options using AI validation.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
instruction (str): The user-provided instruction/question description.
|
430
|
+
choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
|
431
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
T: The single selected item from the choices list.
|
435
|
+
|
436
|
+
Raises:
|
437
|
+
ValueError: If validation fails after maximum attempts or if no valid selection is made.
|
438
|
+
"""
|
439
|
+
return (
|
440
|
+
await self.achoose(
|
441
|
+
instruction=instruction,
|
442
|
+
choices=choices,
|
443
|
+
k=1,
|
444
|
+
**kwargs,
|
445
|
+
)
|
446
|
+
)[0]
|
447
|
+
|
448
|
+
async def ajudge(
|
449
|
+
self,
|
450
|
+
prompt: str,
|
451
|
+
affirm_case: str = "",
|
452
|
+
deny_case: str = "",
|
453
|
+
**kwargs: Unpack[ValidateKwargs[bool]],
|
454
|
+
) -> bool:
|
455
|
+
"""Asynchronously judges a prompt using AI validation.
|
456
|
+
|
457
|
+
Args:
|
458
|
+
prompt (str): The input prompt to be judged.
|
459
|
+
affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
|
460
|
+
deny_case (str): The negative case for the AI model. Defaults to an empty string.
|
461
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
462
|
+
|
463
|
+
Returns:
|
464
|
+
bool: The judgment result (True or False) based on the AI's response.
|
465
|
+
"""
|
466
|
+
return await self.aask_validate(
|
467
|
+
question=TEMPLATE_MANAGER.render_template(
|
468
|
+
configs.templates.make_judgment_template,
|
469
|
+
{"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
|
470
|
+
),
|
471
|
+
validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
|
472
|
+
**kwargs,
|
473
|
+
)
|
474
|
+
|
475
|
+
|
476
|
+
class EmbeddingUsage(LLMUsage):
|
477
|
+
"""A class representing the embedding model."""
|
478
|
+
|
479
|
+
async def aembedding(
|
480
|
+
self,
|
481
|
+
input_text: List[str],
|
482
|
+
model: Optional[str] = None,
|
483
|
+
dimensions: Optional[int] = None,
|
484
|
+
timeout: Optional[PositiveInt] = None,
|
485
|
+
caching: Optional[bool] = False,
|
486
|
+
) -> EmbeddingResponse:
|
487
|
+
"""Asynchronously generates embeddings for the given input text.
|
488
|
+
|
489
|
+
Args:
|
490
|
+
input_text (List[str]): A list of strings to generate embeddings for.
|
491
|
+
model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
|
492
|
+
dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
|
493
|
+
timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
|
494
|
+
caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
|
495
|
+
|
496
|
+
|
497
|
+
Returns:
|
498
|
+
EmbeddingResponse: The response containing the embeddings.
|
499
|
+
"""
|
500
|
+
# check seq length
|
501
|
+
max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
|
502
|
+
if any(len(t) > max_len for t in input_text):
|
503
|
+
logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
|
504
|
+
raise ValueError(err)
|
505
|
+
|
506
|
+
return await litellm.aembedding(
|
507
|
+
input=input_text,
|
508
|
+
caching=caching or self.embedding_caching or configs.embedding.caching,
|
509
|
+
dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
|
510
|
+
model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
|
511
|
+
timeout=timeout
|
512
|
+
or self.embedding_timeout
|
513
|
+
or configs.embedding.timeout
|
514
|
+
or self.llm_timeout
|
515
|
+
or configs.llm.timeout,
|
516
|
+
api_key=(
|
517
|
+
self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
|
518
|
+
).get_secret_value(),
|
519
|
+
api_base=(
|
520
|
+
self.embedding_api_endpoint
|
521
|
+
or configs.embedding.api_endpoint
|
522
|
+
or self.llm_api_endpoint
|
523
|
+
or configs.llm.api_endpoint
|
524
|
+
)
|
525
|
+
.unicode_string()
|
526
|
+
.rstrip("/"),
|
527
|
+
# seems embedding function takes no base_url end with a slash
|
528
|
+
)
|
529
|
+
|
530
|
+
@overload
|
531
|
+
async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
|
532
|
+
@overload
|
533
|
+
async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
|
534
|
+
|
535
|
+
async def vectorize(
|
536
|
+
self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
|
537
|
+
) -> List[List[float]] | List[float]:
|
538
|
+
"""Asynchronously generates vector embeddings for the given input text.
|
539
|
+
|
540
|
+
Args:
|
541
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
542
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
List[List[float]] | List[float]: The generated embeddings.
|
546
|
+
"""
|
547
|
+
if isinstance(input_text, str):
|
548
|
+
return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
|
549
|
+
|
550
|
+
return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
|
551
|
+
|
552
|
+
|
553
|
+
class ToolBoxUsage(LLMUsage):
|
554
|
+
"""A class representing the usage of tools in a task."""
|
555
|
+
|
556
|
+
toolboxes: Set[ToolBox] = Field(default_factory=set)
|
557
|
+
"""A set of toolboxes used by the instance."""
|
558
|
+
|
559
|
+
@property
|
560
|
+
def available_toolbox_names(self) -> List[str]:
|
561
|
+
"""Return a list of available toolbox names."""
|
562
|
+
return [toolbox.name for toolbox in self.toolboxes]
|
563
|
+
|
564
|
+
async def choose_toolboxes(
|
565
|
+
self,
|
566
|
+
task: Task,
|
567
|
+
**kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
|
568
|
+
) -> List[ToolBox]:
|
569
|
+
"""Asynchronously executes a multi-choice decision-making process to choose toolboxes.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
task (Task): The task for which to choose toolboxes.
|
573
|
+
system_message (str): Custom system-level prompt, defaults to an empty string.
|
574
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
575
|
+
|
576
|
+
Returns:
|
577
|
+
List[ToolBox]: The selected toolboxes.
|
578
|
+
"""
|
579
|
+
if not self.toolboxes:
|
580
|
+
logger.warning("No toolboxes available.")
|
581
|
+
return []
|
582
|
+
return await self.achoose(
|
583
|
+
instruction=task.briefing,
|
584
|
+
choices=list(self.toolboxes),
|
585
|
+
**kwargs,
|
586
|
+
)
|
587
|
+
|
588
|
+
async def choose_tools(
|
589
|
+
self,
|
590
|
+
task: Task,
|
591
|
+
toolbox: ToolBox,
|
592
|
+
**kwargs: Unpack[ChooseKwargs[List[Tool]]],
|
593
|
+
) -> List[Tool]:
|
594
|
+
"""Asynchronously executes a multi-choice decision-making process to choose tools.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
task (Task): The task for which to choose tools.
|
598
|
+
toolbox (ToolBox): The toolbox from which to choose tools.
|
599
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
600
|
+
|
601
|
+
Returns:
|
602
|
+
List[Tool]: The selected tools.
|
603
|
+
"""
|
604
|
+
if not toolbox.tools:
|
605
|
+
logger.warning(f"No tools available in toolbox {toolbox.name}.")
|
606
|
+
return []
|
607
|
+
return await self.achoose(
|
608
|
+
instruction=task.briefing,
|
609
|
+
choices=toolbox.tools,
|
610
|
+
**kwargs,
|
611
|
+
)
|
612
|
+
|
613
|
+
async def gather_tools_fine_grind(
|
614
|
+
self,
|
615
|
+
task: Task,
|
616
|
+
box_choose_kwargs: Optional[ChooseKwargs] = None,
|
617
|
+
tool_choose_kwargs: Optional[ChooseKwargs] = None,
|
618
|
+
) -> List[Tool]:
|
619
|
+
"""Asynchronously gathers tools based on the provided task and toolbox and tool selection criteria.
|
620
|
+
|
621
|
+
Args:
|
622
|
+
task (Task): The task for which to gather tools.
|
623
|
+
box_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing toolboxes.
|
624
|
+
tool_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing tools.
|
625
|
+
|
626
|
+
Returns:
|
627
|
+
List[Tool]: A list of tools gathered based on the provided task and toolbox and tool selection criteria.
|
628
|
+
"""
|
629
|
+
box_choose_kwargs = box_choose_kwargs or {}
|
630
|
+
tool_choose_kwargs = tool_choose_kwargs or {}
|
631
|
+
|
632
|
+
# Choose the toolboxes
|
633
|
+
chosen_toolboxes = await self.choose_toolboxes(task, **box_choose_kwargs)
|
634
|
+
# Choose the tools
|
635
|
+
chosen_tools = []
|
636
|
+
for toolbox in chosen_toolboxes:
|
637
|
+
chosen_tools.extend(await self.choose_tools(task, toolbox, **tool_choose_kwargs))
|
638
|
+
return chosen_tools
|
639
|
+
|
640
|
+
async def gather_tools(self, task: Task, **kwargs: Unpack[ChooseKwargs]) -> List[Tool]:
|
641
|
+
"""Asynchronously gathers tools based on the provided task.
|
642
|
+
|
643
|
+
Args:
|
644
|
+
task (Task): The task for which to gather tools.
|
645
|
+
**kwargs (Unpack[ChooseKwargs]): Keyword arguments for choosing tools.
|
646
|
+
|
647
|
+
Returns:
|
648
|
+
List[Tool]: A list of tools gathered based on the provided task.
|
649
|
+
"""
|
650
|
+
return await self.gather_tools_fine_grind(task, kwargs, kwargs)
|
651
|
+
|
652
|
+
def supply_tools_from[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
|
653
|
+
"""Supplies tools from other ToolUsage instances to this instance.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
|
657
|
+
from which to take tools.
|
658
|
+
|
659
|
+
Returns:
|
660
|
+
Self: The current ToolUsage instance with updated tools.
|
661
|
+
"""
|
662
|
+
if isinstance(others, ToolBoxUsage):
|
663
|
+
others = [others]
|
664
|
+
for other in others:
|
665
|
+
self.toolboxes.update(other.toolboxes)
|
666
|
+
return self
|
667
|
+
|
668
|
+
def provide_tools_to[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
|
669
|
+
"""Provides tools from this instance to other ToolUsage instances.
|
670
|
+
|
671
|
+
Args:
|
672
|
+
others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
|
673
|
+
to which to provide tools.
|
674
|
+
|
675
|
+
Returns:
|
676
|
+
Self: The current ToolUsage instance.
|
677
|
+
"""
|
678
|
+
if isinstance(others, ToolBoxUsage):
|
679
|
+
others = [others]
|
680
|
+
for other in others:
|
681
|
+
other.toolboxes.update(self.toolboxes)
|
682
|
+
return self
|