langchain-core 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-core might be problematic. Click here for more details.

Files changed (131) hide show
  1. langchain_core/_api/beta_decorator.py +17 -40
  2. langchain_core/_api/deprecation.py +20 -7
  3. langchain_core/_api/path.py +19 -2
  4. langchain_core/_import_utils.py +7 -0
  5. langchain_core/agents.py +10 -6
  6. langchain_core/callbacks/base.py +28 -15
  7. langchain_core/callbacks/manager.py +81 -69
  8. langchain_core/callbacks/usage.py +4 -2
  9. langchain_core/chat_history.py +29 -21
  10. langchain_core/document_loaders/base.py +34 -9
  11. langchain_core/document_loaders/langsmith.py +3 -0
  12. langchain_core/documents/base.py +35 -10
  13. langchain_core/documents/transformers.py +4 -2
  14. langchain_core/embeddings/fake.py +8 -5
  15. langchain_core/env.py +2 -3
  16. langchain_core/example_selectors/base.py +12 -0
  17. langchain_core/exceptions.py +7 -0
  18. langchain_core/globals.py +17 -28
  19. langchain_core/indexing/api.py +57 -45
  20. langchain_core/indexing/base.py +5 -8
  21. langchain_core/indexing/in_memory.py +23 -3
  22. langchain_core/language_models/__init__.py +6 -2
  23. langchain_core/language_models/_utils.py +28 -4
  24. langchain_core/language_models/base.py +33 -21
  25. langchain_core/language_models/chat_models.py +103 -29
  26. langchain_core/language_models/fake_chat_models.py +5 -7
  27. langchain_core/language_models/llms.py +54 -20
  28. langchain_core/load/dump.py +2 -3
  29. langchain_core/load/load.py +15 -1
  30. langchain_core/load/serializable.py +38 -43
  31. langchain_core/memory.py +7 -3
  32. langchain_core/messages/__init__.py +7 -17
  33. langchain_core/messages/ai.py +41 -34
  34. langchain_core/messages/base.py +16 -7
  35. langchain_core/messages/block_translators/__init__.py +10 -8
  36. langchain_core/messages/block_translators/anthropic.py +3 -1
  37. langchain_core/messages/block_translators/bedrock.py +3 -1
  38. langchain_core/messages/block_translators/bedrock_converse.py +3 -1
  39. langchain_core/messages/block_translators/google_genai.py +3 -1
  40. langchain_core/messages/block_translators/google_vertexai.py +3 -1
  41. langchain_core/messages/block_translators/groq.py +3 -1
  42. langchain_core/messages/block_translators/langchain_v0.py +3 -136
  43. langchain_core/messages/block_translators/ollama.py +3 -1
  44. langchain_core/messages/block_translators/openai.py +252 -10
  45. langchain_core/messages/content.py +26 -124
  46. langchain_core/messages/human.py +2 -13
  47. langchain_core/messages/system.py +2 -6
  48. langchain_core/messages/tool.py +34 -14
  49. langchain_core/messages/utils.py +189 -74
  50. langchain_core/output_parsers/base.py +5 -2
  51. langchain_core/output_parsers/json.py +4 -4
  52. langchain_core/output_parsers/list.py +7 -22
  53. langchain_core/output_parsers/openai_functions.py +3 -0
  54. langchain_core/output_parsers/openai_tools.py +6 -1
  55. langchain_core/output_parsers/pydantic.py +4 -0
  56. langchain_core/output_parsers/string.py +5 -1
  57. langchain_core/output_parsers/xml.py +19 -19
  58. langchain_core/outputs/chat_generation.py +18 -7
  59. langchain_core/outputs/generation.py +14 -3
  60. langchain_core/outputs/llm_result.py +8 -1
  61. langchain_core/prompt_values.py +10 -4
  62. langchain_core/prompts/base.py +6 -11
  63. langchain_core/prompts/chat.py +88 -60
  64. langchain_core/prompts/dict.py +16 -8
  65. langchain_core/prompts/few_shot.py +9 -11
  66. langchain_core/prompts/few_shot_with_templates.py +5 -1
  67. langchain_core/prompts/image.py +12 -5
  68. langchain_core/prompts/loading.py +2 -2
  69. langchain_core/prompts/message.py +5 -6
  70. langchain_core/prompts/pipeline.py +13 -8
  71. langchain_core/prompts/prompt.py +22 -8
  72. langchain_core/prompts/string.py +18 -10
  73. langchain_core/prompts/structured.py +7 -2
  74. langchain_core/rate_limiters.py +2 -2
  75. langchain_core/retrievers.py +7 -6
  76. langchain_core/runnables/base.py +387 -246
  77. langchain_core/runnables/branch.py +11 -28
  78. langchain_core/runnables/config.py +20 -17
  79. langchain_core/runnables/configurable.py +34 -19
  80. langchain_core/runnables/fallbacks.py +20 -13
  81. langchain_core/runnables/graph.py +48 -38
  82. langchain_core/runnables/graph_ascii.py +40 -17
  83. langchain_core/runnables/graph_mermaid.py +54 -25
  84. langchain_core/runnables/graph_png.py +27 -31
  85. langchain_core/runnables/history.py +55 -58
  86. langchain_core/runnables/passthrough.py +44 -21
  87. langchain_core/runnables/retry.py +44 -23
  88. langchain_core/runnables/router.py +9 -8
  89. langchain_core/runnables/schema.py +9 -0
  90. langchain_core/runnables/utils.py +53 -90
  91. langchain_core/stores.py +19 -31
  92. langchain_core/sys_info.py +9 -8
  93. langchain_core/tools/base.py +36 -27
  94. langchain_core/tools/convert.py +25 -14
  95. langchain_core/tools/simple.py +36 -8
  96. langchain_core/tools/structured.py +25 -12
  97. langchain_core/tracers/base.py +2 -2
  98. langchain_core/tracers/context.py +5 -1
  99. langchain_core/tracers/core.py +110 -46
  100. langchain_core/tracers/evaluation.py +22 -26
  101. langchain_core/tracers/event_stream.py +97 -42
  102. langchain_core/tracers/langchain.py +12 -3
  103. langchain_core/tracers/langchain_v1.py +10 -2
  104. langchain_core/tracers/log_stream.py +56 -17
  105. langchain_core/tracers/root_listeners.py +4 -20
  106. langchain_core/tracers/run_collector.py +6 -16
  107. langchain_core/tracers/schemas.py +5 -1
  108. langchain_core/utils/aiter.py +14 -6
  109. langchain_core/utils/env.py +3 -0
  110. langchain_core/utils/function_calling.py +46 -20
  111. langchain_core/utils/interactive_env.py +6 -2
  112. langchain_core/utils/iter.py +12 -5
  113. langchain_core/utils/json.py +12 -3
  114. langchain_core/utils/json_schema.py +156 -40
  115. langchain_core/utils/loading.py +5 -1
  116. langchain_core/utils/mustache.py +25 -16
  117. langchain_core/utils/pydantic.py +38 -9
  118. langchain_core/utils/utils.py +25 -9
  119. langchain_core/vectorstores/base.py +7 -20
  120. langchain_core/vectorstores/in_memory.py +20 -14
  121. langchain_core/vectorstores/utils.py +18 -12
  122. langchain_core/version.py +1 -1
  123. langchain_core-1.0.0a3.dist-info/METADATA +77 -0
  124. langchain_core-1.0.0a3.dist-info/RECORD +181 -0
  125. langchain_core/beta/__init__.py +0 -1
  126. langchain_core/beta/runnables/__init__.py +0 -1
  127. langchain_core/beta/runnables/context.py +0 -448
  128. langchain_core-1.0.0a1.dist-info/METADATA +0 -106
  129. langchain_core-1.0.0a1.dist-info/RECORD +0 -184
  130. {langchain_core-1.0.0a1.dist-info → langchain_core-1.0.0a3.dist-info}/WHEEL +0 -0
  131. {langchain_core-1.0.0a1.dist-info → langchain_core-1.0.0a3.dist-info}/entry_points.txt +0 -0
