langchain-core 0.3.79__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-core might be problematic. Click here for more details.
- langchain_core/__init__.py +1 -1
- langchain_core/_api/__init__.py +3 -4
- langchain_core/_api/beta_decorator.py +23 -26
- langchain_core/_api/deprecation.py +52 -65
- langchain_core/_api/path.py +3 -6
- langchain_core/_import_utils.py +3 -4
- langchain_core/agents.py +19 -19
- langchain_core/caches.py +53 -63
- langchain_core/callbacks/__init__.py +1 -8
- langchain_core/callbacks/base.py +323 -334
- langchain_core/callbacks/file.py +44 -44
- langchain_core/callbacks/manager.py +441 -507
- langchain_core/callbacks/stdout.py +29 -30
- langchain_core/callbacks/streaming_stdout.py +32 -32
- langchain_core/callbacks/usage.py +60 -57
- langchain_core/chat_history.py +48 -63
- langchain_core/document_loaders/base.py +23 -23
- langchain_core/document_loaders/langsmith.py +37 -37
- langchain_core/documents/__init__.py +0 -1
- langchain_core/documents/base.py +62 -65
- langchain_core/documents/compressor.py +4 -4
- langchain_core/documents/transformers.py +28 -29
- langchain_core/embeddings/fake.py +50 -54
- langchain_core/example_selectors/length_based.py +1 -1
- langchain_core/example_selectors/semantic_similarity.py +21 -25
- langchain_core/exceptions.py +10 -11
- langchain_core/globals.py +3 -151
- langchain_core/indexing/api.py +61 -66
- langchain_core/indexing/base.py +58 -58
- langchain_core/indexing/in_memory.py +3 -3
- langchain_core/language_models/__init__.py +14 -27
- langchain_core/language_models/_utils.py +270 -84
- langchain_core/language_models/base.py +55 -162
- langchain_core/language_models/chat_models.py +442 -402
- langchain_core/language_models/fake.py +11 -11
- langchain_core/language_models/fake_chat_models.py +61 -39
- langchain_core/language_models/llms.py +123 -231
- langchain_core/load/dump.py +4 -5
- langchain_core/load/load.py +18 -28
- langchain_core/load/mapping.py +2 -4
- langchain_core/load/serializable.py +39 -40
- langchain_core/messages/__init__.py +61 -22
- langchain_core/messages/ai.py +368 -163
- langchain_core/messages/base.py +214 -43
- langchain_core/messages/block_translators/__init__.py +111 -0
- langchain_core/messages/block_translators/anthropic.py +470 -0
- langchain_core/messages/block_translators/bedrock.py +94 -0
- langchain_core/messages/block_translators/bedrock_converse.py +297 -0
- langchain_core/messages/block_translators/google_genai.py +530 -0
- langchain_core/messages/block_translators/google_vertexai.py +21 -0
- langchain_core/messages/block_translators/groq.py +143 -0
- langchain_core/messages/block_translators/langchain_v0.py +301 -0
- langchain_core/messages/block_translators/openai.py +1010 -0
- langchain_core/messages/chat.py +2 -6
- langchain_core/messages/content.py +1423 -0
- langchain_core/messages/function.py +6 -10
- langchain_core/messages/human.py +41 -38
- langchain_core/messages/modifier.py +2 -2
- langchain_core/messages/system.py +38 -28
- langchain_core/messages/tool.py +96 -103
- langchain_core/messages/utils.py +478 -504
- langchain_core/output_parsers/__init__.py +1 -14
- langchain_core/output_parsers/base.py +58 -61
- langchain_core/output_parsers/json.py +7 -8
- langchain_core/output_parsers/list.py +5 -7
- langchain_core/output_parsers/openai_functions.py +49 -47
- langchain_core/output_parsers/openai_tools.py +14 -19
- langchain_core/output_parsers/pydantic.py +12 -13
- langchain_core/output_parsers/string.py +2 -2
- langchain_core/output_parsers/transform.py +15 -17
- langchain_core/output_parsers/xml.py +8 -10
- langchain_core/outputs/__init__.py +1 -1
- langchain_core/outputs/chat_generation.py +18 -18
- langchain_core/outputs/chat_result.py +1 -3
- langchain_core/outputs/generation.py +8 -8
- langchain_core/outputs/llm_result.py +10 -10
- langchain_core/prompt_values.py +12 -12
- langchain_core/prompts/__init__.py +3 -27
- langchain_core/prompts/base.py +45 -55
- langchain_core/prompts/chat.py +254 -313
- langchain_core/prompts/dict.py +5 -5
- langchain_core/prompts/few_shot.py +81 -88
- langchain_core/prompts/few_shot_with_templates.py +11 -13
- langchain_core/prompts/image.py +12 -14
- langchain_core/prompts/loading.py +6 -8
- langchain_core/prompts/message.py +3 -3
- langchain_core/prompts/prompt.py +24 -39
- langchain_core/prompts/string.py +4 -4
- langchain_core/prompts/structured.py +42 -50
- langchain_core/rate_limiters.py +51 -60
- langchain_core/retrievers.py +49 -190
- langchain_core/runnables/base.py +1484 -1709
- langchain_core/runnables/branch.py +45 -61
- langchain_core/runnables/config.py +80 -88
- langchain_core/runnables/configurable.py +117 -134
- langchain_core/runnables/fallbacks.py +83 -79
- langchain_core/runnables/graph.py +85 -95
- langchain_core/runnables/graph_ascii.py +27 -28
- langchain_core/runnables/graph_mermaid.py +38 -50
- langchain_core/runnables/graph_png.py +15 -16
- langchain_core/runnables/history.py +135 -148
- langchain_core/runnables/passthrough.py +124 -150
- langchain_core/runnables/retry.py +46 -51
- langchain_core/runnables/router.py +25 -30
- langchain_core/runnables/schema.py +79 -74
- langchain_core/runnables/utils.py +62 -68
- langchain_core/stores.py +81 -115
- langchain_core/structured_query.py +8 -8
- langchain_core/sys_info.py +27 -29
- langchain_core/tools/__init__.py +1 -14
- langchain_core/tools/base.py +179 -187
- langchain_core/tools/convert.py +131 -139
- langchain_core/tools/render.py +10 -10
- langchain_core/tools/retriever.py +11 -11
- langchain_core/tools/simple.py +19 -24
- langchain_core/tools/structured.py +30 -39
- langchain_core/tracers/__init__.py +1 -9
- langchain_core/tracers/base.py +97 -99
- langchain_core/tracers/context.py +29 -52
- langchain_core/tracers/core.py +50 -60
- langchain_core/tracers/evaluation.py +11 -11
- langchain_core/tracers/event_stream.py +115 -70
- langchain_core/tracers/langchain.py +21 -21
- langchain_core/tracers/log_stream.py +43 -43
- langchain_core/tracers/memory_stream.py +3 -3
- langchain_core/tracers/root_listeners.py +16 -16
- langchain_core/tracers/run_collector.py +2 -4
- langchain_core/tracers/schemas.py +0 -129
- langchain_core/tracers/stdout.py +3 -3
- langchain_core/utils/__init__.py +1 -4
- langchain_core/utils/_merge.py +46 -8
- langchain_core/utils/aiter.py +57 -61
- langchain_core/utils/env.py +9 -9
- langchain_core/utils/function_calling.py +89 -191
- langchain_core/utils/html.py +7 -8
- langchain_core/utils/input.py +6 -6
- langchain_core/utils/interactive_env.py +1 -1
- langchain_core/utils/iter.py +37 -42
- langchain_core/utils/json.py +4 -3
- langchain_core/utils/json_schema.py +8 -8
- langchain_core/utils/mustache.py +9 -11
- langchain_core/utils/pydantic.py +33 -35
- langchain_core/utils/strings.py +5 -5
- langchain_core/utils/usage.py +1 -1
- langchain_core/utils/utils.py +80 -54
- langchain_core/vectorstores/base.py +129 -164
- langchain_core/vectorstores/in_memory.py +99 -174
- langchain_core/vectorstores/utils.py +5 -5
- langchain_core/version.py +1 -1
- {langchain_core-0.3.79.dist-info → langchain_core-1.0.0.dist-info}/METADATA +28 -27
- langchain_core-1.0.0.dist-info/RECORD +172 -0
- {langchain_core-0.3.79.dist-info → langchain_core-1.0.0.dist-info}/WHEEL +1 -1
- langchain_core/beta/__init__.py +0 -1
- langchain_core/beta/runnables/__init__.py +0 -1
- langchain_core/beta/runnables/context.py +0 -447
- langchain_core/memory.py +0 -120
- langchain_core/messages/content_blocks.py +0 -176
- langchain_core/prompts/pipeline.py +0 -138
- langchain_core/pydantic_v1/__init__.py +0 -30
- langchain_core/pydantic_v1/dataclasses.py +0 -23
- langchain_core/pydantic_v1/main.py +0 -23
- langchain_core/tracers/langchain_v1.py +0 -31
- langchain_core/utils/loading.py +0 -35
- langchain_core-0.3.79.dist-info/RECORD +0 -174
- langchain_core-0.3.79.dist-info/entry_points.txt +0 -4
|
@@ -7,21 +7,17 @@ import functools
|
|
|
7
7
|
import inspect
|
|
8
8
|
import json
|
|
9
9
|
import logging
|
|
10
|
-
import warnings
|
|
11
10
|
from abc import ABC, abstractmethod
|
|
12
|
-
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
11
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
|
13
12
|
from pathlib import Path
|
|
14
13
|
from typing import (
|
|
15
14
|
TYPE_CHECKING,
|
|
16
15
|
Any,
|
|
17
|
-
Callable,
|
|
18
|
-
Optional,
|
|
19
|
-
Union,
|
|
20
16
|
cast,
|
|
21
17
|
)
|
|
22
18
|
|
|
23
19
|
import yaml
|
|
24
|
-
from pydantic import ConfigDict
|
|
20
|
+
from pydantic import ConfigDict
|
|
25
21
|
from tenacity import (
|
|
26
22
|
RetryCallState,
|
|
27
23
|
before_sleep_log,
|
|
@@ -33,7 +29,6 @@ from tenacity import (
|
|
|
33
29
|
)
|
|
34
30
|
from typing_extensions import override
|
|
35
31
|
|
|
36
|
-
from langchain_core._api import deprecated
|
|
37
32
|
from langchain_core.caches import BaseCache
|
|
38
33
|
from langchain_core.callbacks import (
|
|
39
34
|
AsyncCallbackManager,
|
|
@@ -51,10 +46,7 @@ from langchain_core.language_models.base import (
|
|
|
51
46
|
)
|
|
52
47
|
from langchain_core.load import dumpd
|
|
53
48
|
from langchain_core.messages import (
|
|
54
|
-
AIMessage,
|
|
55
|
-
BaseMessage,
|
|
56
49
|
convert_to_messages,
|
|
57
|
-
get_buffer_string,
|
|
58
50
|
)
|
|
59
51
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
|
60
52
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
|
@@ -76,16 +68,14 @@ def _log_error_once(msg: str) -> None:
|
|
|
76
68
|
def create_base_retry_decorator(
|
|
77
69
|
error_types: list[type[BaseException]],
|
|
78
70
|
max_retries: int = 1,
|
|
79
|
-
run_manager:
|
|
80
|
-
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
81
|
-
] = None,
|
|
71
|
+
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
|
|
82
72
|
) -> Callable[[Any], Any]:
|
|
83
73
|
"""Create a retry decorator for a given LLM and provided a list of error types.
|
|
84
74
|
|
|
85
75
|
Args:
|
|
86
76
|
error_types: List of error types to retry on.
|
|
87
|
-
max_retries: Number of retries.
|
|
88
|
-
run_manager: Callback manager for the run.
|
|
77
|
+
max_retries: Number of retries.
|
|
78
|
+
run_manager: Callback manager for the run.
|
|
89
79
|
|
|
90
80
|
Returns:
|
|
91
81
|
A retry decorator.
|
|
@@ -101,13 +91,17 @@ def create_base_retry_decorator(
|
|
|
101
91
|
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
|
|
102
92
|
coro = run_manager.on_retry(retry_state)
|
|
103
93
|
try:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
# and be awaited somewhere
|
|
108
|
-
loop.create_task(coro) # noqa: RUF006
|
|
109
|
-
else:
|
|
94
|
+
try:
|
|
95
|
+
loop = asyncio.get_event_loop()
|
|
96
|
+
except RuntimeError:
|
|
110
97
|
asyncio.run(coro)
|
|
98
|
+
else:
|
|
99
|
+
if loop.is_running():
|
|
100
|
+
# TODO: Fix RUF006 - this task should have a reference
|
|
101
|
+
# and be awaited somewhere
|
|
102
|
+
loop.create_task(coro) # noqa: RUF006
|
|
103
|
+
else:
|
|
104
|
+
asyncio.run(coro)
|
|
111
105
|
except Exception as e:
|
|
112
106
|
_log_error_once(f"Error in on_retry: {e}")
|
|
113
107
|
else:
|
|
@@ -129,9 +123,9 @@ def create_base_retry_decorator(
|
|
|
129
123
|
)
|
|
130
124
|
|
|
131
125
|
|
|
132
|
-
def _resolve_cache(*, cache:
|
|
126
|
+
def _resolve_cache(*, cache: BaseCache | bool | None) -> BaseCache | None:
|
|
133
127
|
"""Resolve the cache."""
|
|
134
|
-
llm_cache:
|
|
128
|
+
llm_cache: BaseCache | None
|
|
135
129
|
if isinstance(cache, BaseCache):
|
|
136
130
|
llm_cache = cache
|
|
137
131
|
elif cache is None:
|
|
@@ -156,14 +150,14 @@ def _resolve_cache(*, cache: Union[BaseCache, bool, None]) -> Optional[BaseCache
|
|
|
156
150
|
def get_prompts(
|
|
157
151
|
params: dict[str, Any],
|
|
158
152
|
prompts: list[str],
|
|
159
|
-
cache:
|
|
153
|
+
cache: BaseCache | bool | None = None, # noqa: FBT001
|
|
160
154
|
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
|
161
155
|
"""Get prompts that are already cached.
|
|
162
156
|
|
|
163
157
|
Args:
|
|
164
158
|
params: Dictionary of parameters.
|
|
165
159
|
prompts: List of prompts.
|
|
166
|
-
cache: Cache object.
|
|
160
|
+
cache: Cache object.
|
|
167
161
|
|
|
168
162
|
Returns:
|
|
169
163
|
A tuple of existing prompts, llm_string, missing prompt indexes,
|
|
@@ -192,14 +186,14 @@ def get_prompts(
|
|
|
192
186
|
async def aget_prompts(
|
|
193
187
|
params: dict[str, Any],
|
|
194
188
|
prompts: list[str],
|
|
195
|
-
cache:
|
|
189
|
+
cache: BaseCache | bool | None = None, # noqa: FBT001
|
|
196
190
|
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
|
197
191
|
"""Get prompts that are already cached. Async version.
|
|
198
192
|
|
|
199
193
|
Args:
|
|
200
194
|
params: Dictionary of parameters.
|
|
201
195
|
prompts: List of prompts.
|
|
202
|
-
cache: Cache object.
|
|
196
|
+
cache: Cache object.
|
|
203
197
|
|
|
204
198
|
Returns:
|
|
205
199
|
A tuple of existing prompts, llm_string, missing prompt indexes,
|
|
@@ -225,13 +219,13 @@ async def aget_prompts(
|
|
|
225
219
|
|
|
226
220
|
|
|
227
221
|
def update_cache(
|
|
228
|
-
cache:
|
|
222
|
+
cache: BaseCache | bool | None, # noqa: FBT001
|
|
229
223
|
existing_prompts: dict[int, list],
|
|
230
224
|
llm_string: str,
|
|
231
225
|
missing_prompt_idxs: list[int],
|
|
232
226
|
new_results: LLMResult,
|
|
233
227
|
prompts: list[str],
|
|
234
|
-
) ->
|
|
228
|
+
) -> dict | None:
|
|
235
229
|
"""Update the cache and get the LLM output.
|
|
236
230
|
|
|
237
231
|
Args:
|
|
@@ -258,13 +252,13 @@ def update_cache(
|
|
|
258
252
|
|
|
259
253
|
|
|
260
254
|
async def aupdate_cache(
|
|
261
|
-
cache:
|
|
255
|
+
cache: BaseCache | bool | None, # noqa: FBT001
|
|
262
256
|
existing_prompts: dict[int, list],
|
|
263
257
|
llm_string: str,
|
|
264
258
|
missing_prompt_idxs: list[int],
|
|
265
259
|
new_results: LLMResult,
|
|
266
260
|
prompts: list[str],
|
|
267
|
-
) ->
|
|
261
|
+
) -> dict | None:
|
|
268
262
|
"""Update the cache and get the LLM output. Async version.
|
|
269
263
|
|
|
270
264
|
Args:
|
|
@@ -296,26 +290,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
296
290
|
It should take in a prompt and return a string.
|
|
297
291
|
"""
|
|
298
292
|
|
|
299
|
-
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
|
300
|
-
"""[DEPRECATED]"""
|
|
301
|
-
|
|
302
293
|
model_config = ConfigDict(
|
|
303
294
|
arbitrary_types_allowed=True,
|
|
304
295
|
)
|
|
305
296
|
|
|
306
|
-
@model_validator(mode="before")
|
|
307
|
-
@classmethod
|
|
308
|
-
def raise_deprecation(cls, values: dict) -> Any:
|
|
309
|
-
"""Raise deprecation warning if callback_manager is used."""
|
|
310
|
-
if values.get("callback_manager") is not None:
|
|
311
|
-
warnings.warn(
|
|
312
|
-
"callback_manager is deprecated. Please use callbacks instead.",
|
|
313
|
-
DeprecationWarning,
|
|
314
|
-
stacklevel=5,
|
|
315
|
-
)
|
|
316
|
-
values["callbacks"] = values.pop("callback_manager", None)
|
|
317
|
-
return values
|
|
318
|
-
|
|
319
297
|
@functools.cached_property
|
|
320
298
|
def _serialized(self) -> dict[str, Any]:
|
|
321
299
|
return dumpd(self)
|
|
@@ -325,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
325
303
|
@property
|
|
326
304
|
@override
|
|
327
305
|
def OutputType(self) -> type[str]:
|
|
328
|
-
"""Get the input type for this
|
|
306
|
+
"""Get the input type for this `Runnable`."""
|
|
329
307
|
return str
|
|
330
308
|
|
|
331
309
|
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
|
|
@@ -343,7 +321,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
343
321
|
|
|
344
322
|
def _get_ls_params(
|
|
345
323
|
self,
|
|
346
|
-
stop:
|
|
324
|
+
stop: list[str] | None = None,
|
|
347
325
|
**kwargs: Any,
|
|
348
326
|
) -> LangSmithParams:
|
|
349
327
|
"""Get standard params for tracing."""
|
|
@@ -382,9 +360,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
382
360
|
def invoke(
|
|
383
361
|
self,
|
|
384
362
|
input: LanguageModelInput,
|
|
385
|
-
config:
|
|
363
|
+
config: RunnableConfig | None = None,
|
|
386
364
|
*,
|
|
387
|
-
stop:
|
|
365
|
+
stop: list[str] | None = None,
|
|
388
366
|
**kwargs: Any,
|
|
389
367
|
) -> str:
|
|
390
368
|
config = ensure_config(config)
|
|
@@ -407,9 +385,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
407
385
|
async def ainvoke(
|
|
408
386
|
self,
|
|
409
387
|
input: LanguageModelInput,
|
|
410
|
-
config:
|
|
388
|
+
config: RunnableConfig | None = None,
|
|
411
389
|
*,
|
|
412
|
-
stop:
|
|
390
|
+
stop: list[str] | None = None,
|
|
413
391
|
**kwargs: Any,
|
|
414
392
|
) -> str:
|
|
415
393
|
config = ensure_config(config)
|
|
@@ -429,7 +407,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
429
407
|
def batch(
|
|
430
408
|
self,
|
|
431
409
|
inputs: list[LanguageModelInput],
|
|
432
|
-
config:
|
|
410
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
433
411
|
*,
|
|
434
412
|
return_exceptions: bool = False,
|
|
435
413
|
**kwargs: Any,
|
|
@@ -476,7 +454,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
476
454
|
async def abatch(
|
|
477
455
|
self,
|
|
478
456
|
inputs: list[LanguageModelInput],
|
|
479
|
-
config:
|
|
457
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
480
458
|
*,
|
|
481
459
|
return_exceptions: bool = False,
|
|
482
460
|
**kwargs: Any,
|
|
@@ -522,9 +500,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
522
500
|
def stream(
|
|
523
501
|
self,
|
|
524
502
|
input: LanguageModelInput,
|
|
525
|
-
config:
|
|
503
|
+
config: RunnableConfig | None = None,
|
|
526
504
|
*,
|
|
527
|
-
stop:
|
|
505
|
+
stop: list[str] | None = None,
|
|
528
506
|
**kwargs: Any,
|
|
529
507
|
) -> Iterator[str]:
|
|
530
508
|
if type(self)._stream == BaseLLM._stream: # noqa: SLF001
|
|
@@ -559,7 +537,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
559
537
|
run_id=config.pop("run_id", None),
|
|
560
538
|
batch_size=1,
|
|
561
539
|
)
|
|
562
|
-
generation:
|
|
540
|
+
generation: GenerationChunk | None = None
|
|
563
541
|
try:
|
|
564
542
|
for chunk in self._stream(
|
|
565
543
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
@@ -589,9 +567,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
589
567
|
async def astream(
|
|
590
568
|
self,
|
|
591
569
|
input: LanguageModelInput,
|
|
592
|
-
config:
|
|
570
|
+
config: RunnableConfig | None = None,
|
|
593
571
|
*,
|
|
594
|
-
stop:
|
|
572
|
+
stop: list[str] | None = None,
|
|
595
573
|
**kwargs: Any,
|
|
596
574
|
) -> AsyncIterator[str]:
|
|
597
575
|
if (
|
|
@@ -629,7 +607,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
629
607
|
run_id=config.pop("run_id", None),
|
|
630
608
|
batch_size=1,
|
|
631
609
|
)
|
|
632
|
-
generation:
|
|
610
|
+
generation: GenerationChunk | None = None
|
|
633
611
|
try:
|
|
634
612
|
async for chunk in self._astream(
|
|
635
613
|
prompt,
|
|
@@ -662,8 +640,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
662
640
|
def _generate(
|
|
663
641
|
self,
|
|
664
642
|
prompts: list[str],
|
|
665
|
-
stop:
|
|
666
|
-
run_manager:
|
|
643
|
+
stop: list[str] | None = None,
|
|
644
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
667
645
|
**kwargs: Any,
|
|
668
646
|
) -> LLMResult:
|
|
669
647
|
"""Run the LLM on the given prompts.
|
|
@@ -682,8 +660,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
682
660
|
async def _agenerate(
|
|
683
661
|
self,
|
|
684
662
|
prompts: list[str],
|
|
685
|
-
stop:
|
|
686
|
-
run_manager:
|
|
663
|
+
stop: list[str] | None = None,
|
|
664
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
687
665
|
**kwargs: Any,
|
|
688
666
|
) -> LLMResult:
|
|
689
667
|
"""Run the LLM on the given prompts.
|
|
@@ -710,8 +688,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
710
688
|
def _stream(
|
|
711
689
|
self,
|
|
712
690
|
prompt: str,
|
|
713
|
-
stop:
|
|
714
|
-
run_manager:
|
|
691
|
+
stop: list[str] | None = None,
|
|
692
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
715
693
|
**kwargs: Any,
|
|
716
694
|
) -> Iterator[GenerationChunk]:
|
|
717
695
|
"""Stream the LLM on the given prompt.
|
|
@@ -738,8 +716,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
738
716
|
async def _astream(
|
|
739
717
|
self,
|
|
740
718
|
prompt: str,
|
|
741
|
-
stop:
|
|
742
|
-
run_manager:
|
|
719
|
+
stop: list[str] | None = None,
|
|
720
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
743
721
|
**kwargs: Any,
|
|
744
722
|
) -> AsyncIterator[GenerationChunk]:
|
|
745
723
|
"""An async version of the _stream method.
|
|
@@ -783,8 +761,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
783
761
|
def generate_prompt(
|
|
784
762
|
self,
|
|
785
763
|
prompts: list[PromptValue],
|
|
786
|
-
stop:
|
|
787
|
-
callbacks:
|
|
764
|
+
stop: list[str] | None = None,
|
|
765
|
+
callbacks: Callbacks | list[Callbacks] | None = None,
|
|
788
766
|
**kwargs: Any,
|
|
789
767
|
) -> LLMResult:
|
|
790
768
|
prompt_strings = [p.to_string() for p in prompts]
|
|
@@ -794,8 +772,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
794
772
|
async def agenerate_prompt(
|
|
795
773
|
self,
|
|
796
774
|
prompts: list[PromptValue],
|
|
797
|
-
stop:
|
|
798
|
-
callbacks:
|
|
775
|
+
stop: list[str] | None = None,
|
|
776
|
+
callbacks: Callbacks | list[Callbacks] | None = None,
|
|
799
777
|
**kwargs: Any,
|
|
800
778
|
) -> LLMResult:
|
|
801
779
|
prompt_strings = [p.to_string() for p in prompts]
|
|
@@ -806,7 +784,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
806
784
|
def _generate_helper(
|
|
807
785
|
self,
|
|
808
786
|
prompts: list[str],
|
|
809
|
-
stop:
|
|
787
|
+
stop: list[str] | None,
|
|
810
788
|
run_managers: list[CallbackManagerForLLMRun],
|
|
811
789
|
*,
|
|
812
790
|
new_arg_supported: bool,
|
|
@@ -829,7 +807,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
829
807
|
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
|
|
830
808
|
raise
|
|
831
809
|
flattened_outputs = output.flatten()
|
|
832
|
-
for manager, flattened_output in zip(
|
|
810
|
+
for manager, flattened_output in zip(
|
|
811
|
+
run_managers, flattened_outputs, strict=False
|
|
812
|
+
):
|
|
833
813
|
manager.on_llm_end(flattened_output)
|
|
834
814
|
if run_managers:
|
|
835
815
|
output.run = [
|
|
@@ -840,13 +820,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
840
820
|
def generate(
|
|
841
821
|
self,
|
|
842
822
|
prompts: list[str],
|
|
843
|
-
stop:
|
|
844
|
-
callbacks:
|
|
823
|
+
stop: list[str] | None = None,
|
|
824
|
+
callbacks: Callbacks | list[Callbacks] | None = None,
|
|
845
825
|
*,
|
|
846
|
-
tags:
|
|
847
|
-
metadata:
|
|
848
|
-
run_name:
|
|
849
|
-
run_id:
|
|
826
|
+
tags: list[str] | list[list[str]] | None = None,
|
|
827
|
+
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
|
828
|
+
run_name: str | list[str] | None = None,
|
|
829
|
+
run_id: uuid.UUID | list[uuid.UUID | None] | None = None,
|
|
850
830
|
**kwargs: Any,
|
|
851
831
|
) -> LLMResult:
|
|
852
832
|
"""Pass a sequence of prompts to a model and return generations.
|
|
@@ -859,13 +839,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
859
839
|
1. Take advantage of batched calls,
|
|
860
840
|
2. Need more output from the model than just the top generated value,
|
|
861
841
|
3. Are building chains that are agnostic to the underlying language model
|
|
862
|
-
|
|
842
|
+
type (e.g., pure text completion models vs chat models).
|
|
863
843
|
|
|
864
844
|
Args:
|
|
865
845
|
prompts: List of string prompts.
|
|
866
846
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
867
847
|
first occurrence of any of these substrings.
|
|
868
|
-
callbacks: Callbacks to pass through. Used for executing additional
|
|
848
|
+
callbacks: `Callbacks` to pass through. Used for executing additional
|
|
869
849
|
functionality, such as logging or streaming, throughout generation.
|
|
870
850
|
tags: List of tags to associate with each prompt. If provided, the length
|
|
871
851
|
of the list must match the length of the prompts list.
|
|
@@ -881,12 +861,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
881
861
|
|
|
882
862
|
Raises:
|
|
883
863
|
ValueError: If prompts is not a list.
|
|
884
|
-
ValueError: If the length of
|
|
885
|
-
|
|
864
|
+
ValueError: If the length of `callbacks`, `tags`, `metadata`, or
|
|
865
|
+
`run_name` (if provided) does not match the length of prompts.
|
|
886
866
|
|
|
887
867
|
Returns:
|
|
888
|
-
An LLMResult
|
|
889
|
-
prompt and additional model provider-specific output.
|
|
868
|
+
An `LLMResult`, which contains a list of candidate `Generations` for each
|
|
869
|
+
input prompt and additional model provider-specific output.
|
|
890
870
|
"""
|
|
891
871
|
if not isinstance(prompts, list):
|
|
892
872
|
msg = (
|
|
@@ -936,14 +916,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
936
916
|
msg = "run_name must be a list of the same length as prompts"
|
|
937
917
|
raise ValueError(msg)
|
|
938
918
|
callbacks = cast("list[Callbacks]", callbacks)
|
|
939
|
-
tags_list = cast(
|
|
940
|
-
"list[Optional[list[str]]]", tags or ([None] * len(prompts))
|
|
941
|
-
)
|
|
919
|
+
tags_list = cast("list[list[str] | None]", tags or ([None] * len(prompts)))
|
|
942
920
|
metadata_list = cast(
|
|
943
|
-
"list[
|
|
921
|
+
"list[dict[str, Any] | None]", metadata or ([{}] * len(prompts))
|
|
944
922
|
)
|
|
945
923
|
run_name_list = run_name or cast(
|
|
946
|
-
"list[
|
|
924
|
+
"list[str | None]", ([None] * len(prompts))
|
|
947
925
|
)
|
|
948
926
|
callback_managers = [
|
|
949
927
|
CallbackManager.configure(
|
|
@@ -955,7 +933,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
955
933
|
meta,
|
|
956
934
|
self.metadata,
|
|
957
935
|
)
|
|
958
|
-
for callback, tag, meta in zip(
|
|
936
|
+
for callback, tag, meta in zip(
|
|
937
|
+
callbacks, tags_list, metadata_list, strict=False
|
|
938
|
+
)
|
|
959
939
|
]
|
|
960
940
|
else:
|
|
961
941
|
# We've received a single callbacks arg to apply to all inputs
|
|
@@ -970,7 +950,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
970
950
|
self.metadata,
|
|
971
951
|
)
|
|
972
952
|
] * len(prompts)
|
|
973
|
-
run_name_list = [cast("
|
|
953
|
+
run_name_list = [cast("str | None", run_name)] * len(prompts)
|
|
974
954
|
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
|
975
955
|
params = self.dict()
|
|
976
956
|
params["stop"] = stop
|
|
@@ -996,7 +976,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
996
976
|
run_id=run_id_,
|
|
997
977
|
)[0]
|
|
998
978
|
for callback_manager, prompt, run_name, run_id_ in zip(
|
|
999
|
-
callback_managers,
|
|
979
|
+
callback_managers,
|
|
980
|
+
prompts,
|
|
981
|
+
run_name_list,
|
|
982
|
+
run_ids_list,
|
|
983
|
+
strict=False,
|
|
1000
984
|
)
|
|
1001
985
|
]
|
|
1002
986
|
return self._generate_helper(
|
|
@@ -1046,7 +1030,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1046
1030
|
|
|
1047
1031
|
@staticmethod
|
|
1048
1032
|
def _get_run_ids_list(
|
|
1049
|
-
run_id:
|
|
1033
|
+
run_id: uuid.UUID | list[uuid.UUID | None] | None, prompts: list
|
|
1050
1034
|
) -> list:
|
|
1051
1035
|
if run_id is None:
|
|
1052
1036
|
return [None] * len(prompts)
|
|
@@ -1063,7 +1047,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1063
1047
|
async def _agenerate_helper(
|
|
1064
1048
|
self,
|
|
1065
1049
|
prompts: list[str],
|
|
1066
|
-
stop:
|
|
1050
|
+
stop: list[str] | None,
|
|
1067
1051
|
run_managers: list[AsyncCallbackManagerForLLMRun],
|
|
1068
1052
|
*,
|
|
1069
1053
|
new_arg_supported: bool,
|
|
@@ -1093,7 +1077,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1093
1077
|
*[
|
|
1094
1078
|
run_manager.on_llm_end(flattened_output)
|
|
1095
1079
|
for run_manager, flattened_output in zip(
|
|
1096
|
-
run_managers, flattened_outputs
|
|
1080
|
+
run_managers, flattened_outputs, strict=False
|
|
1097
1081
|
)
|
|
1098
1082
|
]
|
|
1099
1083
|
)
|
|
@@ -1106,13 +1090,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1106
1090
|
async def agenerate(
|
|
1107
1091
|
self,
|
|
1108
1092
|
prompts: list[str],
|
|
1109
|
-
stop:
|
|
1110
|
-
callbacks:
|
|
1093
|
+
stop: list[str] | None = None,
|
|
1094
|
+
callbacks: Callbacks | list[Callbacks] | None = None,
|
|
1111
1095
|
*,
|
|
1112
|
-
tags:
|
|
1113
|
-
metadata:
|
|
1114
|
-
run_name:
|
|
1115
|
-
run_id:
|
|
1096
|
+
tags: list[str] | list[list[str]] | None = None,
|
|
1097
|
+
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
|
1098
|
+
run_name: str | list[str] | None = None,
|
|
1099
|
+
run_id: uuid.UUID | list[uuid.UUID | None] | None = None,
|
|
1116
1100
|
**kwargs: Any,
|
|
1117
1101
|
) -> LLMResult:
|
|
1118
1102
|
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
|
@@ -1125,13 +1109,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1125
1109
|
1. Take advantage of batched calls,
|
|
1126
1110
|
2. Need more output from the model than just the top generated value,
|
|
1127
1111
|
3. Are building chains that are agnostic to the underlying language model
|
|
1128
|
-
|
|
1112
|
+
type (e.g., pure text completion models vs chat models).
|
|
1129
1113
|
|
|
1130
1114
|
Args:
|
|
1131
1115
|
prompts: List of string prompts.
|
|
1132
1116
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
1133
1117
|
first occurrence of any of these substrings.
|
|
1134
|
-
callbacks: Callbacks to pass through. Used for executing additional
|
|
1118
|
+
callbacks: `Callbacks` to pass through. Used for executing additional
|
|
1135
1119
|
functionality, such as logging or streaming, throughout generation.
|
|
1136
1120
|
tags: List of tags to associate with each prompt. If provided, the length
|
|
1137
1121
|
of the list must match the length of the prompts list.
|
|
@@ -1146,12 +1130,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1146
1130
|
to the model provider API call.
|
|
1147
1131
|
|
|
1148
1132
|
Raises:
|
|
1149
|
-
ValueError: If the length of
|
|
1150
|
-
|
|
1133
|
+
ValueError: If the length of `callbacks`, `tags`, `metadata`, or
|
|
1134
|
+
`run_name` (if provided) does not match the length of prompts.
|
|
1151
1135
|
|
|
1152
1136
|
Returns:
|
|
1153
|
-
An LLMResult
|
|
1154
|
-
prompt and additional model provider-specific output.
|
|
1137
|
+
An `LLMResult`, which contains a list of candidate `Generations` for each
|
|
1138
|
+
input prompt and additional model provider-specific output.
|
|
1155
1139
|
"""
|
|
1156
1140
|
if isinstance(metadata, list):
|
|
1157
1141
|
metadata = [
|
|
@@ -1191,14 +1175,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1191
1175
|
msg = "run_name must be a list of the same length as prompts"
|
|
1192
1176
|
raise ValueError(msg)
|
|
1193
1177
|
callbacks = cast("list[Callbacks]", callbacks)
|
|
1194
|
-
tags_list = cast(
|
|
1195
|
-
"list[Optional[list[str]]]", tags or ([None] * len(prompts))
|
|
1196
|
-
)
|
|
1178
|
+
tags_list = cast("list[list[str] | None]", tags or ([None] * len(prompts)))
|
|
1197
1179
|
metadata_list = cast(
|
|
1198
|
-
"list[
|
|
1180
|
+
"list[dict[str, Any] | None]", metadata or ([{}] * len(prompts))
|
|
1199
1181
|
)
|
|
1200
1182
|
run_name_list = run_name or cast(
|
|
1201
|
-
"list[
|
|
1183
|
+
"list[str | None]", ([None] * len(prompts))
|
|
1202
1184
|
)
|
|
1203
1185
|
callback_managers = [
|
|
1204
1186
|
AsyncCallbackManager.configure(
|
|
@@ -1210,7 +1192,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1210
1192
|
meta,
|
|
1211
1193
|
self.metadata,
|
|
1212
1194
|
)
|
|
1213
|
-
for callback, tag, meta in zip(
|
|
1195
|
+
for callback, tag, meta in zip(
|
|
1196
|
+
callbacks, tags_list, metadata_list, strict=False
|
|
1197
|
+
)
|
|
1214
1198
|
]
|
|
1215
1199
|
else:
|
|
1216
1200
|
# We've received a single callbacks arg to apply to all inputs
|
|
@@ -1225,7 +1209,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1225
1209
|
self.metadata,
|
|
1226
1210
|
)
|
|
1227
1211
|
] * len(prompts)
|
|
1228
|
-
run_name_list = [cast("
|
|
1212
|
+
run_name_list = [cast("str | None", run_name)] * len(prompts)
|
|
1229
1213
|
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
|
1230
1214
|
params = self.dict()
|
|
1231
1215
|
params["stop"] = stop
|
|
@@ -1255,7 +1239,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1255
1239
|
run_id=run_id_,
|
|
1256
1240
|
)
|
|
1257
1241
|
for callback_manager, prompt, run_name, run_id_ in zip(
|
|
1258
|
-
callback_managers,
|
|
1242
|
+
callback_managers,
|
|
1243
|
+
prompts,
|
|
1244
|
+
run_name_list,
|
|
1245
|
+
run_ids_list,
|
|
1246
|
+
strict=False,
|
|
1259
1247
|
)
|
|
1260
1248
|
]
|
|
1261
1249
|
)
|
|
@@ -1308,64 +1296,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1308
1296
|
generations = [existing_prompts[i] for i in range(len(prompts))]
|
|
1309
1297
|
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
|
1310
1298
|
|
|
1311
|
-
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
|
1312
|
-
def __call__(
|
|
1313
|
-
self,
|
|
1314
|
-
prompt: str,
|
|
1315
|
-
stop: Optional[list[str]] = None,
|
|
1316
|
-
callbacks: Callbacks = None,
|
|
1317
|
-
*,
|
|
1318
|
-
tags: Optional[list[str]] = None,
|
|
1319
|
-
metadata: Optional[dict[str, Any]] = None,
|
|
1320
|
-
**kwargs: Any,
|
|
1321
|
-
) -> str:
|
|
1322
|
-
"""Check Cache and run the LLM on the given prompt and input.
|
|
1323
|
-
|
|
1324
|
-
Args:
|
|
1325
|
-
prompt: The prompt to generate from.
|
|
1326
|
-
stop: Stop words to use when generating. Model output is cut off at the
|
|
1327
|
-
first occurrence of any of these substrings.
|
|
1328
|
-
callbacks: Callbacks to pass through. Used for executing additional
|
|
1329
|
-
functionality, such as logging or streaming, throughout generation.
|
|
1330
|
-
tags: List of tags to associate with the prompt.
|
|
1331
|
-
metadata: Metadata to associate with the prompt.
|
|
1332
|
-
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
|
1333
|
-
to the model provider API call.
|
|
1334
|
-
|
|
1335
|
-
Returns:
|
|
1336
|
-
The generated text.
|
|
1337
|
-
|
|
1338
|
-
Raises:
|
|
1339
|
-
ValueError: If the prompt is not a string.
|
|
1340
|
-
"""
|
|
1341
|
-
if not isinstance(prompt, str):
|
|
1342
|
-
msg = (
|
|
1343
|
-
"Argument `prompt` is expected to be a string. Instead found "
|
|
1344
|
-
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
|
|
1345
|
-
"`generate` instead."
|
|
1346
|
-
)
|
|
1347
|
-
raise ValueError(msg) # noqa: TRY004
|
|
1348
|
-
return (
|
|
1349
|
-
self.generate(
|
|
1350
|
-
[prompt],
|
|
1351
|
-
stop=stop,
|
|
1352
|
-
callbacks=callbacks,
|
|
1353
|
-
tags=tags,
|
|
1354
|
-
metadata=metadata,
|
|
1355
|
-
**kwargs,
|
|
1356
|
-
)
|
|
1357
|
-
.generations[0][0]
|
|
1358
|
-
.text
|
|
1359
|
-
)
|
|
1360
|
-
|
|
1361
1299
|
async def _call_async(
|
|
1362
1300
|
self,
|
|
1363
1301
|
prompt: str,
|
|
1364
|
-
stop:
|
|
1302
|
+
stop: list[str] | None = None,
|
|
1365
1303
|
callbacks: Callbacks = None,
|
|
1366
1304
|
*,
|
|
1367
|
-
tags:
|
|
1368
|
-
metadata:
|
|
1305
|
+
tags: list[str] | None = None,
|
|
1306
|
+
metadata: dict[str, Any] | None = None,
|
|
1369
1307
|
**kwargs: Any,
|
|
1370
1308
|
) -> str:
|
|
1371
1309
|
"""Check Cache and run the LLM on the given prompt and input."""
|
|
@@ -1379,50 +1317,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1379
1317
|
)
|
|
1380
1318
|
return result.generations[0][0].text
|
|
1381
1319
|
|
|
1382
|
-
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
|
1383
|
-
@override
|
|
1384
|
-
def predict(
|
|
1385
|
-
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
1386
|
-
) -> str:
|
|
1387
|
-
stop_ = None if stop is None else list(stop)
|
|
1388
|
-
return self(text, stop=stop_, **kwargs)
|
|
1389
|
-
|
|
1390
|
-
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
|
1391
|
-
@override
|
|
1392
|
-
def predict_messages(
|
|
1393
|
-
self,
|
|
1394
|
-
messages: list[BaseMessage],
|
|
1395
|
-
*,
|
|
1396
|
-
stop: Optional[Sequence[str]] = None,
|
|
1397
|
-
**kwargs: Any,
|
|
1398
|
-
) -> BaseMessage:
|
|
1399
|
-
text = get_buffer_string(messages)
|
|
1400
|
-
stop_ = None if stop is None else list(stop)
|
|
1401
|
-
content = self(text, stop=stop_, **kwargs)
|
|
1402
|
-
return AIMessage(content=content)
|
|
1403
|
-
|
|
1404
|
-
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
|
1405
|
-
@override
|
|
1406
|
-
async def apredict(
|
|
1407
|
-
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
1408
|
-
) -> str:
|
|
1409
|
-
stop_ = None if stop is None else list(stop)
|
|
1410
|
-
return await self._call_async(text, stop=stop_, **kwargs)
|
|
1411
|
-
|
|
1412
|
-
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
|
1413
|
-
@override
|
|
1414
|
-
async def apredict_messages(
|
|
1415
|
-
self,
|
|
1416
|
-
messages: list[BaseMessage],
|
|
1417
|
-
*,
|
|
1418
|
-
stop: Optional[Sequence[str]] = None,
|
|
1419
|
-
**kwargs: Any,
|
|
1420
|
-
) -> BaseMessage:
|
|
1421
|
-
text = get_buffer_string(messages)
|
|
1422
|
-
stop_ = None if stop is None else list(stop)
|
|
1423
|
-
content = await self._call_async(text, stop=stop_, **kwargs)
|
|
1424
|
-
return AIMessage(content=content)
|
|
1425
|
-
|
|
1426
1320
|
def __str__(self) -> str:
|
|
1427
1321
|
"""Return a string representation of the object for printing."""
|
|
1428
1322
|
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
|
|
@@ -1440,7 +1334,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1440
1334
|
starter_dict["_type"] = self._llm_type
|
|
1441
1335
|
return starter_dict
|
|
1442
1336
|
|
|
1443
|
-
def save(self, file_path:
|
|
1337
|
+
def save(self, file_path: Path | str) -> None:
|
|
1444
1338
|
"""Save the LLM.
|
|
1445
1339
|
|
|
1446
1340
|
Args:
|
|
@@ -1450,11 +1344,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1450
1344
|
ValueError: If the file path is not a string or Path object.
|
|
1451
1345
|
|
|
1452
1346
|
Example:
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
llm.save(file_path="path/llm.yaml")
|
|
1457
|
-
|
|
1347
|
+
```python
|
|
1348
|
+
llm.save(file_path="path/llm.yaml")
|
|
1349
|
+
```
|
|
1458
1350
|
"""
|
|
1459
1351
|
# Convert file to Path object.
|
|
1460
1352
|
save_path = Path(file_path)
|
|
@@ -1466,10 +1358,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|
|
1466
1358
|
prompt_dict = self.dict()
|
|
1467
1359
|
|
|
1468
1360
|
if save_path.suffix == ".json":
|
|
1469
|
-
with save_path.open("w") as f:
|
|
1361
|
+
with save_path.open("w", encoding="utf-8") as f:
|
|
1470
1362
|
json.dump(prompt_dict, f, indent=4)
|
|
1471
1363
|
elif save_path.suffix.endswith((".yaml", ".yml")):
|
|
1472
|
-
with save_path.open("w") as f:
|
|
1364
|
+
with save_path.open("w", encoding="utf-8") as f:
|
|
1473
1365
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
|
1474
1366
|
else:
|
|
1475
1367
|
msg = f"{save_path} must be json or yaml"
|
|
@@ -1510,8 +1402,8 @@ class LLM(BaseLLM):
|
|
|
1510
1402
|
def _call(
|
|
1511
1403
|
self,
|
|
1512
1404
|
prompt: str,
|
|
1513
|
-
stop:
|
|
1514
|
-
run_manager:
|
|
1405
|
+
stop: list[str] | None = None,
|
|
1406
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
1515
1407
|
**kwargs: Any,
|
|
1516
1408
|
) -> str:
|
|
1517
1409
|
"""Run the LLM on the given input.
|
|
@@ -1534,8 +1426,8 @@ class LLM(BaseLLM):
|
|
|
1534
1426
|
async def _acall(
|
|
1535
1427
|
self,
|
|
1536
1428
|
prompt: str,
|
|
1537
|
-
stop:
|
|
1538
|
-
run_manager:
|
|
1429
|
+
stop: list[str] | None = None,
|
|
1430
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
1539
1431
|
**kwargs: Any,
|
|
1540
1432
|
) -> str:
|
|
1541
1433
|
"""Async version of the _call method.
|
|
@@ -1568,8 +1460,8 @@ class LLM(BaseLLM):
|
|
|
1568
1460
|
def _generate(
|
|
1569
1461
|
self,
|
|
1570
1462
|
prompts: list[str],
|
|
1571
|
-
stop:
|
|
1572
|
-
run_manager:
|
|
1463
|
+
stop: list[str] | None = None,
|
|
1464
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
1573
1465
|
**kwargs: Any,
|
|
1574
1466
|
) -> LLMResult:
|
|
1575
1467
|
# TODO: add caching here.
|
|
@@ -1587,8 +1479,8 @@ class LLM(BaseLLM):
|
|
|
1587
1479
|
async def _agenerate(
|
|
1588
1480
|
self,
|
|
1589
1481
|
prompts: list[str],
|
|
1590
|
-
stop:
|
|
1591
|
-
run_manager:
|
|
1482
|
+
stop: list[str] | None = None,
|
|
1483
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
1592
1484
|
**kwargs: Any,
|
|
1593
1485
|
) -> LLMResult:
|
|
1594
1486
|
generations = []
|