langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. langchain/agents/agent.py +16 -20
  2. langchain/agents/agent_iterator.py +19 -12
  3. langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
  4. langchain/agents/chat/base.py +2 -0
  5. langchain/agents/conversational/base.py +2 -0
  6. langchain/agents/conversational_chat/base.py +2 -0
  7. langchain/agents/initialize.py +1 -1
  8. langchain/agents/json_chat/base.py +1 -0
  9. langchain/agents/mrkl/base.py +2 -0
  10. langchain/agents/openai_assistant/base.py +1 -1
  11. langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
  12. langchain/agents/openai_functions_agent/base.py +3 -2
  13. langchain/agents/openai_functions_multi_agent/base.py +1 -1
  14. langchain/agents/openai_tools/base.py +1 -0
  15. langchain/agents/output_parsers/json.py +2 -0
  16. langchain/agents/output_parsers/openai_functions.py +10 -3
  17. langchain/agents/output_parsers/openai_tools.py +8 -1
  18. langchain/agents/output_parsers/react_json_single_input.py +3 -0
  19. langchain/agents/output_parsers/react_single_input.py +3 -0
  20. langchain/agents/output_parsers/self_ask.py +2 -0
  21. langchain/agents/output_parsers/tools.py +16 -2
  22. langchain/agents/output_parsers/xml.py +3 -0
  23. langchain/agents/react/agent.py +1 -0
  24. langchain/agents/react/base.py +4 -0
  25. langchain/agents/react/output_parser.py +2 -0
  26. langchain/agents/schema.py +2 -0
  27. langchain/agents/self_ask_with_search/base.py +4 -0
  28. langchain/agents/structured_chat/base.py +5 -0
  29. langchain/agents/structured_chat/output_parser.py +13 -0
  30. langchain/agents/tool_calling_agent/base.py +1 -0
  31. langchain/agents/tools.py +3 -0
  32. langchain/agents/xml/base.py +7 -1
  33. langchain/callbacks/streaming_aiter.py +13 -2
  34. langchain/callbacks/streaming_aiter_final_only.py +11 -2
  35. langchain/callbacks/streaming_stdout_final_only.py +5 -0
  36. langchain/callbacks/tracers/logging.py +11 -0
  37. langchain/chains/api/base.py +5 -1
  38. langchain/chains/base.py +8 -2
  39. langchain/chains/combine_documents/base.py +7 -1
  40. langchain/chains/combine_documents/map_reduce.py +3 -0
  41. langchain/chains/combine_documents/map_rerank.py +6 -4
  42. langchain/chains/combine_documents/reduce.py +1 -0
  43. langchain/chains/combine_documents/refine.py +1 -0
  44. langchain/chains/combine_documents/stuff.py +5 -1
  45. langchain/chains/constitutional_ai/base.py +7 -0
  46. langchain/chains/conversation/base.py +4 -1
  47. langchain/chains/conversational_retrieval/base.py +67 -59
  48. langchain/chains/elasticsearch_database/base.py +2 -1
  49. langchain/chains/flare/base.py +2 -0
  50. langchain/chains/flare/prompts.py +2 -0
  51. langchain/chains/llm.py +7 -2
  52. langchain/chains/llm_bash/__init__.py +1 -1
  53. langchain/chains/llm_checker/base.py +12 -1
  54. langchain/chains/llm_math/base.py +9 -1
  55. langchain/chains/llm_summarization_checker/base.py +13 -1
  56. langchain/chains/llm_symbolic_math/__init__.py +1 -1
  57. langchain/chains/loading.py +4 -2
  58. langchain/chains/moderation.py +3 -0
  59. langchain/chains/natbot/base.py +3 -1
  60. langchain/chains/natbot/crawler.py +29 -0
  61. langchain/chains/openai_functions/base.py +2 -0
  62. langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
  63. langchain/chains/openai_functions/openapi.py +4 -0
  64. langchain/chains/openai_functions/qa_with_structure.py +3 -3
  65. langchain/chains/openai_functions/tagging.py +2 -0
  66. langchain/chains/qa_generation/base.py +4 -0
  67. langchain/chains/qa_with_sources/base.py +3 -0
  68. langchain/chains/qa_with_sources/retrieval.py +1 -1
  69. langchain/chains/qa_with_sources/vector_db.py +4 -2
  70. langchain/chains/query_constructor/base.py +4 -2
  71. langchain/chains/query_constructor/parser.py +64 -2
  72. langchain/chains/retrieval_qa/base.py +4 -0
  73. langchain/chains/router/base.py +14 -2
  74. langchain/chains/router/embedding_router.py +3 -0
  75. langchain/chains/router/llm_router.py +6 -4
  76. langchain/chains/router/multi_prompt.py +3 -0
  77. langchain/chains/router/multi_retrieval_qa.py +18 -0
  78. langchain/chains/sql_database/query.py +1 -0
  79. langchain/chains/structured_output/base.py +2 -0
  80. langchain/chains/transform.py +4 -0
  81. langchain/chat_models/base.py +55 -18
  82. langchain/document_loaders/blob_loaders/schema.py +1 -4
  83. langchain/embeddings/base.py +2 -0
  84. langchain/embeddings/cache.py +3 -3
  85. langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
  86. langchain/evaluation/comparison/eval_chain.py +1 -0
  87. langchain/evaluation/criteria/eval_chain.py +3 -0
  88. langchain/evaluation/embedding_distance/base.py +11 -0
  89. langchain/evaluation/exact_match/base.py +14 -1
  90. langchain/evaluation/loading.py +1 -0
  91. langchain/evaluation/parsing/base.py +16 -3
  92. langchain/evaluation/parsing/json_distance.py +19 -8
  93. langchain/evaluation/parsing/json_schema.py +1 -4
  94. langchain/evaluation/qa/eval_chain.py +8 -0
  95. langchain/evaluation/qa/generate_chain.py +2 -0
  96. langchain/evaluation/regex_match/base.py +9 -1
  97. langchain/evaluation/scoring/eval_chain.py +1 -0
  98. langchain/evaluation/string_distance/base.py +6 -0
  99. langchain/memory/buffer.py +5 -0
  100. langchain/memory/buffer_window.py +2 -0
  101. langchain/memory/combined.py +1 -1
  102. langchain/memory/entity.py +47 -0
  103. langchain/memory/simple.py +3 -0
  104. langchain/memory/summary.py +30 -0
  105. langchain/memory/summary_buffer.py +3 -0
  106. langchain/memory/token_buffer.py +2 -0
  107. langchain/output_parsers/combining.py +4 -2
  108. langchain/output_parsers/enum.py +5 -1
  109. langchain/output_parsers/fix.py +8 -1
  110. langchain/output_parsers/pandas_dataframe.py +16 -1
  111. langchain/output_parsers/regex.py +2 -0
  112. langchain/output_parsers/retry.py +21 -1
  113. langchain/output_parsers/structured.py +10 -0
  114. langchain/output_parsers/yaml.py +4 -0
  115. langchain/pydantic_v1/__init__.py +1 -1
  116. langchain/retrievers/document_compressors/chain_extract.py +4 -2
  117. langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
  118. langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
  119. langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
  120. langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
  121. langchain/retrievers/ensemble.py +2 -2
  122. langchain/retrievers/multi_query.py +3 -1
  123. langchain/retrievers/multi_vector.py +4 -1
  124. langchain/retrievers/parent_document_retriever.py +15 -0
  125. langchain/retrievers/self_query/base.py +19 -0
  126. langchain/retrievers/time_weighted_retriever.py +3 -0
  127. langchain/runnables/hub.py +12 -0
  128. langchain/runnables/openai_functions.py +6 -0
  129. langchain/smith/__init__.py +1 -0
  130. langchain/smith/evaluation/config.py +5 -22
  131. langchain/smith/evaluation/progress.py +12 -3
  132. langchain/smith/evaluation/runner_utils.py +240 -123
  133. langchain/smith/evaluation/string_run_evaluator.py +27 -0
  134. langchain/storage/encoder_backed.py +1 -0
  135. langchain/tools/python/__init__.py +1 -1
  136. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
  137. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
  138. langchain/smith/evaluation/utils.py +0 -0
  139. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
  140. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
  141. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -10,6 +10,7 @@ from langchain_core.callbacks import (
10
10
  CallbackManagerForChainRun,
11
11
  )
