fabricatio 0.3.15.dev5__cp313-cp313-win_amd64.whl → 0.4.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. fabricatio/__init__.py +9 -8
  2. fabricatio/actions/rules.py +83 -83
  3. fabricatio/rust.cp313-win_amd64.pyd +0 -0
  4. fabricatio/workflows/rag.py +2 -1
  5. fabricatio-0.4.0.data/scripts/tdown.exe +0 -0
  6. {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.0.dist-info}/METADATA +17 -16
  7. fabricatio-0.4.0.dist-info/RECORD +18 -0
  8. fabricatio/actions/article.py +0 -415
  9. fabricatio/actions/article_rag.py +0 -407
  10. fabricatio/capabilities/__init__.py +0 -1
  11. fabricatio/capabilities/advanced_judge.py +0 -20
  12. fabricatio/capabilities/advanced_rag.py +0 -61
  13. fabricatio/capabilities/censor.py +0 -105
  14. fabricatio/capabilities/check.py +0 -212
  15. fabricatio/capabilities/correct.py +0 -228
  16. fabricatio/capabilities/extract.py +0 -74
  17. fabricatio/capabilities/propose.py +0 -65
  18. fabricatio/capabilities/rag.py +0 -264
  19. fabricatio/capabilities/rating.py +0 -404
  20. fabricatio/capabilities/review.py +0 -114
  21. fabricatio/capabilities/task.py +0 -113
  22. fabricatio/decorators.py +0 -253
  23. fabricatio/emitter.py +0 -177
  24. fabricatio/fs/__init__.py +0 -35
  25. fabricatio/fs/curd.py +0 -153
  26. fabricatio/fs/readers.py +0 -61
  27. fabricatio/journal.py +0 -12
  28. fabricatio/models/action.py +0 -263
  29. fabricatio/models/adv_kwargs_types.py +0 -63
  30. fabricatio/models/extra/__init__.py +0 -1
  31. fabricatio/models/extra/advanced_judge.py +0 -32
  32. fabricatio/models/extra/aricle_rag.py +0 -286
  33. fabricatio/models/extra/article_base.py +0 -488
  34. fabricatio/models/extra/article_essence.py +0 -98
  35. fabricatio/models/extra/article_main.py +0 -285
  36. fabricatio/models/extra/article_outline.py +0 -45
  37. fabricatio/models/extra/article_proposal.py +0 -52
  38. fabricatio/models/extra/patches.py +0 -20
  39. fabricatio/models/extra/problem.py +0 -165
  40. fabricatio/models/extra/rag.py +0 -98
  41. fabricatio/models/extra/rule.py +0 -51
  42. fabricatio/models/generic.py +0 -904
  43. fabricatio/models/kwargs_types.py +0 -121
  44. fabricatio/models/role.py +0 -156
  45. fabricatio/models/task.py +0 -310
  46. fabricatio/models/tool.py +0 -328
  47. fabricatio/models/usages.py +0 -791
  48. fabricatio/parser.py +0 -114
  49. fabricatio/rust.pyi +0 -846
  50. fabricatio/utils.py +0 -156
  51. fabricatio/workflows/articles.py +0 -24
  52. fabricatio-0.3.15.dev5.data/scripts/tdown.exe +0 -0
  53. fabricatio-0.3.15.dev5.data/scripts/ttm.exe +0 -0
  54. fabricatio-0.3.15.dev5.dist-info/RECORD +0 -63
  55. {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.0.dist-info}/WHEEL +0 -0
  56. {fabricatio-0.3.15.dev5.dist-info → fabricatio-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,791 +0,0 @@
1
- """This module contains classes that manage the usage of language models and tools in tasks."""
2
-
3
- import traceback
4
- from abc import ABC
5
- from asyncio import gather
6
- from typing import Callable, Dict, Iterable, List, Literal, Optional, Self, Sequence, Set, Union, Unpack, overload
7
-
8
- import asyncstdlib
9
- from fabricatio.decorators import logging_exec_time
10
- from fabricatio.journal import logger
11
- from fabricatio.models.generic import ScopedConfig, WithBriefing
12
- from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
13
- from fabricatio.models.task import Task
14
- from fabricatio.models.tool import Tool, ToolBox
15
- from fabricatio.rust import CONFIG, TEMPLATE_MANAGER
16
- from fabricatio.utils import first_available, ok
17
- from litellm import ( # pyright: ignore [reportPrivateImportUsage]
18
- RateLimitError,
19
- Router,
20
- aembedding,
21
- stream_chunk_builder,
22
- )
23
- from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
24
- from litellm.types.utils import (
25
- Choices,
26
- EmbeddingResponse,
27
- ModelResponse,
28
- StreamingChoices,
29
- TextChoices,
30
- )
31
- from litellm.utils import CustomStreamWrapper, token_counter # pyright: ignore [reportPrivateImportUsage]
32
- from more_itertools import duplicates_everseen
33
- from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
34
-
35
- ROUTER = Router(
36
- routing_strategy="usage-based-routing-v2",
37
- default_max_parallel_requests=CONFIG.routing.max_parallel_requests,
38
- allowed_fails=CONFIG.routing.allowed_fails,
39
- retry_after=CONFIG.routing.retry_after,
40
- cooldown_time=CONFIG.routing.cooldown_time,
41
- )
42
-
43
-
44
- class LLMUsage(ScopedConfig, ABC):
45
- """Class that manages LLM (Large Language Model) usage parameters and methods.
46
-
47
- This class provides methods to deploy LLMs, query them for responses, and handle various configurations
48
- related to LLM usage such as API keys, endpoints, and rate limits.
49
- """
50
-
51
- def _deploy(self, deployment: Deployment) -> Router:
52
- """Add a deployment to the router.
53
-
54
- Args:
55
- deployment (Deployment): The deployment to be added to the router.
56
-
57
- Returns:
58
- Router: The updated router with the added deployment.
59
- """
60
- self._added_deployment = ROUTER.upsert_deployment(deployment)
61
- return ROUTER
62
-
63
- # noinspection PyTypeChecker,PydanticTypeChecker,t
64
- async def aquery(
65
- self,
66
- messages: List[Dict[str, str]],
67
- n: PositiveInt | None = None,
68
- **kwargs: Unpack[LLMKwargs],
69
- ) -> ModelResponse | CustomStreamWrapper:
70
- """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
71
-
72
- Args:
73
- messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
74
- n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
75
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
76
-
77
- Returns:
78
- ModelResponse | CustomStreamWrapper: An object containing the generated response and other metadata from the model.
79
- """
80
- # Call the underlying asynchronous completion function with the provided and default parameters
81
- # noinspection PyTypeChecker,PydanticTypeChecker
82
- return await self._deploy(
83
- Deployment(
84
- model_name=(
85
- m_name := ok(
86
- kwargs.get("model") or self.llm_model or CONFIG.llm.model, "model name is not set at any place"
87
- )
88
- ), # pyright: ignore [reportCallIssue]
89
- litellm_params=(
90
- p := LiteLLM_Params(
91
- api_key=ok(
92
- self.llm_api_key or CONFIG.llm.api_key, "llm api key is not set at any place"
93
- ).get_secret_value(),
94
- api_base=ok(
95
- self.llm_api_endpoint or CONFIG.llm.api_endpoint,
96
- "llm api endpoint is not set at any place",
97
- ),
98
- model=m_name,
99
- tpm=self.llm_tpm or CONFIG.llm.tpm,
100
- rpm=self.llm_rpm or CONFIG.llm.rpm,
101
- max_retries=kwargs.get("max_retries") or self.llm_max_retries or CONFIG.llm.max_retries,
102
- timeout=kwargs.get("timeout") or self.llm_timeout or CONFIG.llm.timeout,
103
- )
104
- ),
105
- model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
106
- )
107
- ).acompletion(
108
- messages=messages, # pyright: ignore [reportArgumentType]
109
- n=n or self.llm_generation_count or CONFIG.llm.generation_count,
110
- model=m_name,
111
- temperature=kwargs.get("temperature") or self.llm_temperature or CONFIG.llm.temperature,
112
- stop=kwargs.get("stop") or self.llm_stop_sign or CONFIG.llm.stop_sign,
113
- top_p=kwargs.get("top_p") or self.llm_top_p or CONFIG.llm.top_p,
114
- max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or CONFIG.llm.max_tokens,
115
- stream=first_available(
116
- (kwargs.get("stream"), self.llm_stream, CONFIG.llm.stream), "stream is not set at any place"
117
- ),
118
- cache={
119
- "no-cache": kwargs.get("no_cache"),
120
- "no-store": kwargs.get("no_store"),
121
- "cache-ttl": kwargs.get("cache_ttl"),
122
- "s-maxage": kwargs.get("s_maxage"),
123
- },
124
- presence_penalty=kwargs.get("presence_penalty") or self.llm_presence_penalty or CONFIG.llm.presence_penalty,
125
- frequency_penalty=kwargs.get("frequency_penalty")
126
- or self.llm_frequency_penalty
127
- or CONFIG.llm.frequency_penalty,
128
- )
129
-
130
- async def ainvoke(
131
- self,
132
- question: str,
133
- system_message: str = "",
134
- n: PositiveInt | None = None,
135
- **kwargs: Unpack[LLMKwargs],
136
- ) -> Sequence[TextChoices | Choices | StreamingChoices]:
137
- """Asynchronously invokes the language model with a question and optional system message.
138
-
139
- Args:
140
- question (str): The question to ask the model.
141
- system_message (str): The system message to provide context to the model. Defaults to an empty string.
142
- n (PositiveInt | None): The number of responses to generate. Defaults to the instance's `llm_generation_count` or the global configuration.
143
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
144
-
145
- Returns:
146
- Sequence[TextChoices | Choices | StreamingChoices]: A sequence of choices or streaming choices from the model response.
147
- """
148
- resp = await self.aquery(
149
- messages=Messages().add_system_message(system_message).add_user_message(question).as_list(),
150
- n=n,
151
- **kwargs,
152
- )
153
- if isinstance(resp, ModelResponse):
154
- return resp.choices
155
- if isinstance(resp, CustomStreamWrapper) and (pack := stream_chunk_builder(await asyncstdlib.list(resp))):
156
- return pack.choices
157
- logger.critical(err := f"Unexpected response type: {type(resp)}")
158
- raise ValueError(err)
159
-
160
- @overload
161
- async def aask(
162
- self,
163
- question: List[str],
164
- system_message: List[str],
165
- **kwargs: Unpack[LLMKwargs],
166
- ) -> List[str]: ...
167
-
168
- @overload
169
- async def aask(
170
- self,
171
- question: str,
172
- system_message: List[str],
173
- **kwargs: Unpack[LLMKwargs],
174
- ) -> List[str]: ...
175
-
176
- @overload
177
- async def aask(
178
- self,
179
- question: List[str],
180
- system_message: Optional[str] = None,
181
- **kwargs: Unpack[LLMKwargs],
182
- ) -> List[str]: ...
183
-
184
- @overload
185
- async def aask(
186
- self,
187
- question: str,
188
- system_message: Optional[str] = None,
189
- **kwargs: Unpack[LLMKwargs],
190
- ) -> str: ...
191
-
192
- @logging_exec_time
193
- async def aask(
194
- self,
195
- question: str | List[str],
196
- system_message: Optional[str | List[str]] = None,
197
- **kwargs: Unpack[LLMKwargs],
198
- ) -> str | List[str]:
199
- """Asynchronously asks the language model a question and returns the response content.
200
-
201
- Args:
202
- question (str | List[str]): The question to ask the model.
203
- system_message (str | List[str] | None): The system message to provide context to the model. Defaults to an empty string.
204
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
205
-
206
- Returns:
207
- str | List[str]: The content of the model's response message.
208
- """
209
- match (question, system_message or ""):
210
- case (list(q_seq), list(sm_seq)):
211
- res = await gather(
212
- *[
213
- self.ainvoke(n=1, question=q, system_message=sm, **kwargs)
214
- for q, sm in zip(q_seq, sm_seq, strict=True)
215
- ]
216
- )
217
- out = [r[0].message.content for r in res] # pyright: ignore [reportAttributeAccessIssue]
218
- case (list(q_seq), str(sm)):
219
- res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
220
- out = [r[0].message.content for r in res] # pyright: ignore [reportAttributeAccessIssue]
221
- case (str(q), list(sm_seq)):
222
- res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
223
- out = [r[0].message.content for r in res] # pyright: ignore [reportAttributeAccessIssue]
224
- case (str(q), str(sm)):
225
- out = ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content # pyright: ignore [reportAttributeAccessIssue]
226
- case _:
227
- raise RuntimeError("Should not reach here.")
228
-
229
- logger.debug(
230
- f"Response Token Count: {token_counter(text=out) if isinstance(out, str) else sum(token_counter(text=o) for o in out)}"
231
- # pyright: ignore [reportOptionalIterable]
232
- )
233
- return out # pyright: ignore [reportReturnType]
234
-
235
- @overload
236
- async def aask_validate[T](
237
- self,
238
- question: str,
239
- validator: Callable[[str], T | None],
240
- default: T = ...,
241
- max_validations: PositiveInt = 2,
242
- **kwargs: Unpack[GenerateKwargs],
243
- ) -> T: ...
244
-
245
- @overload
246
- async def aask_validate[T](
247
- self,
248
- question: List[str],
249
- validator: Callable[[str], T | None],
250
- default: T = ...,
251
- max_validations: PositiveInt = 2,
252
- **kwargs: Unpack[GenerateKwargs],
253
- ) -> List[T]: ...
254
-
255
- @overload
256
- async def aask_validate[T](
257
- self,
258
- question: str,
259
- validator: Callable[[str], T | None],
260
- default: None = None,
261
- max_validations: PositiveInt = 2,
262
- **kwargs: Unpack[GenerateKwargs],
263
- ) -> Optional[T]: ...
264
-
265
- @overload
266
- async def aask_validate[T](
267
- self,
268
- question: List[str],
269
- validator: Callable[[str], T | None],
270
- default: None = None,
271
- max_validations: PositiveInt = 2,
272
- **kwargs: Unpack[GenerateKwargs],
273
- ) -> List[Optional[T]]: ...
274
-
275
- async def aask_validate[T](
276
- self,
277
- question: str | List[str],
278
- validator: Callable[[str], T | None],
279
- default: Optional[T] = None,
280
- max_validations: PositiveInt = 3,
281
- **kwargs: Unpack[GenerateKwargs],
282
- ) -> Optional[T] | List[Optional[T]] | List[T] | T:
283
- """Asynchronously asks a question and validates the response using a given validator.
284
-
285
- Args:
286
- question (str | List[str]): The question to ask.
287
- validator (Callable[[str], T | None]): A function to validate the response.
288
- default (T | None): Default value to return if validation fails. Defaults to None.
289
- max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
290
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
291
-
292
- Returns:
293
- Optional[T] | List[T | None] | List[T] | T: The validated response.
294
- """
295
-
296
- async def _inner(q: str) -> Optional[T]:
297
- for lap in range(max_validations):
298
- try:
299
- if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
300
- logger.debug(f"Successfully validated the response at {lap}th attempt.")
301
- return validated
302
-
303
- except RateLimitError as e:
304
- logger.warning(f"Rate limit error:\n{e}")
305
- continue
306
- except Exception as e: # noqa: BLE001
307
- logger.error(f"Error during validation:\n{e}")
308
- logger.debug(traceback.format_exc())
309
- break
310
- logger.error(f"Failed to validate the response at {lap}th attempt:\n{response}")
311
- if not kwargs.get("no_cache"):
312
- kwargs["no_cache"] = True
313
- logger.debug("Closed the cache for the next attempt")
314
- if default is None:
315
- logger.error(f"Failed to validate the response after {max_validations} attempts.")
316
- return default
317
-
318
- return await (gather(*[_inner(q) for q in question]) if isinstance(question, list) else _inner(question))
319
-
320
- async def alist_str(
321
- self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
322
- ) -> Optional[List[str]]:
323
- """Asynchronously generates a list of strings based on a given requirement.
324
-
325
- Args:
326
- requirement (str): The requirement for the list of strings.
327
- k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
328
- **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
329
-
330
- Returns:
331
- Optional[List[str]]: The validated response as a list of strings.
332
- """
333
- from fabricatio.parser import JsonCapture
334
-
335
- return await self.aask_validate(
336
- TEMPLATE_MANAGER.render_template(
337
- CONFIG.templates.liststr_template,
338
- {"requirement": requirement, "k": k},
339
- ),
340
- lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
341
- **kwargs,
342
- )
343
-
344
- async def apathstr(self, requirement: str, **kwargs: Unpack[ChooseKwargs[List[str]]]) -> Optional[List[str]]:
345
- """Asynchronously generates a list of strings based on a given requirement.
346
-
347
- Args:
348
- requirement (str): The requirement for the list of strings.
349
- **kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the LLM usage.
350
-
351
- Returns:
352
- Optional[List[str]]: The validated response as a list of strings.
353
- """
354
- return await self.alist_str(
355
- TEMPLATE_MANAGER.render_template(
356
- CONFIG.templates.pathstr_template,
357
- {"requirement": requirement},
358
- ),
359
- **kwargs,
360
- )
361
-
362
- async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[ValidateKwargs[List[str]]]) -> Optional[str]:
363
- """Asynchronously generates a single path string based on a given requirement.
364
-
365
- Args:
366
- requirement (str): The requirement for the list of strings.
367
- **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
368
-
369
- Returns:
370
- Optional[str]: The validated response as a single string.
371
- """
372
- if paths := await self.apathstr(
373
- requirement,
374
- k=1,
375
- **kwargs,
376
- ):
377
- return paths.pop()
378
-
379
- return None
380
-
381
- async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> Optional[str]:
382
- """Asynchronously generates a generic string based on a given requirement.
383
-
384
- Args:
385
- requirement (str): The requirement for the string.
386
- **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
387
-
388
- Returns:
389
- Optional[str]: The generated string.
390
- """
391
- from fabricatio.parser import GenericCapture
392
-
393
- return await self.aask_validate( # pyright: ignore [reportReturnType]
394
- TEMPLATE_MANAGER.render_template(
395
- CONFIG.templates.generic_string_template,
396
- {"requirement": requirement, "language": GenericCapture.capture_type},
397
- ),
398
- validator=lambda resp: GenericCapture.capture(resp),
399
- **kwargs,
400
- )
401
-
402
- async def achoose[T: WithBriefing](
403
- self,
404
- instruction: str,
405
- choices: List[T],
406
- k: NonNegativeInt = 0,
407
- **kwargs: Unpack[ValidateKwargs[List[T]]],
408
- ) -> Optional[List[T]]:
409
- """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
410
-
411
- Args:
412
- instruction (str): The user-provided instruction/question description.
413
- choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
414
- k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
415
- **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
416
-
417
- Returns:
418
- Optional[List[T]]: The final validated selection result list, with element types matching the input `choices`.
419
- """
420
- from fabricatio.parser import JsonCapture
421
-
422
- if dup := duplicates_everseen(choices, key=lambda x: x.name):
423
- logger.error(err := f"Redundant choices: {dup}")
424
- raise ValueError(err)
425
- prompt = TEMPLATE_MANAGER.render_template(
426
- CONFIG.templates.make_choice_template,
427
- {
428
- "instruction": instruction,
429
- "options": [m.model_dump(include={"name", "briefing"}) for m in choices],
430
- "k": k,
431
- },
432
- )
433
- names = {c.name for c in choices}
434
-
435
- logger.debug(f"Start choosing between {names} with prompt: \n{prompt}")
436
-
437
- def _validate(response: str) -> List[T] | None:
438
- ret = JsonCapture.validate_with(response, target_type=List, elements_type=str, length=k)
439
- if ret is None or set(ret) - names:
440
- return None
441
- return [
442
- next(candidate for candidate in choices if candidate.name == candidate_name) for candidate_name in ret
443
- ]
444
-
445
- return await self.aask_validate(
446
- question=prompt,
447
- validator=_validate,
448
- **kwargs,
449
- )
450
-
451
- async def apick[T: WithBriefing](
452
- self,
453
- instruction: str,
454
- choices: List[T],
455
- **kwargs: Unpack[ValidateKwargs[List[T]]],
456
- ) -> T:
457
- """Asynchronously picks a single choice from a list of options using AI validation.
458
-
459
- Args:
460
- instruction (str): The user-provided instruction/question description.
461
- choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
462
- **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
463
-
464
- Returns:
465
- T: The single selected item from the choices list.
466
-
467
- Raises:
468
- ValueError: If validation fails after maximum attempts or if no valid selection is made.
469
- """
470
- return ok(
471
- await self.achoose(
472
- instruction=instruction,
473
- choices=choices,
474
- k=1,
475
- **kwargs,
476
- ),
477
- )[0]
478
-
479
- async def ajudge(
480
- self,
481
- prompt: str,
482
- affirm_case: str = "",
483
- deny_case: str = "",
484
- **kwargs: Unpack[ValidateKwargs[bool]],
485
- ) -> Optional[bool]:
486
- """Asynchronously judges a prompt using AI validation.
487
-
488
- Args:
489
- prompt (str): The input prompt to be judged.
490
- affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
491
- deny_case (str): The negative case for the AI model. Defaults to an empty string.
492
- **kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
493
-
494
- Returns:
495
- bool: The judgment result (True or False) based on the AI's response.
496
- """
497
- from fabricatio.parser import JsonCapture
498
-
499
- return await self.aask_validate(
500
- question=TEMPLATE_MANAGER.render_template(
501
- CONFIG.templates.make_judgment_template,
502
- {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
503
- ),
504
- validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
505
- **kwargs,
506
- )
507
-
508
-
509
- class EmbeddingUsage(LLMUsage, ABC):
510
- """A class representing the embedding model.
511
-
512
- This class extends LLMUsage and provides methods to generate embeddings for input text using various models.
513
- """
514
-
515
- async def aembedding(
516
- self,
517
- input_text: List[str],
518
- model: Optional[str] = None,
519
- dimensions: Optional[int] = None,
520
- timeout: Optional[PositiveInt] = None,
521
- caching: Optional[bool] = False,
522
- ) -> EmbeddingResponse:
523
- """Asynchronously generates embeddings for the given input text.
524
-
525
- Args:
526
- input_text (List[str]): A list of strings to generate embeddings for.
527
- model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
528
- dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
529
- timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
530
- caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
531
-
532
- Returns:
533
- EmbeddingResponse: The response containing the embeddings.
534
- """
535
- # check seq length
536
- max_len = self.embedding_max_sequence_length or CONFIG.embedding.max_sequence_length
537
- if max_len and any(length := (token_counter(text=t)) > max_len for t in input_text):
538
- logger.error(err := f"Input text exceeds maximum sequence length {max_len}, got {length}.")
539
- raise ValueError(err)
540
-
541
- return await aembedding(
542
- input=input_text,
543
- caching=caching or self.embedding_caching or CONFIG.embedding.caching,
544
- dimensions=dimensions or self.embedding_dimensions or CONFIG.embedding.dimensions,
545
- model=model or self.embedding_model or CONFIG.embedding.model or self.llm_model or CONFIG.llm.model,
546
- timeout=timeout
547
- or self.embedding_timeout
548
- or CONFIG.embedding.timeout
549
- or self.llm_timeout
550
- or CONFIG.llm.timeout,
551
- api_key=ok(
552
- self.embedding_api_key or CONFIG.embedding.api_key or self.llm_api_key or CONFIG.llm.api_key
553
- ).get_secret_value(),
554
- api_base=ok(
555
- self.embedding_api_endpoint
556
- or CONFIG.embedding.api_endpoint
557
- or self.llm_api_endpoint
558
- or CONFIG.llm.api_endpoint
559
- ).rstrip("/"),
560
- # seems embedding function takes no base_url end with a slash
561
- )
562
-
563
- @overload
564
- async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
565
-
566
- @overload
567
- async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
568
-
569
- async def vectorize(
570
- self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
571
- ) -> List[List[float]] | List[float]:
572
- """Asynchronously generates vector embeddings for the given input text.
573
-
574
- Args:
575
- input_text (List[str] | str): A string or list of strings to generate embeddings for.
576
- **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
577
-
578
- Returns:
579
- List[List[float]] | List[float]: The generated embeddings.
580
- """
581
- if isinstance(input_text, str):
582
- return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
583
-
584
- return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
585
-
586
-
587
- class ToolBoxUsage(LLMUsage, ABC):
588
- """A class representing the usage of tools in a task.
589
-
590
- This class extends LLMUsage and provides methods to manage and use toolboxes and tools within tasks.
591
- """
592
-
593
- toolboxes: Set[ToolBox] = Field(default_factory=set)
594
- """A set of toolboxes used by the instance."""
595
-
596
- @property
597
- def available_toolbox_names(self) -> List[str]:
598
- """Return a list of available toolbox names.
599
-
600
- Returns:
601
- List[str]: A list of names of the available toolboxes.
602
- """
603
- return [toolbox.name for toolbox in self.toolboxes]
604
-
605
- async def choose_toolboxes(
606
- self,
607
- task: Task,
608
- **kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
609
- ) -> Optional[List[ToolBox]]:
610
- """Asynchronously executes a multi-choice decision-making process to choose toolboxes.
611
-
612
- Args:
613
- task (Task): The task for which to choose toolboxes.
614
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
615
-
616
- Returns:
617
- Optional[List[ToolBox]]: The selected toolboxes.
618
- """
619
- if not self.toolboxes:
620
- logger.warning("No toolboxes available.")
621
- return []
622
- return await self.achoose(
623
- instruction=task.briefing,
624
- choices=list(self.toolboxes),
625
- **kwargs,
626
- )
627
-
628
- async def choose_tools(
629
- self,
630
- task: Task,
631
- toolbox: ToolBox,
632
- **kwargs: Unpack[ChooseKwargs[List[Tool]]],
633
- ) -> Optional[List[Tool]]:
634
- """Asynchronously executes a multi-choice decision-making process to choose tools.
635
-
636
- Args:
637
- task (Task): The task for which to choose tools.
638
- toolbox (ToolBox): The toolbox from which to choose tools.
639
- **kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
640
-
641
- Returns:
642
- Optional[List[Tool]]: The selected tools.
643
- """
644
- if not toolbox.tools:
645
- logger.warning(f"No tools available in toolbox {toolbox.name}.")
646
- return []
647
- return await self.achoose(
648
- instruction=task.briefing,
649
- choices=toolbox.tools,
650
- **kwargs,
651
- )
652
-
653
- async def gather_tools_fine_grind(
654
- self,
655
- task: Task,
656
- box_choose_kwargs: Optional[ChooseKwargs] = None,
657
- tool_choose_kwargs: Optional[ChooseKwargs] = None,
658
- ) -> List[Tool]:
659
- """Asynchronously gathers tools based on the provided task and toolbox and tool selection criteria.
660
-
661
- Args:
662
- task (Task): The task for which to gather tools.
663
- box_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing toolboxes.
664
- tool_choose_kwargs (Optional[ChooseKwargs]): Keyword arguments for choosing tools.
665
-
666
- Returns:
667
- List[Tool]: A list of tools gathered based on the provided task and toolbox and tool selection criteria.
668
- """
669
- box_choose_kwargs = box_choose_kwargs or {}
670
- tool_choose_kwargs = tool_choose_kwargs or {}
671
-
672
- # Choose the toolboxes
673
- chosen_toolboxes = ok(await self.choose_toolboxes(task, **box_choose_kwargs))
674
- # Choose the tools
675
- chosen_tools = []
676
- for toolbox in chosen_toolboxes:
677
- chosen_tools.extend(ok(await self.choose_tools(task, toolbox, **tool_choose_kwargs)))
678
- return chosen_tools
679
-
680
- async def gather_tools(self, task: Task, **kwargs: Unpack[ChooseKwargs]) -> List[Tool]:
681
- """Asynchronously gathers tools based on the provided task.
682
-
683
- Args:
684
- task (Task): The task for which to gather tools.
685
- **kwargs (Unpack[ChooseKwargs]): Keyword arguments for choosing tools.
686
-
687
- Returns:
688
- List[Tool]: A list of tools gathered based on the provided task.
689
- """
690
- return await self.gather_tools_fine_grind(task, kwargs, kwargs)
691
-
692
- def supply_tools_from[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
693
- """Supplies tools from other ToolUsage instances to this instance.
694
-
695
- Args:
696
- others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
697
- from which to take tools.
698
-
699
- Returns:
700
- Self: The current ToolUsage instance with updated tools.
701
- """
702
- if isinstance(others, ToolBoxUsage):
703
- others = [others]
704
- for other in (x for x in others if isinstance(x, ToolBoxUsage)):
705
- self.toolboxes.update(other.toolboxes)
706
- return self
707
-
708
- def provide_tools_to[S: "ToolBoxUsage"](self, others: Union[S, Iterable[S]]) -> Self:
709
- """Provides tools from this instance to other ToolUsage instances.
710
-
711
- Args:
712
- others (ToolBoxUsage | Iterable[ToolBoxUsage]): A single ToolUsage instance or an iterable of ToolUsage instances
713
- to which to provide tools.
714
-
715
- Returns:
716
- Self: The current ToolUsage instance.
717
- """
718
- if isinstance(others, ToolBoxUsage):
719
- others = [others]
720
- for other in (x for x in others if isinstance(x, ToolBoxUsage)):
721
- other.toolboxes.update(self.toolboxes)
722
- return self
723
-
724
-
725
- class Message(BaseModel):
726
- """A class representing a message."""
727
-
728
- model_config = ConfigDict(use_attribute_docstrings=True)
729
- role: Literal["user", "system", "assistant"]
730
- """The role of the message sender."""
731
- content: str
732
- """The content of the message."""
733
-
734
-
735
- class Messages(list):
736
- """A list of messages."""
737
-
738
- def add_message(self, role: Literal["user", "system", "assistant"], content: str) -> Self:
739
- """Adds a message to the list with the specified role and content.
740
-
741
- Args:
742
- role (Literal["user", "system", "assistant"]): The role of the message sender.
743
- content (str): The content of the message.
744
-
745
- Returns:
746
- Self: The current instance of Messages to allow method chaining.
747
- """
748
- if content:
749
- self.append(Message(role=role, content=content))
750
- return self
751
-
752
- def add_user_message(self, content: str) -> Self:
753
- """Adds a user message to the list with the specified content.
754
-
755
- Args:
756
- content (str): The content of the user message.
757
-
758
- Returns:
759
- Self: The current instance of Messages to allow method chaining.
760
- """
761
- return self.add_message("user", content)
762
-
763
- def add_system_message(self, content: str) -> Self:
764
- """Adds a system message to the list with the specified content.
765
-
766
- Args:
767
- content (str): The content of the system message.
768
-
769
- Returns:
770
- Self: The current instance of Messages to allow method chaining.
771
- """
772
- return self.add_message("system", content)
773
-
774
- def add_assistant_message(self, content: str) -> Self:
775
- """Adds an assistant message to the list with the specified content.
776
-
777
- Args:
778
- content (str): The content of the assistant message.
779
-
780
- Returns:
781
- Self: The current instance of Messages to allow method chaining.
782
- """
783
- return self.add_message("assistant", content)
784
-
785
- def as_list(self) -> List[Dict[str, str]]:
786
- """Converts the messages to a list of dictionaries.
787
-
788
- Returns:
789
- list[dict]: A list of dictionaries representing the messages.
790
- """
791
- return [message.model_dump() for message in self]