@@ -117,7 +117,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
117
117
 
118
118
  @classmethod
119
119
  def is_lc_serializable(cls) -> bool:
120
- """Return whether or not the class is serializable."""
120
+ """Return False as this class is not serializable."""
121
121
  return False
122
122
 
123
123
  validate_template: bool = False
@@ -153,7 +153,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
153
153
  self.template_format,
154
154
  self.input_variables + list(self.partial_variables),
155
155
  )
156
- elif self.template_format or None:
156
+ elif self.template_format:
157
157
  self.input_variables = [
158
158
  var
159
159
  for var in get_template_variables(
@@ -272,7 +272,7 @@ class FewShotChatMessagePromptTemplate(
272
272
 
273
273
  from langchain_core.prompts import (
274
274
  FewShotChatMessagePromptTemplate,
275
- ChatPromptTemplate
275
+ ChatPromptTemplate,
276
276
  )
277
277
 
278
278
  examples = [
@@ -281,7 +281,7 @@ class FewShotChatMessagePromptTemplate(
281
281
  ]
282
282
 
283
283
  example_prompt = ChatPromptTemplate.from_messages(
284
- [('human', 'What is {input}?'), ('ai', '{output}')]
284
+ [("human", "What is {input}?"), ("ai", "{output}")]
285
285
  )
286
286
 
287
287
  few_shot_prompt = FewShotChatMessagePromptTemplate(
@@ -292,9 +292,9 @@ class FewShotChatMessagePromptTemplate(
292
292
 
293
293
  final_prompt = ChatPromptTemplate.from_messages(
294
294
  [
295
- ('system', 'You are a helpful AI Assistant'),
295
+ ("system", "You are a helpful AI Assistant"),
296
296
  few_shot_prompt,
297
- ('human', '{input}'),
297
+ ("human", "{input}"),
298
298
  ]
299
299
  )
300
300
  final_prompt.format(input="What is 4+4?")
@@ -314,10 +314,7 @@ class FewShotChatMessagePromptTemplate(
314
314
  # ...
315
315
  ]
316
316
 
317
- to_vectorize = [
318
- " ".join(example.values())
319
- for example in examples
320
- ]
317
+ to_vectorize = [" ".join(example.values()) for example in examples]
321
318
  embeddings = OpenAIEmbeddings()
322
319
  vectorstore = Chroma.from_texts(
323
320
  to_vectorize, embeddings, metadatas=examples
@@ -355,6 +352,7 @@ class FewShotChatMessagePromptTemplate(
355
352
 
356
353
  # Use within an LLM
357
354
  from langchain_core.chat_models import ChatAnthropic
355
+
358
356
  chain = final_prompt | ChatAnthropic(model="claude-3-haiku-20240307")
359
357
  chain.invoke({"input": "What's 3+3?"})
360
358
 
@@ -369,7 +367,7 @@ class FewShotChatMessagePromptTemplate(
369
367
 
370
368
  @classmethod
371
369
  def is_lc_serializable(cls) -> bool:
372
- """Return whether or not the class is serializable."""
370
+ """Return False as this class is not serializable."""
373
371
  return False
374
372
 
375
373
  model_config = ConfigDict(
@@ -46,7 +46,11 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
46
46
 
47
47
  @classmethod
48
48
  def get_lc_namespace(cls) -> list[str]:
49
- """Get the namespace of the langchain object."""
49
+ """Get the namespace of the langchain object.
50
+
51
+ Returns:
52
+ ``["langchain", "prompts", "few_shot_with_templates"]``
53
+ """
50
54
  return ["langchain", "prompts", "few_shot_with_templates"]
51
55
 
52
56
  @model_validator(mode="before")
@@ -23,7 +23,12 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
23
23
  Options are: 'f-string', 'mustache', 'jinja2'."""
24
24
 
25
25
  def __init__(self, **kwargs: Any) -> None:
26
- """Create an image prompt template."""
26
+ """Create an image prompt template.
27
+
28
+ Raises:
29
+ ValueError: If the input variables contain ``'url'``, ``'path'``, or
30
+ ``'detail'``.
31
+ """
27
32
  if "input_variables" not in kwargs:
28
33
  kwargs["input_variables"] = []
29
34
 
@@ -44,7 +49,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
44
49
 
45
50
  @classmethod
46
51
  def get_lc_namespace(cls) -> list[str]:
47
- """Get the namespace of the langchain object."""
52
+ """Get the namespace of the langchain object.
53
+
54
+ Returns:
55
+ ``["langchain", "prompts", "image"]``
56
+ """
48
57
  return ["langchain", "prompts", "image"]
49
58
 
50
59
  def format_prompt(self, **kwargs: Any) -> PromptValue:
@@ -84,6 +93,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
84
93
  Raises:
85
94
  ValueError: If the url is not provided.
86
95
  ValueError: If the url is not a string.
96
+ ValueError: If ``'path'`` is provided in the template or kwargs.
87
97
 
88
98
  Example:
89
99
 
@@ -128,9 +138,6 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
128
138
 
129
139
  Returns:
130
140
  A formatted string.
131
-
132
- Raises:
133
- ValueError: If the path or url is not a string.
134
141
  """
135
142
  return await run_in_executor(None, self.format, **kwargs)
136
143
 
@@ -53,7 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
53
53
  template_path = Path(config.pop(f"{var_name}_path"))
54
54
  # Load the template.
55
55
  if template_path.suffix == ".txt":
56
- template = template_path.read_text()
56
+ template = template_path.read_text(encoding="utf-8")
57
57
  else:
58
58
  raise ValueError
59
59
  # Set the template variable to the extracted variable.
@@ -67,7 +67,7 @@ def _load_examples(config: dict) -> dict:
67
67
  pass
68
68
  elif isinstance(config["examples"], str):
69
69
  path = Path(config["examples"])
70
- with path.open() as f:
70
+ with path.open(encoding="utf-8") as f:
71
71
  if path.suffix == ".json":
72
72
  examples = json.load(f)
73
73
  elif path.suffix in {".yaml", ".yml"}:
@@ -18,17 +18,15 @@ class BaseMessagePromptTemplate(Serializable, ABC):
18
18
 
19
19
  @classmethod
20
20
  def is_lc_serializable(cls) -> bool:
21
- """Return whether or not the class is serializable.
22
-
23
- Returns: True.
24
- """
21
+ """Return True as this class is serializable."""
25
22
  return True
26
23
 
27
24
  @classmethod
28
25
  def get_lc_namespace(cls) -> list[str]:
29
26
  """Get the namespace of the langchain object.
30
27
 
31
- Default namespace is ["langchain", "prompts", "chat"].
28
+ Returns:
29
+ ``["langchain", "prompts", "chat"]``
32
30
  """
33
31
  return ["langchain", "prompts", "chat"]
34
32
 
@@ -90,7 +88,8 @@ class BaseMessagePromptTemplate(Serializable, ABC):
90
88
  Returns:
91
89
  Combined prompt template.
92
90
  """
93
- from langchain_core.prompts.chat import ChatPromptTemplate
91
+ # Import locally to avoid circular import.
92
+ from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
94
93
 
95
94
  prompt = ChatPromptTemplate(messages=[self])
96
95
  return prompt + other
@@ -39,23 +39,28 @@ class PipelinePromptTemplate(BasePromptTemplate):
39
39
  This can be useful when you want to reuse parts of prompts.
40
40
 
41
41
  A PipelinePrompt consists of two main parts:
42
- - final_prompt: This is the final prompt that is returned
43
- - pipeline_prompts: This is a list of tuples, consisting
44
- of a string (`name`) and a Prompt Template.
45
- Each PromptTemplate will be formatted and then passed
46
- to future prompt templates as a variable with
47
- the same name as `name`
42
+
43
+ - final_prompt: This is the final prompt that is returned
44
+ - pipeline_prompts: This is a list of tuples, consisting
45
+ of a string (``name``) and a Prompt Template.
46
+ Each PromptTemplate will be formatted and then passed
47
+ to future prompt templates as a variable with
48
+ the same name as ``name``
48
49
 
49
50
  """
50
51
 
51
52
  final_prompt: BasePromptTemplate
52
53
  """The final prompt that is returned."""
53
54
  pipeline_prompts: list[tuple[str, BasePromptTemplate]]
54
- """A list of tuples, consisting of a string (`name`) and a Prompt Template."""
55
+ """A list of tuples, consisting of a string (``name``) and a Prompt Template."""
55
56
 
56
57
  @classmethod
57
58
  def get_lc_namespace(cls) -> list[str]:
58
- """Get the namespace of the langchain object."""
59
+ """Get the namespace of the langchain object.
60
+
61
+ Returns:
62
+ ``["langchain", "prompts", "pipeline"]``
63
+ """
59
64
  return ["langchain", "prompts", "pipeline"]
60
65
 
61
66
  @model_validator(mode="before")
@@ -69,6 +69,11 @@ class PromptTemplate(StringPromptTemplate):
69
69
  @classmethod
70
70
  @override
71
71
  def get_lc_namespace(cls) -> list[str]:
72
+ """Get the namespace of the langchain object.
73
+
74
+ Returns:
75
+ ``["langchain", "prompts", "prompt"]``
76
+ """
72
77
  return ["langchain", "prompts", "prompt"]
73
78
 
74
79
  template: str
@@ -135,14 +140,20 @@ class PromptTemplate(StringPromptTemplate):
135
140
  return mustache_schema(self.template)
136
141
 
137
142
  def __add__(self, other: Any) -> PromptTemplate:
138
- """Override the + operator to allow for combining prompt templates."""
143
+ """Override the + operator to allow for combining prompt templates.
144
+
145
+ Raises:
146
+ ValueError: If the template formats are not f-string or if there are
147
+ conflicting partial variables.
148
+ NotImplementedError: If the other object is not a ``PromptTemplate`` or str.
149
+
150
+ Returns:
151
+ A new ``PromptTemplate`` that is the combination of the two.
152
+ """
139
153
  # Allow for easy combining
140
154
  if isinstance(other, PromptTemplate):
141
- if self.template_format != "f-string":
142
- msg = "Adding prompt templates only supported for f-strings."
143
- raise ValueError(msg)
144
- if other.template_format != "f-string":
145
- msg = "Adding prompt templates only supported for f-strings."
155
+ if self.template_format != other.template_format:
156
+ msg = "Cannot add templates of different formats"
146
157
  raise ValueError(msg)
147
158
  input_variables = list(
148
159
  set(self.input_variables) | set(other.input_variables)
@@ -160,11 +171,14 @@ class PromptTemplate(StringPromptTemplate):
160
171
  template=template,
161
172
  input_variables=input_variables,
162
173
  partial_variables=partial_variables,
163
- template_format="f-string",
174
+ template_format=self.template_format,
164
175
  validate_template=validate_template,
165
176
  )
166
177
  if isinstance(other, str):
167
- prompt = PromptTemplate.from_template(other)
178
+ prompt = PromptTemplate.from_template(
179
+ other,
180
+ template_format=self.template_format,
181
+ )
168
182
  return self + prompt
169
183
  msg = f"Unsupported operand type for +: {type(other)}"
170
184
  raise NotImplementedError(msg)
@@ -15,6 +15,14 @@ from langchain_core.utils import get_colored_text, mustache
15
15
  from langchain_core.utils.formatting import formatter
16
16
  from langchain_core.utils.interactive_env import is_interactive_env
17
17
 
18
+ try:
19
+ from jinja2 import Environment, meta
20
+ from jinja2.sandbox import SandboxedEnvironment
21
+
22
+ _HAS_JINJA2 = True
23
+ except ImportError:
24
+ _HAS_JINJA2 = False
25
+
18
26
  PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
19
27
 
20
28
 
@@ -40,9 +48,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
40
48
  Raises:
41
49
  ImportError: If jinja2 is not installed.
42
50
  """
43
- try:
44
- from jinja2.sandbox import SandboxedEnvironment
45
- except ImportError as e:
51
+ if not _HAS_JINJA2:
46
52
  msg = (
47
53
  "jinja2 not installed, which is needed to use the jinja2_formatter. "
48
54
  "Please install it with `pip install jinja2`."
@@ -50,7 +56,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
50
56
  "Do not expand jinja2 templates using unverified or user-controlled "
51
57
  "inputs as that can result in arbitrary Python code execution."
52
58
  )
53
- raise ImportError(msg) from e
59
+ raise ImportError(msg)
54
60
 
55
61
  # This uses a sandboxed environment to prevent arbitrary code execution.
56
62
  # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
@@ -88,14 +94,12 @@ def validate_jinja2(template: str, input_variables: list[str]) -> None:
88
94
 
89
95
 
90
96
  def _get_jinja2_variables_from_template(template: str) -> set[str]:
91
- try:
92
- from jinja2 import Environment, meta
93
- except ImportError as e:
97
+ if not _HAS_JINJA2:
94
98
  msg = (
95
99
  "jinja2 not installed, which is needed to use the jinja2_formatter. "
96
100
  "Please install it with `pip install jinja2`."
97
101
  )
98
- raise ImportError(msg) from e
102
+ raise ImportError(msg)
99
103
  env = Environment() # noqa: S701
100
104
  ast = env.parse(template)
101
105
  return meta.find_undeclared_variables(ast)
@@ -166,7 +170,7 @@ def mustache_schema(
166
170
  prefix = section_stack.pop()
167
171
  elif type_ in {"section", "inverted section"}:
168
172
  section_stack.append(prefix)
169
- prefix = prefix + tuple(key.split("."))
173
+ prefix += tuple(key.split("."))
170
174
  fields[prefix] = False
171
175
  elif type_ in {"variable", "no escape"}:
172
176
  fields[prefix + tuple(key.split("."))] = True
@@ -268,7 +272,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
268
272
 
269
273
  @classmethod
270
274
  def get_lc_namespace(cls) -> list[str]:
271
- """Get the namespace of the langchain object."""
275
+ """Get the namespace of the langchain object.
276
+
277
+ Returns:
278
+ ``["langchain", "prompts", "base"]``
279
+ """
272
280
  return ["langchain", "prompts", "base"]
273
281
 
274
282
  def format_prompt(self, **kwargs: Any) -> PromptValue:
@@ -68,8 +68,11 @@ class StructuredPrompt(ChatPromptTemplate):
68
68
  def get_lc_namespace(cls) -> list[str]:
69
69
  """Get the namespace of the langchain object.
70
70
 
71
- For example, if the class is `langchain.llms.openai.OpenAI`, then the
72
- namespace is ["langchain", "llms", "openai"]
71
+ For example, if the class is ``langchain.llms.openai.OpenAI``, then the
72
+ namespace is ``["langchain", "llms", "openai"]``
73
+
74
+ Returns:
75
+ The namespace of the langchain object.
73
76
  """
74
77
  return cls.__module__.split(".")
75
78
 
@@ -89,10 +92,12 @@ class StructuredPrompt(ChatPromptTemplate):
89
92
 
90
93
  from langchain_core.prompts import StructuredPrompt
91
94
 
95
+
92
96
  class OutputSchema(BaseModel):
93
97
  name: str
94
98
  value: int
95
99
 
100
+
96
101
  template = StructuredPrompt(
97
102
  [
98
103
  ("human", "Hello, how are you?"),
@@ -110,9 +110,9 @@ class InMemoryRateLimiter(BaseRateLimiter):
110
110
  )
111
111
 
112
112
  from langchain_anthropic import ChatAnthropic
113
+
113
114
  model = ChatAnthropic(
114
- model_name="claude-3-opus-20240229",
115
- rate_limiter=rate_limiter
115
+ model_name="claude-3-opus-20240229", rate_limiter=rate_limiter
116
116
  )
117
117
 
118
118
  for _ in range(5):
@@ -31,6 +31,7 @@ from typing_extensions import Self, TypedDict, override
31
31
 
32
32
  from langchain_core._api import deprecated
33
33
  from langchain_core.callbacks import Callbacks
34
+ from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
34
35
  from langchain_core.documents import Document
35
36
  from langchain_core.runnables import (
36
37
  Runnable,
@@ -109,6 +110,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
109
110
 
110
111
  from sklearn.metrics.pairwise import cosine_similarity
111
112
 
113
+
112
114
  class TFIDFRetriever(BaseRetriever, BaseModel):
113
115
  vectorizer: Any
114
116
  docs: list[Document]
@@ -122,10 +124,12 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
122
124
  # Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
123
125
  query_vec = self.vectorizer.transform([query])
124
126
  # Op -- (n_docs,1) -- Cosine Sim with each doc
125
- results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
127
+ results = cosine_similarity(self.tfidf_array, query_vec).reshape(
128
+ (-1,)
129
+ )
126
130
  return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
127
131
 
128
- """ # noqa: E501
132
+ """
129
133
 
130
134
  model_config = ConfigDict(
131
135
  arbitrary_types_allowed=True,
@@ -233,8 +237,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
233
237
  retriever.invoke("query")
234
238
 
235
239
  """
236
- from langchain_core.callbacks.manager import CallbackManager
237
-
238
240
  config = ensure_config(config)
239
241
  inheritable_metadata = {
240
242
  **(config.get("metadata") or {}),
@@ -298,8 +300,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
298
300
  await retriever.ainvoke("query")
299
301
 
300
302
  """
301
- from langchain_core.callbacks.manager import AsyncCallbackManager
302
-
303
303
  config = ensure_config(config)
304
304
  inheritable_metadata = {
305
305
  **(config.get("metadata") or {}),
@@ -359,6 +359,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
359
359
  Args:
360
360
  query: String to find relevant documents for
361
361
  run_manager: The callback handler to use
362
+
362
363
  Returns:
363
364
  List of relevant documents
364
365
  """