12
12
  from pydantic import Field
13
+ from typing_extensions import override
13
14
 
14
15
  from langchain.chains.base import Chain
15
16
 
@@ -25,6 +26,7 @@ class TransformChain(Chain):
25
26
  from langchain.chains import TransformChain
26
27
  transform_chain = TransformChain(input_variables=["text"],
27
28
  output_variables["entities"], transform=func())
29
+
28
30
  """
29
31
 
30
32
  input_variables: list[str]
@@ -63,6 +65,7 @@ class TransformChain(Chain):
63
65
  """
64
66
  return self.output_variables
65
67
 
68
+ @override
66
69
  def _call(
67
70
  self,
68
71
  inputs: dict[str, str],
@@ -70,6 +73,7 @@ class TransformChain(Chain):
70
73
  ) -> dict[str, str]:
71
74
  return self.transform_cb(inputs)
72
75
 
76
+ @override
73
77
  async def _acall(
74
78
  self,
75
79
  inputs: dict[str, Any],
@@ -3,15 +3,7 @@ from __future__ import annotations
3
3
  import warnings
4
4
  from collections.abc import AsyncIterator, Iterator, Sequence
5
5
  from importlib import util
6
- from typing import (
7
- Any,
8
- Callable,
9
- Literal,
10
- Optional,
11
- Union,
12
- cast,
13
- overload,
14
- )
6
+ from typing import Any, Callable, Literal, Optional, Union, cast, overload
15
7
 
16
8
  from langchain_core.language_models import (
17
9
  BaseChatModel,
@@ -27,6 +19,7 @@ from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
27
19
  from langchain_core.runnables.schema import StreamEvent
28
20
  from langchain_core.tools import BaseTool
29
21
  from langchain_core.tracers import RunLog, RunLogPatch
22
+ from langchain_core.v1.chat_models import BaseChatModel as BaseChatModelV1
30
23
  from pydantic import BaseModel
31
24
  from typing_extensions import TypeAlias, override
32
25
 
@@ -47,10 +40,23 @@ def init_chat_model(
47
40
  model_provider: Optional[str] = None,
48
41
  configurable_fields: Literal[None] = None,
49
42
  config_prefix: Optional[str] = None,
43
+ message_version: Literal["v0"] = "v0",
50
44
  **kwargs: Any,
51
45
  ) -> BaseChatModel: ...
52
46
 
53
47
 
48
+ @overload
49
+ def init_chat_model(
50
+ model: str,
51
+ *,
52
+ model_provider: Optional[str] = None,
53
+ configurable_fields: Literal[None] = None,
54
+ config_prefix: Optional[str] = None,
55
+ message_version: Literal["v1"] = "v1",
56
+ **kwargs: Any,
57
+ ) -> BaseChatModelV1: ...
58
+
59
+
54
60
  @overload
55
61
  def init_chat_model(
56
62
  model: Literal[None] = None,
@@ -58,6 +64,7 @@ def init_chat_model(
58
64
  model_provider: Optional[str] = None,
59
65
  configurable_fields: Literal[None] = None,
60
66
  config_prefix: Optional[str] = None,
67
+ message_version: Literal["v0", "v1"] = "v0",
61
68
  **kwargs: Any,
62
69
  ) -> _ConfigurableModel: ...
63
70
 
@@ -69,6 +76,7 @@ def init_chat_model(
69
76
  model_provider: Optional[str] = None,
70
77
  configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
71
78
  config_prefix: Optional[str] = None,
79
+ message_version: Literal["v0", "v1"] = "v0",
72
80
  **kwargs: Any,
73
81
  ) -> _ConfigurableModel: ...
74
82
 
@@ -84,8 +92,9 @@ def init_chat_model(
84
92
  Union[Literal["any"], list[str], tuple[str, ...]]
85
93
  ] = None,
86
94
  config_prefix: Optional[str] = None,
95
+ message_version: Literal["v0", "v1"] = "v0",
87
96
  **kwargs: Any,
88
- ) -> Union[BaseChatModel, _ConfigurableModel]:
97
+ ) -> Union[BaseChatModel, BaseChatModelV1, _ConfigurableModel]:
89
98
  """Initialize a ChatModel in a single line using the model's name and provider.
90
99
 
91
100
  .. note::
@@ -136,6 +145,20 @@ def init_chat_model(
136
145
  - ``deepseek...`` -> ``deepseek``
137
146
  - ``grok...`` -> ``xai``
138
147
  - ``sonar...`` -> ``perplexity``
148
+
149
+ message_version: The version of the BaseChatModel to return. Either ``"v0"`` for
150
+ a v0 :class:`~langchain_core.language_models.chat_models.BaseChatModel` or
151
+ ``"v1"`` for a v1 :class:`~langchain_core.v1.chat_models.BaseChatModel`. The
152
+ output version determines what type of message objects the model will
153
+ generate.
154
+
155
+ .. note::
156
+ Currently supported for these providers:
157
+
158
+ - ``openai``
159
+
160
+ .. versionadded:: 0.4.0
161
+
139
162
  configurable_fields: Which model parameters are configurable:
140
163
 
141
164
  - None: No configurable fields.
@@ -188,7 +211,7 @@ def init_chat_model(
188
211
 
189
212
  o3_mini = init_chat_model("openai:o3-mini", temperature=0)
190
213
  claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
191
- gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
214
+ gemini_2_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
192
215
 
193
216
  o3_mini.invoke("what's your name")
194
217
  claude_sonnet.invoke("what's your name")
@@ -322,8 +345,9 @@ def init_chat_model(
322
345
 
323
346
  if not configurable_fields:
324
347
  return _init_chat_model_helper(
325
- cast(str, model),
348
+ cast("str", model),
326
349
  model_provider=model_provider,
350
+ message_version=message_version,
327
351
  **kwargs,
328
352
  )
329
353
  if model:
@@ -341,14 +365,27 @@ def _init_chat_model_helper(
341
365
  model: str,
342
366
  *,
343
367
  model_provider: Optional[str] = None,
368
+ message_version: Literal["v0", "v1"] = "v0",
344
369
  **kwargs: Any,
345
- ) -> BaseChatModel:
370
+ ) -> Union[BaseChatModel, BaseChatModelV1]:
346
371
  model, model_provider = _parse_model(model, model_provider)
372
+ if message_version != "v0" and model_provider not in ("openai",):
373
+ warnings.warn(
374
+ f"Model provider {model_provider} does not support "
375
+ f"message_version={message_version}. Defaulting to v0.",
376
+ stacklevel=2,
377
+ )
347
378
  if model_provider == "openai":
348
379
  _check_pkg("langchain_openai")
349
- from langchain_openai import ChatOpenAI
380
+ if message_version == "v0":
381
+ from langchain_openai import ChatOpenAI
382
+
383
+ return ChatOpenAI(model=model, **kwargs)
384
+ # v1
385
+ from langchain_openai.v1 import ChatOpenAI as ChatOpenAIV1
386
+
387
+ return ChatOpenAIV1(model=model, **kwargs)
350
388
 
351
- return ChatOpenAI(model=model, **kwargs)
352
389
  if model_provider == "anthropic":
353
390
  _check_pkg("langchain_anthropic")
354
391
  from langchain_anthropic import ChatAnthropic
@@ -632,7 +669,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
632
669
  **kwargs: Any,
633
670
  ) -> _ConfigurableModel:
634
671
  """Bind config to a Runnable, returning a new Runnable."""
635
- config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
672
+ config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
636
673
  model_params = self._model_params(config)
637
674
  remaining_config = {k: v for k, v in config.items() if k != "configurable"}
638
675
  remaining_config["configurable"] = {
@@ -781,7 +818,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
781
818
  if config is None or isinstance(config, dict) or len(config) <= 1:
782
819
  if isinstance(config, list):
783
820
  config = config[0]
784
- yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
821
+ yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
785
822
  inputs,
786
823
  config=config,
787
824
  return_exceptions=return_exceptions,
@@ -811,7 +848,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
811
848
  if isinstance(config, list):
812
849
  config = config[0]
813
850
  async for x in self._model(
814
- cast(RunnableConfig, config),
851
+ cast("RunnableConfig", config),
815
852
  ).abatch_as_completed( # type: ignore[call-overload]
816
853
  inputs,
817
854
  config=config,
@@ -1,12 +1,9 @@
1
- from typing import TYPE_CHECKING, Any
1
+ from typing import Any
2
2
 
3
3
  from langchain_core.document_loaders import Blob, BlobLoader
4
4
 
5
5
  from langchain._api import create_importer
6
6
 
7
- if TYPE_CHECKING:
8
- pass
9
-
10
7
  # Create a way to dynamically look up deprecated imports.
11
8
  # Used to consolidate logic for raising deprecation warnings and
12
9
  # handling optional imports.
@@ -47,6 +47,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
47
47
  Raises:
48
48
  ValueError: If the model string is not in the correct format or
49
49
  the provider is unsupported
50
+
50
51
  """
51
52
  if ":" not in model_name:
52
53
  providers = _SUPPORTED_PROVIDERS
@@ -177,6 +178,7 @@ def init_embeddings(
177
178
  )
178
179
 
179
180
  .. versionadded:: 0.3.9
181
+
180
182
  """
181
183
  if not model:
182
184
  providers = _SUPPORTED_PROVIDERS.keys()
@@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
80
80
 
81
81
  def _value_deserializer(serialized_value: bytes) -> list[float]:
82
82
  """Deserialize a value."""
83
- return cast(list[float], json.loads(serialized_value.decode()))
83
+ return cast("list[float]", json.loads(serialized_value.decode()))
84
84
 
85
85
 
86
86
  # The warning is global; track emission, so it appears only once.
@@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
192
192
  vectors[index] = updated_vector
193
193
 
194
194
  return cast(
195
- list[list[float]],
195
+ "list[list[float]]",
196
196
  vectors,
197
197
  ) # Nones should have been resolved by now
198
198
 
@@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
230
230
  vectors[index] = updated_vector
231
231
 
232
232
  return cast(
233
- list[list[float]],
233
+ "list[list[float]]",
234
234
  vectors,
235
235
  ) # Nones should have been resolved by now
236
236
 
@@ -140,6 +140,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
140
140
  )
141
141
  print(result["score"]) # noqa: T201
142
142
  # 0
143
+
143
144
  """
144
145
 
145
146
  agent_tools: Optional[list[BaseTool]] = None
@@ -301,7 +302,7 @@ The following is the expected answer. Use this to measure correctness:
301
302
  chain_input,
302
303
  callbacks=_run_manager.get_child(),
303
304
  )
