langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +16 -20
- langchain/agents/agent_iterator.py +19 -12
- langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
- langchain/agents/chat/base.py +2 -0
- langchain/agents/conversational/base.py +2 -0
- langchain/agents/conversational_chat/base.py +2 -0
- langchain/agents/initialize.py +1 -1
- langchain/agents/json_chat/base.py +1 -0
- langchain/agents/mrkl/base.py +2 -0
- langchain/agents/openai_assistant/base.py +1 -1
- langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
- langchain/agents/openai_functions_agent/base.py +3 -2
- langchain/agents/openai_functions_multi_agent/base.py +1 -1
- langchain/agents/openai_tools/base.py +1 -0
- langchain/agents/output_parsers/json.py +2 -0
- langchain/agents/output_parsers/openai_functions.py +10 -3
- langchain/agents/output_parsers/openai_tools.py +8 -1
- langchain/agents/output_parsers/react_json_single_input.py +3 -0
- langchain/agents/output_parsers/react_single_input.py +3 -0
- langchain/agents/output_parsers/self_ask.py +2 -0
- langchain/agents/output_parsers/tools.py +16 -2
- langchain/agents/output_parsers/xml.py +3 -0
- langchain/agents/react/agent.py +1 -0
- langchain/agents/react/base.py +4 -0
- langchain/agents/react/output_parser.py +2 -0
- langchain/agents/schema.py +2 -0
- langchain/agents/self_ask_with_search/base.py +4 -0
- langchain/agents/structured_chat/base.py +5 -0
- langchain/agents/structured_chat/output_parser.py +13 -0
- langchain/agents/tool_calling_agent/base.py +1 -0
- langchain/agents/tools.py +3 -0
- langchain/agents/xml/base.py +7 -1
- langchain/callbacks/streaming_aiter.py +13 -2
- langchain/callbacks/streaming_aiter_final_only.py +11 -2
- langchain/callbacks/streaming_stdout_final_only.py +5 -0
- langchain/callbacks/tracers/logging.py +11 -0
- langchain/chains/api/base.py +5 -1
- langchain/chains/base.py +8 -2
- langchain/chains/combine_documents/base.py +7 -1
- langchain/chains/combine_documents/map_reduce.py +3 -0
- langchain/chains/combine_documents/map_rerank.py +6 -4
- langchain/chains/combine_documents/reduce.py +1 -0
- langchain/chains/combine_documents/refine.py +1 -0
- langchain/chains/combine_documents/stuff.py +5 -1
- langchain/chains/constitutional_ai/base.py +7 -0
- langchain/chains/conversation/base.py +4 -1
- langchain/chains/conversational_retrieval/base.py +67 -59
- langchain/chains/elasticsearch_database/base.py +2 -1
- langchain/chains/flare/base.py +2 -0
- langchain/chains/flare/prompts.py +2 -0
- langchain/chains/llm.py +7 -2
- langchain/chains/llm_bash/__init__.py +1 -1
- langchain/chains/llm_checker/base.py +12 -1
- langchain/chains/llm_math/base.py +9 -1
- langchain/chains/llm_summarization_checker/base.py +13 -1
- langchain/chains/llm_symbolic_math/__init__.py +1 -1
- langchain/chains/loading.py +4 -2
- langchain/chains/moderation.py +3 -0
- langchain/chains/natbot/base.py +3 -1
- langchain/chains/natbot/crawler.py +29 -0
- langchain/chains/openai_functions/base.py +2 -0
- langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
- langchain/chains/openai_functions/openapi.py +4 -0
- langchain/chains/openai_functions/qa_with_structure.py +3 -3
- langchain/chains/openai_functions/tagging.py +2 -0
- langchain/chains/qa_generation/base.py +4 -0
- langchain/chains/qa_with_sources/base.py +3 -0
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -2
- langchain/chains/query_constructor/base.py +4 -2
- langchain/chains/query_constructor/parser.py +64 -2
- langchain/chains/retrieval_qa/base.py +4 -0
- langchain/chains/router/base.py +14 -2
- langchain/chains/router/embedding_router.py +3 -0
- langchain/chains/router/llm_router.py +6 -4
- langchain/chains/router/multi_prompt.py +3 -0
- langchain/chains/router/multi_retrieval_qa.py +18 -0
- langchain/chains/sql_database/query.py +1 -0
- langchain/chains/structured_output/base.py +2 -0
- langchain/chains/transform.py +4 -0
- langchain/chat_models/base.py +55 -18
- langchain/document_loaders/blob_loaders/schema.py +1 -4
- langchain/embeddings/base.py +2 -0
- langchain/embeddings/cache.py +3 -3
- langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
- langchain/evaluation/comparison/eval_chain.py +1 -0
- langchain/evaluation/criteria/eval_chain.py +3 -0
- langchain/evaluation/embedding_distance/base.py +11 -0
- langchain/evaluation/exact_match/base.py +14 -1
- langchain/evaluation/loading.py +1 -0
- langchain/evaluation/parsing/base.py +16 -3
- langchain/evaluation/parsing/json_distance.py +19 -8
- langchain/evaluation/parsing/json_schema.py +1 -4
- langchain/evaluation/qa/eval_chain.py +8 -0
- langchain/evaluation/qa/generate_chain.py +2 -0
- langchain/evaluation/regex_match/base.py +9 -1
- langchain/evaluation/scoring/eval_chain.py +1 -0
- langchain/evaluation/string_distance/base.py +6 -0
- langchain/memory/buffer.py +5 -0
- langchain/memory/buffer_window.py +2 -0
- langchain/memory/combined.py +1 -1
- langchain/memory/entity.py +47 -0
- langchain/memory/simple.py +3 -0
- langchain/memory/summary.py +30 -0
- langchain/memory/summary_buffer.py +3 -0
- langchain/memory/token_buffer.py +2 -0
- langchain/output_parsers/combining.py +4 -2
- langchain/output_parsers/enum.py +5 -1
- langchain/output_parsers/fix.py +8 -1
- langchain/output_parsers/pandas_dataframe.py +16 -1
- langchain/output_parsers/regex.py +2 -0
- langchain/output_parsers/retry.py +21 -1
- langchain/output_parsers/structured.py +10 -0
- langchain/output_parsers/yaml.py +4 -0
- langchain/pydantic_v1/__init__.py +1 -1
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
- langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
- langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
- langchain/retrievers/ensemble.py +2 -2
- langchain/retrievers/multi_query.py +3 -1
- langchain/retrievers/multi_vector.py +4 -1
- langchain/retrievers/parent_document_retriever.py +15 -0
- langchain/retrievers/self_query/base.py +19 -0
- langchain/retrievers/time_weighted_retriever.py +3 -0
- langchain/runnables/hub.py +12 -0
- langchain/runnables/openai_functions.py +6 -0
- langchain/smith/__init__.py +1 -0
- langchain/smith/evaluation/config.py +5 -22
- langchain/smith/evaluation/progress.py +12 -3
- langchain/smith/evaluation/runner_utils.py +240 -123
- langchain/smith/evaluation/string_run_evaluator.py +27 -0
- langchain/storage/encoder_backed.py +1 -0
- langchain/tools/python/__init__.py +1 -1
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
- langchain/smith/evaluation/utils.py +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
langchain/chains/transform.py
CHANGED
|
@@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
|
|
10
10
|
CallbackManagerForChainRun,
|
|
11
11
|
)
|
|
12
12
|
from pydantic import Field
|
|
13
|
+
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from langchain.chains.base import Chain
|
|
15
16
|
|
|
@@ -25,6 +26,7 @@ class TransformChain(Chain):
|
|
|
25
26
|
from langchain.chains import TransformChain
|
|
26
27
|
transform_chain = TransformChain(input_variables=["text"],
|
|
27
28
|
output_variables["entities"], transform=func())
|
|
29
|
+
|
|
28
30
|
"""
|
|
29
31
|
|
|
30
32
|
input_variables: list[str]
|
|
@@ -63,6 +65,7 @@ class TransformChain(Chain):
|
|
|
63
65
|
"""
|
|
64
66
|
return self.output_variables
|
|
65
67
|
|
|
68
|
+
@override
|
|
66
69
|
def _call(
|
|
67
70
|
self,
|
|
68
71
|
inputs: dict[str, str],
|
|
@@ -70,6 +73,7 @@ class TransformChain(Chain):
|
|
|
70
73
|
) -> dict[str, str]:
|
|
71
74
|
return self.transform_cb(inputs)
|
|
72
75
|
|
|
76
|
+
@override
|
|
73
77
|
async def _acall(
|
|
74
78
|
self,
|
|
75
79
|
inputs: dict[str, Any],
|
langchain/chat_models/base.py
CHANGED
|
@@ -3,15 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import warnings
|
|
4
4
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
5
5
|
from importlib import util
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
Literal,
|
|
10
|
-
Optional,
|
|
11
|
-
Union,
|
|
12
|
-
cast,
|
|
13
|
-
overload,
|
|
14
|
-
)
|
|
6
|
+
from typing import Any, Callable, Literal, Optional, Union, cast, overload
|
|
15
7
|
|
|
16
8
|
from langchain_core.language_models import (
|
|
17
9
|
BaseChatModel,
|
|
@@ -27,6 +19,7 @@ from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
|
|
27
19
|
from langchain_core.runnables.schema import StreamEvent
|
|
28
20
|
from langchain_core.tools import BaseTool
|
|
29
21
|
from langchain_core.tracers import RunLog, RunLogPatch
|
|
22
|
+
from langchain_core.v1.chat_models import BaseChatModel as BaseChatModelV1
|
|
30
23
|
from pydantic import BaseModel
|
|
31
24
|
from typing_extensions import TypeAlias, override
|
|
32
25
|
|
|
@@ -47,10 +40,23 @@ def init_chat_model(
|
|
|
47
40
|
model_provider: Optional[str] = None,
|
|
48
41
|
configurable_fields: Literal[None] = None,
|
|
49
42
|
config_prefix: Optional[str] = None,
|
|
43
|
+
message_version: Literal["v0"] = "v0",
|
|
50
44
|
**kwargs: Any,
|
|
51
45
|
) -> BaseChatModel: ...
|
|
52
46
|
|
|
53
47
|
|
|
48
|
+
@overload
|
|
49
|
+
def init_chat_model(
|
|
50
|
+
model: str,
|
|
51
|
+
*,
|
|
52
|
+
model_provider: Optional[str] = None,
|
|
53
|
+
configurable_fields: Literal[None] = None,
|
|
54
|
+
config_prefix: Optional[str] = None,
|
|
55
|
+
message_version: Literal["v1"] = "v1",
|
|
56
|
+
**kwargs: Any,
|
|
57
|
+
) -> BaseChatModelV1: ...
|
|
58
|
+
|
|
59
|
+
|
|
54
60
|
@overload
|
|
55
61
|
def init_chat_model(
|
|
56
62
|
model: Literal[None] = None,
|
|
@@ -58,6 +64,7 @@ def init_chat_model(
|
|
|
58
64
|
model_provider: Optional[str] = None,
|
|
59
65
|
configurable_fields: Literal[None] = None,
|
|
60
66
|
config_prefix: Optional[str] = None,
|
|
67
|
+
message_version: Literal["v0", "v1"] = "v0",
|
|
61
68
|
**kwargs: Any,
|
|
62
69
|
) -> _ConfigurableModel: ...
|
|
63
70
|
|
|
@@ -69,6 +76,7 @@ def init_chat_model(
|
|
|
69
76
|
model_provider: Optional[str] = None,
|
|
70
77
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
|
71
78
|
config_prefix: Optional[str] = None,
|
|
79
|
+
message_version: Literal["v0", "v1"] = "v0",
|
|
72
80
|
**kwargs: Any,
|
|
73
81
|
) -> _ConfigurableModel: ...
|
|
74
82
|
|
|
@@ -84,8 +92,9 @@ def init_chat_model(
|
|
|
84
92
|
Union[Literal["any"], list[str], tuple[str, ...]]
|
|
85
93
|
] = None,
|
|
86
94
|
config_prefix: Optional[str] = None,
|
|
95
|
+
message_version: Literal["v0", "v1"] = "v0",
|
|
87
96
|
**kwargs: Any,
|
|
88
|
-
) -> Union[BaseChatModel, _ConfigurableModel]:
|
|
97
|
+
) -> Union[BaseChatModel, BaseChatModelV1, _ConfigurableModel]:
|
|
89
98
|
"""Initialize a ChatModel in a single line using the model's name and provider.
|
|
90
99
|
|
|
91
100
|
.. note::
|
|
@@ -136,6 +145,20 @@ def init_chat_model(
|
|
|
136
145
|
- ``deepseek...`` -> ``deepseek``
|
|
137
146
|
- ``grok...`` -> ``xai``
|
|
138
147
|
- ``sonar...`` -> ``perplexity``
|
|
148
|
+
|
|
149
|
+
message_version: The version of the BaseChatModel to return. Either ``"v0"`` for
|
|
150
|
+
a v0 :class:`~langchain_core.language_models.chat_models.BaseChatModel` or
|
|
151
|
+
``"v1"`` for a v1 :class:`~langchain_core.v1.chat_models.BaseChatModel`. The
|
|
152
|
+
output version determines what type of message objects the model will
|
|
153
|
+
generate.
|
|
154
|
+
|
|
155
|
+
.. note::
|
|
156
|
+
Currently supported for these providers:
|
|
157
|
+
|
|
158
|
+
- ``openai``
|
|
159
|
+
|
|
160
|
+
.. versionadded:: 0.4.0
|
|
161
|
+
|
|
139
162
|
configurable_fields: Which model parameters are configurable:
|
|
140
163
|
|
|
141
164
|
- None: No configurable fields.
|
|
@@ -188,7 +211,7 @@ def init_chat_model(
|
|
|
188
211
|
|
|
189
212
|
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
|
|
190
213
|
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
|
|
191
|
-
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.
|
|
214
|
+
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
|
|
192
215
|
|
|
193
216
|
o3_mini.invoke("what's your name")
|
|
194
217
|
claude_sonnet.invoke("what's your name")
|
|
@@ -322,8 +345,9 @@ def init_chat_model(
|
|
|
322
345
|
|
|
323
346
|
if not configurable_fields:
|
|
324
347
|
return _init_chat_model_helper(
|
|
325
|
-
cast(str, model),
|
|
348
|
+
cast("str", model),
|
|
326
349
|
model_provider=model_provider,
|
|
350
|
+
message_version=message_version,
|
|
327
351
|
**kwargs,
|
|
328
352
|
)
|
|
329
353
|
if model:
|
|
@@ -341,14 +365,27 @@ def _init_chat_model_helper(
|
|
|
341
365
|
model: str,
|
|
342
366
|
*,
|
|
343
367
|
model_provider: Optional[str] = None,
|
|
368
|
+
message_version: Literal["v0", "v1"] = "v0",
|
|
344
369
|
**kwargs: Any,
|
|
345
|
-
) -> BaseChatModel:
|
|
370
|
+
) -> Union[BaseChatModel, BaseChatModelV1]:
|
|
346
371
|
model, model_provider = _parse_model(model, model_provider)
|
|
372
|
+
if message_version != "v0" and model_provider not in ("openai",):
|
|
373
|
+
warnings.warn(
|
|
374
|
+
f"Model provider {model_provider} does not support "
|
|
375
|
+
f"message_version={message_version}. Defaulting to v0.",
|
|
376
|
+
stacklevel=2,
|
|
377
|
+
)
|
|
347
378
|
if model_provider == "openai":
|
|
348
379
|
_check_pkg("langchain_openai")
|
|
349
|
-
|
|
380
|
+
if message_version == "v0":
|
|
381
|
+
from langchain_openai import ChatOpenAI
|
|
382
|
+
|
|
383
|
+
return ChatOpenAI(model=model, **kwargs)
|
|
384
|
+
# v1
|
|
385
|
+
from langchain_openai.v1 import ChatOpenAI as ChatOpenAIV1
|
|
386
|
+
|
|
387
|
+
return ChatOpenAIV1(model=model, **kwargs)
|
|
350
388
|
|
|
351
|
-
return ChatOpenAI(model=model, **kwargs)
|
|
352
389
|
if model_provider == "anthropic":
|
|
353
390
|
_check_pkg("langchain_anthropic")
|
|
354
391
|
from langchain_anthropic import ChatAnthropic
|
|
@@ -632,7 +669,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
632
669
|
**kwargs: Any,
|
|
633
670
|
) -> _ConfigurableModel:
|
|
634
671
|
"""Bind config to a Runnable, returning a new Runnable."""
|
|
635
|
-
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
|
|
672
|
+
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
|
636
673
|
model_params = self._model_params(config)
|
|
637
674
|
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
|
638
675
|
remaining_config["configurable"] = {
|
|
@@ -781,7 +818,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
781
818
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
|
782
819
|
if isinstance(config, list):
|
|
783
820
|
config = config[0]
|
|
784
|
-
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
|
|
821
|
+
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
|
|
785
822
|
inputs,
|
|
786
823
|
config=config,
|
|
787
824
|
return_exceptions=return_exceptions,
|
|
@@ -811,7 +848,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
811
848
|
if isinstance(config, list):
|
|
812
849
|
config = config[0]
|
|
813
850
|
async for x in self._model(
|
|
814
|
-
cast(RunnableConfig, config),
|
|
851
|
+
cast("RunnableConfig", config),
|
|
815
852
|
).abatch_as_completed( # type: ignore[call-overload]
|
|
816
853
|
inputs,
|
|
817
854
|
config=config,
|
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
from langchain_core.document_loaders import Blob, BlobLoader
|
|
4
4
|
|
|
5
5
|
from langchain._api import create_importer
|
|
6
6
|
|
|
7
|
-
if TYPE_CHECKING:
|
|
8
|
-
pass
|
|
9
|
-
|
|
10
7
|
# Create a way to dynamically look up deprecated imports.
|
|
11
8
|
# Used to consolidate logic for raising deprecation warnings and
|
|
12
9
|
# handling optional imports.
|
langchain/embeddings/base.py
CHANGED
|
@@ -47,6 +47,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
|
|
47
47
|
Raises:
|
|
48
48
|
ValueError: If the model string is not in the correct format or
|
|
49
49
|
the provider is unsupported
|
|
50
|
+
|
|
50
51
|
"""
|
|
51
52
|
if ":" not in model_name:
|
|
52
53
|
providers = _SUPPORTED_PROVIDERS
|
|
@@ -177,6 +178,7 @@ def init_embeddings(
|
|
|
177
178
|
)
|
|
178
179
|
|
|
179
180
|
.. versionadded:: 0.3.9
|
|
181
|
+
|
|
180
182
|
"""
|
|
181
183
|
if not model:
|
|
182
184
|
providers = _SUPPORTED_PROVIDERS.keys()
|
langchain/embeddings/cache.py
CHANGED
|
@@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
|
|
|
80
80
|
|
|
81
81
|
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
|
82
82
|
"""Deserialize a value."""
|
|
83
|
-
return cast(list[float], json.loads(serialized_value.decode()))
|
|
83
|
+
return cast("list[float]", json.loads(serialized_value.decode()))
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
# The warning is global; track emission, so it appears only once.
|
|
@@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
192
192
|
vectors[index] = updated_vector
|
|
193
193
|
|
|
194
194
|
return cast(
|
|
195
|
-
list[list[float]],
|
|
195
|
+
"list[list[float]]",
|
|
196
196
|
vectors,
|
|
197
197
|
) # Nones should have been resolved by now
|
|
198
198
|
|
|
@@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
230
230
|
vectors[index] = updated_vector
|
|
231
231
|
|
|
232
232
|
return cast(
|
|
233
|
-
list[list[float]],
|
|
233
|
+
"list[list[float]]",
|
|
234
234
|
vectors,
|
|
235
235
|
) # Nones should have been resolved by now
|
|
236
236
|
|
|
@@ -140,6 +140,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
|
|
140
140
|
)
|
|
141
141
|
print(result["score"]) # noqa: T201
|
|
142
142
|
# 0
|
|
143
|
+
|
|
143
144
|
"""
|
|
144
145
|
|
|
145
146
|
agent_tools: Optional[list[BaseTool]] = None
|
|
@@ -301,7 +302,7 @@ The following is the expected answer. Use this to measure correctness:
|
|
|
301
302
|
chain_input,
|
|
302
303
|
callbacks=_run_manager.get_child(),
|
|
303
304
|
)
|
|
304
|
-
return cast(dict, self.output_parser.parse(raw_output))
|
|
305
|
+
return cast("dict", self.output_parser.parse(raw_output))
|
|
305
306
|
|
|
306
307
|
async def _acall(
|
|
307
308
|
self,
|
|
@@ -326,7 +327,7 @@ The following is the expected answer. Use this to measure correctness:
|
|
|
326
327
|
chain_input,
|
|
327
328
|
callbacks=_run_manager.get_child(),
|
|
328
329
|
)
|
|
329
|
-
return cast(dict, self.output_parser.parse(raw_output))
|
|
330
|
+
return cast("dict", self.output_parser.parse(raw_output))
|
|
330
331
|
|
|
331
332
|
@override
|
|
332
333
|
def _evaluate_agent_trajectory(
|
|
@@ -236,6 +236,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
|
|
236
236
|
output_key: str = "results" #: :meta private:
|
|
237
237
|
|
|
238
238
|
@classmethod
|
|
239
|
+
@override
|
|
239
240
|
def is_lc_serializable(cls) -> bool:
|
|
240
241
|
return False
|
|
241
242
|
|
|
@@ -249,6 +250,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
|
|
249
250
|
return False
|
|
250
251
|
|
|
251
252
|
@property
|
|
253
|
+
@override
|
|
252
254
|
def requires_input(self) -> bool:
|
|
253
255
|
return True
|
|
254
256
|
|
|
@@ -520,6 +522,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
|
|
|
520
522
|
"""Criteria evaluation chain that requires references."""
|
|
521
523
|
|
|
522
524
|
@classmethod
|
|
525
|
+
@override
|
|
523
526
|
def is_lc_serializable(cls) -> bool:
|
|
524
527
|
return False
|
|
525
528
|
|
|
@@ -14,6 +14,7 @@ from langchain_core.callbacks.manager import (
|
|
|
14
14
|
from langchain_core.embeddings import Embeddings
|
|
15
15
|
from langchain_core.utils import pre_init
|
|
16
16
|
from pydantic import ConfigDict, Field
|
|
17
|
+
from typing_extensions import override
|
|
17
18
|
|
|
18
19
|
from langchain.chains.base import Chain
|
|
19
20
|
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
|
@@ -317,6 +318,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|
|
317
318
|
return True
|
|
318
319
|
|
|
319
320
|
@property
|
|
321
|
+
@override
|
|
320
322
|
def evaluation_name(self) -> str:
|
|
321
323
|
return f"embedding_{self.distance_metric.value}_distance"
|
|
322
324
|
|
|
@@ -329,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|
|
329
331
|
"""
|
|
330
332
|
return ["prediction", "reference"]
|
|
331
333
|
|
|
334
|
+
@override
|
|
332
335
|
def _call(
|
|
333
336
|
self,
|
|
334
337
|
inputs: dict[str, Any],
|
|
@@ -353,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|
|
353
356
|
score = self._compute_score(vectors)
|
|
354
357
|
return {"score": score}
|
|
355
358
|
|
|
359
|
+
@override
|
|
356
360
|
async def _acall(
|
|
357
361
|
self,
|
|
358
362
|
inputs: dict[str, Any],
|
|
@@ -380,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|
|
380
384
|
score = self._compute_score(vectors)
|
|
381
385
|
return {"score": score}
|
|
382
386
|
|
|
387
|
+
@override
|
|
383
388
|
def _evaluate_strings(
|
|
384
389
|
self,
|
|
385
390
|
*,
|
|
@@ -414,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|
|
414
419
|
)
|
|
415
420
|
return self._prepare_output(result)
|
|
416
421
|
|
|
422
|
+
@override
|
|
417
423
|
async def _aevaluate_strings(
|
|
418
424
|
self,
|
|
419
425
|
*,
|
|
@@ -473,8 +479,10 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|
|
473
479
|
|
|
474
480
|
@property
|
|
475
481
|
def evaluation_name(self) -> str:
|
|
482
|
+
"""Return the evaluation name."""
|
|
476
483
|
return f"pairwise_embedding_{self.distance_metric.value}_distance"
|
|
477
484
|
|
|
485
|
+
@override
|
|
478
486
|
def _call(
|
|
479
487
|
self,
|
|
480
488
|
inputs: dict[str, Any],
|
|
@@ -502,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|
|
502
510
|
score = self._compute_score(vectors)
|
|
503
511
|
return {"score": score}
|
|
504
512
|
|
|
513
|
+
@override
|
|
505
514
|
async def _acall(
|
|
506
515
|
self,
|
|
507
516
|
inputs: dict[str, Any],
|
|
@@ -529,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|
|
529
538
|
score = self._compute_score(vectors)
|
|
530
539
|
return {"score": score}
|
|
531
540
|
|
|
541
|
+
@override
|
|
532
542
|
def _evaluate_string_pairs(
|
|
533
543
|
self,
|
|
534
544
|
*,
|
|
@@ -564,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|
|
564
574
|
)
|
|
565
575
|
return self._prepare_output(result)
|
|
566
576
|
|
|
577
|
+
@override
|
|
567
578
|
async def _aevaluate_string_pairs(
|
|
568
579
|
self,
|
|
569
580
|
*,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import string
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
4
6
|
from langchain.evaluation.schema import StringEvaluator
|
|
5
7
|
|
|
6
8
|
|
|
@@ -27,8 +29,18 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
|
|
27
29
|
ignore_case: bool = False,
|
|
28
30
|
ignore_punctuation: bool = False,
|
|
29
31
|
ignore_numbers: bool = False,
|
|
30
|
-
**
|
|
32
|
+
**_: Any,
|
|
31
33
|
):
|
|
34
|
+
"""Initialize the ExactMatchStringEvaluator.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
ignore_case: Whether to ignore case when comparing strings.
|
|
38
|
+
Defaults to False.
|
|
39
|
+
ignore_punctuation: Whether to ignore punctuation when comparing strings.
|
|
40
|
+
Defaults to False.
|
|
41
|
+
ignore_numbers: Whether to ignore numbers when comparing strings.
|
|
42
|
+
Defaults to False.
|
|
43
|
+
"""
|
|
32
44
|
super().__init__()
|
|
33
45
|
self.ignore_case = ignore_case
|
|
34
46
|
self.ignore_punctuation = ignore_punctuation
|
|
@@ -68,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
|
|
68
80
|
"""
|
|
69
81
|
return "exact_match"
|
|
70
82
|
|
|
83
|
+
@override
|
|
71
84
|
def _evaluate_strings( # type: ignore[override]
|
|
72
85
|
self,
|
|
73
86
|
*,
|
langchain/evaluation/loading.py
CHANGED
|
@@ -35,18 +35,22 @@ class JsonValidityEvaluator(StringEvaluator):
|
|
|
35
35
|
{'score': 0, 'reasoning': 'Expecting property name enclosed in double quotes'}
|
|
36
36
|
"""
|
|
37
37
|
|
|
38
|
-
def __init__(self, **
|
|
38
|
+
def __init__(self, **_: Any) -> None:
|
|
39
|
+
"""Initialize the JsonValidityEvaluator."""
|
|
39
40
|
super().__init__()
|
|
40
41
|
|
|
41
42
|
@property
|
|
43
|
+
@override
|
|
42
44
|
def requires_input(self) -> bool:
|
|
43
45
|
return False
|
|
44
46
|
|
|
45
47
|
@property
|
|
48
|
+
@override
|
|
46
49
|
def requires_reference(self) -> bool:
|
|
47
50
|
return False
|
|
48
51
|
|
|
49
52
|
@property
|
|
53
|
+
@override
|
|
50
54
|
def evaluation_name(self) -> str:
|
|
51
55
|
return "json_validity"
|
|
52
56
|
|
|
@@ -110,19 +114,28 @@ class JsonEqualityEvaluator(StringEvaluator):
|
|
|
110
114
|
|
|
111
115
|
"""
|
|
112
116
|
|
|
113
|
-
def __init__(self, operator: Optional[Callable] = None, **
|
|
117
|
+
def __init__(self, operator: Optional[Callable] = None, **_: Any) -> None:
|
|
118
|
+
"""Initialize the JsonEqualityEvaluator.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
operator: A custom operator to compare the parsed JSON objects.
|
|
122
|
+
Defaults to equality (`eq`).
|
|
123
|
+
"""
|
|
114
124
|
super().__init__()
|
|
115
125
|
self.operator = operator or eq
|
|
116
126
|
|
|
117
127
|
@property
|
|
128
|
+
@override
|
|
118
129
|
def requires_input(self) -> bool:
|
|
119
130
|
return False
|
|
120
131
|
|
|
121
132
|
@property
|
|
133
|
+
@override
|
|
122
134
|
def requires_reference(self) -> bool:
|
|
123
135
|
return True
|
|
124
136
|
|
|
125
137
|
@property
|
|
138
|
+
@override
|
|
126
139
|
def evaluation_name(self) -> str:
|
|
127
140
|
return "json_equality"
|
|
128
141
|
|
|
@@ -153,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
|
|
|
153
166
|
dict: A dictionary containing the evaluation score.
|
|
154
167
|
"""
|
|
155
168
|
parsed = self._parse_json(prediction)
|
|
156
|
-
label = self._parse_json(cast(str, reference))
|
|
169
|
+
label = self._parse_json(cast("str", reference))
|
|
157
170
|
if isinstance(label, list):
|
|
158
171
|
if not isinstance(parsed, list):
|
|
159
172
|
return {"score": 0}
|
|
@@ -15,13 +15,6 @@ class JsonEditDistanceEvaluator(StringEvaluator):
|
|
|
15
15
|
after parsing them and converting them to a canonical format (i.e., whitespace and key order are normalized).
|
|
16
16
|
It can be customized with alternative distance and canonicalization functions.
|
|
17
17
|
|
|
18
|
-
Args:
|
|
19
|
-
string_distance (Optional[Callable[[str, str], float]]): A callable that computes the distance between two strings.
|
|
20
|
-
If not provided, a Damerau-Levenshtein distance from the `rapidfuzz` package will be used.
|
|
21
|
-
canonicalize (Optional[Callable[[Any], Any]]): A callable that converts a parsed JSON object into its canonical string form.
|
|
22
|
-
If not provided, the default behavior is to serialize the JSON with sorted keys and no extra whitespace.
|
|
23
|
-
**kwargs (Any): Additional keyword arguments.
|
|
24
|
-
|
|
25
18
|
Attributes:
|
|
26
19
|
_string_distance (Callable[[str, str], float]): The internal distance computation function.
|
|
27
20
|
_canonicalize (Callable[[Any], Any]): The internal canonicalization function.
|
|
@@ -40,8 +33,23 @@ class JsonEditDistanceEvaluator(StringEvaluator):
|
|
|
40
33
|
self,
|
|
41
34
|
string_distance: Optional[Callable[[str, str], float]] = None,
|
|
42
35
|
canonicalize: Optional[Callable[[Any], Any]] = None,
|
|
43
|
-
**
|
|
36
|
+
**_: Any,
|
|
44
37
|
) -> None:
|
|
38
|
+
"""Initialize the JsonEditDistanceEvaluator.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
string_distance: A callable that computes the distance between two strings.
|
|
42
|
+
If not provided, a Damerau-Levenshtein distance from the `rapidfuzz`
|
|
43
|
+
package will be used.
|
|
44
|
+
canonicalize: A callable that converts a parsed JSON object into its
|
|
45
|
+
canonical string form.
|
|
46
|
+
If not provided, the default behavior is to serialize the JSON with
|
|
47
|
+
sorted keys and no extra whitespace.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ImportError: If the `rapidfuzz` package is not installed and no
|
|
51
|
+
`string_distance` function is provided.
|
|
52
|
+
"""
|
|
45
53
|
super().__init__()
|
|
46
54
|
if string_distance is not None:
|
|
47
55
|
self._string_distance = string_distance
|
|
@@ -67,14 +75,17 @@ class JsonEditDistanceEvaluator(StringEvaluator):
|
|
|
67
75
|
)
|
|
68
76
|
|
|
69
77
|
@property
|
|
78
|
+
@override
|
|
70
79
|
def requires_input(self) -> bool:
|
|
71
80
|
return False
|
|
72
81
|
|
|
73
82
|
@property
|
|
83
|
+
@override
|
|
74
84
|
def requires_reference(self) -> bool:
|
|
75
85
|
return True
|
|
76
86
|
|
|
77
87
|
@property
|
|
88
|
+
@override
|
|
78
89
|
def evaluation_name(self) -> str:
|
|
79
90
|
return "json_edit_distance"
|
|
80
91
|
|
|
@@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
|
|
|
33
33
|
|
|
34
34
|
""" # noqa: E501
|
|
35
35
|
|
|
36
|
-
def __init__(self, **
|
|
36
|
+
def __init__(self, **_: Any) -> None:
|
|
37
37
|
"""Initializes the JsonSchemaEvaluator.
|
|
38
38
|
|
|
39
|
-
Args:
|
|
40
|
-
kwargs: Additional keyword arguments.
|
|
41
|
-
|
|
42
39
|
Raises:
|
|
43
40
|
ImportError: If the jsonschema package is not installed.
|
|
44
41
|
"""
|
|
@@ -80,18 +80,22 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
@classmethod
|
|
83
|
+
@override
|
|
83
84
|
def is_lc_serializable(cls) -> bool:
|
|
84
85
|
return False
|
|
85
86
|
|
|
86
87
|
@property
|
|
88
|
+
@override
|
|
87
89
|
def evaluation_name(self) -> str:
|
|
88
90
|
return "correctness"
|
|
89
91
|
|
|
90
92
|
@property
|
|
93
|
+
@override
|
|
91
94
|
def requires_reference(self) -> bool:
|
|
92
95
|
return True
|
|
93
96
|
|
|
94
97
|
@property
|
|
98
|
+
@override
|
|
95
99
|
def requires_input(self) -> bool:
|
|
96
100
|
return True
|
|
97
101
|
|
|
@@ -214,6 +218,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
|
|
214
218
|
"""LLM Chain for evaluating QA w/o GT based on context"""
|
|
215
219
|
|
|
216
220
|
@classmethod
|
|
221
|
+
@override
|
|
217
222
|
def is_lc_serializable(cls) -> bool:
|
|
218
223
|
return False
|
|
219
224
|
|
|
@@ -242,6 +247,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
|
|
242
247
|
raise ValueError(msg)
|
|
243
248
|
|
|
244
249
|
@property
|
|
250
|
+
@override
|
|
245
251
|
def evaluation_name(self) -> str:
|
|
246
252
|
return "Contextual Accuracy"
|
|
247
253
|
|
|
@@ -344,10 +350,12 @@ class CotQAEvalChain(ContextQAEvalChain):
|
|
|
344
350
|
"""LLM Chain for evaluating QA using chain of thought reasoning."""
|
|
345
351
|
|
|
346
352
|
@classmethod
|
|
353
|
+
@override
|
|
347
354
|
def is_lc_serializable(cls) -> bool:
|
|
348
355
|
return False
|
|
349
356
|
|
|
350
357
|
@property
|
|
358
|
+
@override
|
|
351
359
|
def evaluation_name(self) -> str:
|
|
352
360
|
return "COT Contextual Accuracy"
|
|
353
361
|
|
|
@@ -7,6 +7,7 @@ from typing import Any
|
|
|
7
7
|
from langchain_core.language_models import BaseLanguageModel
|
|
8
8
|
from langchain_core.output_parsers import BaseLLMOutputParser
|
|
9
9
|
from pydantic import Field
|
|
10
|
+
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
from langchain.chains.llm import LLMChain
|
|
12
13
|
from langchain.evaluation.qa.generate_prompt import PROMPT
|
|
@@ -25,6 +26,7 @@ class QAGenerateChain(LLMChain):
|
|
|
25
26
|
output_key: str = "qa_pairs"
|
|
26
27
|
|
|
27
28
|
@classmethod
|
|
29
|
+
@override
|
|
28
30
|
def is_lc_serializable(cls) -> bool:
|
|
29
31
|
return False
|
|
30
32
|
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
4
6
|
from langchain.evaluation.schema import StringEvaluator
|
|
5
7
|
|
|
6
8
|
|
|
@@ -27,7 +29,12 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
|
|
27
29
|
) # This will return {'score': 1.0} as the prediction matches the second pattern in the union
|
|
28
30
|
""" # noqa: E501
|
|
29
31
|
|
|
30
|
-
def __init__(self, *, flags: int = 0, **
|
|
32
|
+
def __init__(self, *, flags: int = 0, **_: Any): # Default is no flags
|
|
33
|
+
"""Initialize the RegexMatchStringEvaluator.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
flags: Flags to use for the regex match. Defaults to 0 (no flags).
|
|
37
|
+
"""
|
|
31
38
|
super().__init__()
|
|
32
39
|
self.flags = flags
|
|
33
40
|
|
|
@@ -65,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
|
|
65
72
|
"""
|
|
66
73
|
return "regex_match"
|
|
67
74
|
|
|
75
|
+
@override
|
|
68
76
|
def _evaluate_strings( # type: ignore[override]
|
|
69
77
|
self,
|
|
70
78
|
*,
|