304
- return cast(dict, self.output_parser.parse(raw_output))
305
+ return cast("dict", self.output_parser.parse(raw_output))
305
306
 
306
307
  async def _acall(
307
308
  self,
@@ -326,7 +327,7 @@ The following is the expected answer. Use this to measure correctness:
326
327
  chain_input,
327
328
  callbacks=_run_manager.get_child(),
328
329
  )
329
- return cast(dict, self.output_parser.parse(raw_output))
330
+ return cast("dict", self.output_parser.parse(raw_output))
330
331
 
331
332
  @override
332
333
  def _evaluate_agent_trajectory(
@@ -191,6 +191,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
191
191
  )
192
192
 
193
193
  @classmethod
194
+ @override
194
195
  def is_lc_serializable(cls) -> bool:
195
196
  return False
196
197
 
@@ -236,6 +236,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
236
236
  output_key: str = "results" #: :meta private:
237
237
 
238
238
  @classmethod
239
+ @override
239
240
  def is_lc_serializable(cls) -> bool:
240
241
  return False
241
242
 
@@ -249,6 +250,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
249
250
  return False
250
251
 
251
252
  @property
253
+ @override
252
254
  def requires_input(self) -> bool:
253
255
  return True
254
256
 
@@ -520,6 +522,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
520
522
  """Criteria evaluation chain that requires references."""
521
523
 
522
524
  @classmethod
525
+ @override
523
526
  def is_lc_serializable(cls) -> bool:
524
527
  return False
525
528
 
@@ -14,6 +14,7 @@ from langchain_core.callbacks.manager import (
14
14
  from langchain_core.embeddings import Embeddings
15
15
  from langchain_core.utils import pre_init
16
16
  from pydantic import ConfigDict, Field
17
+ from typing_extensions import override
17
18
 
18
19
  from langchain.chains.base import Chain
19
20
  from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@@ -317,6 +318,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
317
318
  return True
318
319
 
319
320
  @property
321
+ @override
320
322
  def evaluation_name(self) -> str:
321
323
  return f"embedding_{self.distance_metric.value}_distance"
322
324
 
@@ -329,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
329
331
  """
330
332
  return ["prediction", "reference"]
331
333
 
334
+ @override
332
335
  def _call(
333
336
  self,
334
337
  inputs: dict[str, Any],
@@ -353,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
353
356
  score = self._compute_score(vectors)
354
357
  return {"score": score}
355
358
 
359
+ @override
356
360
  async def _acall(
357
361
  self,
358
362
  inputs: dict[str, Any],
@@ -380,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
380
384
  score = self._compute_score(vectors)
381
385
  return {"score": score}
382
386
 
387
+ @override
383
388
  def _evaluate_strings(
384
389
  self,
385
390
  *,
@@ -414,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
414
419
  )
415
420
  return self._prepare_output(result)
416
421
 
422
+ @override
417
423
  async def _aevaluate_strings(
418
424
  self,
419
425
  *,
@@ -473,8 +479,10 @@ class PairwiseEmbeddingDistanceEvalChain(
473
479
 
474
480
  @property
475
481
  def evaluation_name(self) -> str:
482
+ """Return the evaluation name."""
476
483
  return f"pairwise_embedding_{self.distance_metric.value}_distance"
477
484
 
485
+ @override
478
486
  def _call(
479
487
  self,
480
488
  inputs: dict[str, Any],
@@ -502,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
502
510
  score = self._compute_score(vectors)
503
511
  return {"score": score}
504
512
 
513
+ @override
505
514
  async def _acall(
506
515
  self,
507
516
  inputs: dict[str, Any],
@@ -529,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
529
538
  score = self._compute_score(vectors)
530
539
  return {"score": score}
531
540
 
541
+ @override
532
542
  def _evaluate_string_pairs(
533
543
  self,
534
544
  *,
@@ -564,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
564
574
  )
565
575
  return self._prepare_output(result)
566
576
 
577
+ @override
567
578
  async def _aevaluate_string_pairs(
568
579
  self,
569
580
  *,
@@ -1,6 +1,8 @@
1
1
  import string
2
2
  from typing import Any
3
3
 
4
+ from typing_extensions import override
5
+
4
6
  from langchain.evaluation.schema import StringEvaluator
5
7
 
6
8
 
@@ -27,8 +29,18 @@ class ExactMatchStringEvaluator(StringEvaluator):
27
29
  ignore_case: bool = False,
28
30
  ignore_punctuation: bool = False,
29
31
  ignore_numbers: bool = False,
30
- **kwargs: Any,
32
+ **_: Any,
31
33
  ):
34
+ """Initialize the ExactMatchStringEvaluator.
35
+
36
+ Args:
37
+ ignore_case: Whether to ignore case when comparing strings.
38
+ Defaults to False.
39
+ ignore_punctuation: Whether to ignore punctuation when comparing strings.
40
+ Defaults to False.
41
+ ignore_numbers: Whether to ignore numbers when comparing strings.
42
+ Defaults to False.
43
+ """
32
44
  super().__init__()
33
45
  self.ignore_case = ignore_case
34
46
  self.ignore_punctuation = ignore_punctuation
@@ -68,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
68
80
  """
69
81
  return "exact_match"
70
82
 
83
+ @override
71
84
  def _evaluate_strings( # type: ignore[override]
72
85
  self,
73
86
  *,
@@ -58,6 +58,7 @@ def load_dataset(uri: str) -> list[dict]:
58
58
 
59
59
  from langchain.evaluation import load_dataset
60
60
  ds = load_dataset("llm-math")
61
+
61
62
  """
62
63
  try:
63
64
  from datasets import load_dataset
@@ -35,18 +35,22 @@ class JsonValidityEvaluator(StringEvaluator):
35
35
  {'score': 0, 'reasoning': 'Expecting property name enclosed in double quotes'}
36
36
  """
37
37
 
38
- def __init__(self, **kwargs: Any) -> None:
38
+ def __init__(self, **_: Any) -> None:
39
+ """Initialize the JsonValidityEvaluator."""
39
40
  super().__init__()
40
41
 
41
42
  @property
43
+ @override
42
44
  def requires_input(self) -> bool:
43
45
  return False
44
46
 
45
47
  @property
48
+ @override
46
49
  def requires_reference(self) -> bool:
47
50
  return False
48
51
 
49
52
  @property
53
+ @override
50
54
  def evaluation_name(self) -> str:
51
55
  return "json_validity"
52
56
 
@@ -110,19 +114,28 @@ class JsonEqualityEvaluator(StringEvaluator):
110
114
 
111
115
  """
112
116
 
113
- def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
117
+ def __init__(self, operator: Optional[Callable] = None, **_: Any) -> None:
118
+ """Initialize the JsonEqualityEvaluator.
119
+
120
+ Args:
121
+ operator: A custom operator to compare the parsed JSON objects.
122
+ Defaults to equality (`eq`).
123
+ """
114
124
  super().__init__()
115
125
  self.operator = operator or eq
116
126
 
117
127
  @property
128
+ @override
118
129
  def requires_input(self) -> bool:
119
130
  return False
120
131
 
121
132
  @property
133
+ @override
122
134
  def requires_reference(self) -> bool:
123
135
  return True
124
136
 
125
137
  @property
138
+ @override
126
139
  def evaluation_name(self) -> str:
127
140
  return "json_equality"
128
141
 
@@ -153,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
153
166
  dict: A dictionary containing the evaluation score.
154
167
  """
155
168
  parsed = self._parse_json(prediction)
156
- label = self._parse_json(cast(str, reference))
169
+ label = self._parse_json(cast("str", reference))
157
170
  if isinstance(label, list):
158
171
  if not isinstance(parsed, list):
159
172
  return {"score": 0}
@@ -15,13 +15,6 @@ class JsonEditDistanceEvaluator(StringEvaluator):
15
15
  after parsing them and converting them to a canonical format (i.e., whitespace and key order are normalized).
16
16
  It can be customized with alternative distance and canonicalization functions.
17
17
 
18
- Args:
19
- string_distance (Optional[Callable[[str, str], float]]): A callable that computes the distance between two strings.
20
- If not provided, a Damerau-Levenshtein distance from the `rapidfuzz` package will be used.
21
- canonicalize (Optional[Callable[[Any], Any]]): A callable that converts a parsed JSON object into its canonical string form.
22
- If not provided, the default behavior is to serialize the JSON with sorted keys and no extra whitespace.
23
- **kwargs (Any): Additional keyword arguments.
24
-
25
18
  Attributes:
26
19
  _string_distance (Callable[[str, str], float]): The internal distance computation function.
27
20
  _canonicalize (Callable[[Any], Any]): The internal canonicalization function.
@@ -40,8 +33,23 @@ class JsonEditDistanceEvaluator(StringEvaluator):
40
33
  self,
41
34
  string_distance: Optional[Callable[[str, str], float]] = None,
42
35
  canonicalize: Optional[Callable[[Any], Any]] = None,
43
- **kwargs: Any,
36
+ **_: Any,
44
37
  ) -> None:
38
+ """Initialize the JsonEditDistanceEvaluator.
39
+
40
+ Args:
41
+ string_distance: A callable that computes the distance between two strings.
42
+ If not provided, a Damerau-Levenshtein distance from the `rapidfuzz`
43
+ package will be used.
44
+ canonicalize: A callable that converts a parsed JSON object into its
45
+ canonical string form.
46
+ If not provided, the default behavior is to serialize the JSON with
47
+ sorted keys and no extra whitespace.
48
+
49
+ Raises:
50
+ ImportError: If the `rapidfuzz` package is not installed and no
51
+ `string_distance` function is provided.
52
+ """
45
53
  super().__init__()
46
54
  if string_distance is not None:
47
55
  self._string_distance = string_distance
@@ -67,14 +75,17 @@ class JsonEditDistanceEvaluator(StringEvaluator):
67
75
  )
68
76
 
69
77
  @property
78
+ @override
70
79
  def requires_input(self) -> bool:
71
80
  return False
72
81
 
73
82
  @property
83
+ @override
74
84
  def requires_reference(self) -> bool:
75
85
  return True
76
86
 
77
87
  @property
88
+ @override
78
89
  def evaluation_name(self) -> str:
79
90
  return "json_edit_distance"
80
91
 
@@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
33
33
 
34
34
  """ # noqa: E501
35
35
 
36
- def __init__(self, **kwargs: Any) -> None:
36
+ def __init__(self, **_: Any) -> None:
37
37
  """Initializes the JsonSchemaEvaluator.
38
38
 
39
- Args:
40
- kwargs: Additional keyword arguments.
41
-
42
39
  Raises:
43
40
  ImportError: If the jsonschema package is not installed.
44
41
  """
@@ -80,18 +80,22 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
80
80
  )
81
81
 
82
82
  @classmethod
83
+ @override
83
84
  def is_lc_serializable(cls) -> bool:
84
85
  return False
85
86
 
86
87
  @property
88
+ @override
87
89
  def evaluation_name(self) -> str:
88
90
  return "correctness"
89
91
 
90
92
  @property
93
+ @override
91
94
  def requires_reference(self) -> bool:
92
95
  return True
93
96
 
94
97
  @property
98
+ @override
95
99
  def requires_input(self) -> bool:
96
100
  return True
97
101
 
@@ -214,6 +218,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
214
218
  """LLM Chain for evaluating QA w/o GT based on context"""
215
219
 
216
220
  @classmethod
221
+ @override
217
222
  def is_lc_serializable(cls) -> bool:
218
223
  return False
219
224
 
@@ -242,6 +247,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
242
247
  raise ValueError(msg)
243
248
 
244
249
  @property
250
+ @override
245
251
  def evaluation_name(self) -> str:
246
252
  return "Contextual Accuracy"
247
253
 
@@ -344,10 +350,12 @@ class CotQAEvalChain(ContextQAEvalChain):
344
350
  """LLM Chain for evaluating QA using chain of thought reasoning."""
345
351
 
346
352
  @classmethod
353
+ @override
347
354
  def is_lc_serializable(cls) -> bool:
348
355
  return False
349
356
 
350
357
  @property
358
+ @override
351
359
  def evaluation_name(self) -> str:
352
360
  return "COT Contextual Accuracy"
353
361
 
@@ -7,6 +7,7 @@ from typing import Any
7
7
  from langchain_core.language_models import BaseLanguageModel
8
8
  from langchain_core.output_parsers import BaseLLMOutputParser
9
9
  from pydantic import Field
10
+ from typing_extensions import override
10
11
 
11
12
  from langchain.chains.llm import LLMChain
12
13
  from langchain.evaluation.qa.generate_prompt import PROMPT
@@ -25,6 +26,7 @@ class QAGenerateChain(LLMChain):
25
26
  output_key: str = "qa_pairs"
26
27
 
27
28
  @classmethod
29
+ @override
28
30
  def is_lc_serializable(cls) -> bool:
29
31
  return False
30
32
 
@@ -1,6 +1,8 @@
1
1
  import re
2
2
  from typing import Any
3
3
 
4
+ from typing_extensions import override
5
+
4
6
  from langchain.evaluation.schema import StringEvaluator
5
7
 
6
8
 
@@ -27,7 +29,12 @@ class RegexMatchStringEvaluator(StringEvaluator):
27
29
  ) # This will return {'score': 1.0} as the prediction matches the second pattern in the union
28
30
  """ # noqa: E501
29
31
 
30
- def __init__(self, *, flags: int = 0, **kwargs: Any): # Default is no flags
32
+ def __init__(self, *, flags: int = 0, **_: Any): # Default is no flags
33
+ """Initialize the RegexMatchStringEvaluator.
34
+
35
+ Args:
36
+ flags: Flags to use for the regex match. Defaults to 0 (no flags).
37
+ """
31
38
  super().__init__()
32
39
  self.flags = flags
33
40
 
@@ -65,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
65
72
  """
66
73
  return "regex_match"
67
74
 
75
+ @override
68
76
  def _evaluate_strings( # type: ignore[override]
69
77
  self,
70
78
  *,