langchain-core 1.0.5__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. langchain_core/callbacks/manager.py +14 -14
  2. langchain_core/callbacks/usage.py +1 -1
  3. langchain_core/indexing/api.py +2 -0
  4. langchain_core/language_models/__init__.py +15 -5
  5. langchain_core/language_models/_utils.py +1 -0
  6. langchain_core/language_models/chat_models.py +74 -94
  7. langchain_core/language_models/llms.py +5 -3
  8. langchain_core/language_models/model_profile.py +84 -0
  9. langchain_core/load/load.py +14 -1
  10. langchain_core/messages/ai.py +12 -4
  11. langchain_core/messages/base.py +6 -6
  12. langchain_core/messages/block_translators/anthropic.py +27 -8
  13. langchain_core/messages/block_translators/bedrock_converse.py +18 -8
  14. langchain_core/messages/block_translators/google_genai.py +25 -10
  15. langchain_core/messages/content.py +1 -1
  16. langchain_core/messages/tool.py +28 -27
  17. langchain_core/messages/utils.py +45 -18
  18. langchain_core/output_parsers/openai_tools.py +9 -7
  19. langchain_core/output_parsers/pydantic.py +1 -1
  20. langchain_core/output_parsers/string.py +27 -1
  21. langchain_core/prompts/chat.py +22 -17
  22. langchain_core/prompts/string.py +29 -9
  23. langchain_core/prompts/structured.py +7 -1
  24. langchain_core/runnables/base.py +174 -160
  25. langchain_core/runnables/branch.py +1 -1
  26. langchain_core/runnables/config.py +25 -20
  27. langchain_core/runnables/fallbacks.py +1 -2
  28. langchain_core/runnables/graph.py +3 -2
  29. langchain_core/runnables/graph_mermaid.py +5 -1
  30. langchain_core/runnables/passthrough.py +2 -2
  31. langchain_core/tools/base.py +46 -2
  32. langchain_core/tools/convert.py +16 -0
  33. langchain_core/tools/retriever.py +29 -58
  34. langchain_core/tools/structured.py +14 -0
  35. langchain_core/tracers/event_stream.py +9 -4
  36. langchain_core/utils/aiter.py +3 -1
  37. langchain_core/utils/function_calling.py +7 -2
  38. langchain_core/utils/json_schema.py +29 -21
  39. langchain_core/utils/mustache.py +24 -9
  40. langchain_core/utils/pydantic.py +7 -7
  41. langchain_core/utils/uuid.py +54 -0
  42. langchain_core/vectorstores/base.py +26 -18
  43. langchain_core/version.py +1 -1
  44. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/METADATA +2 -1
  45. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/RECORD +46 -44
  46. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/WHEEL +1 -1
@@ -51,22 +51,22 @@ class InputTokenDetails(TypedDict, total=False):
51
51
  May also hold extra provider-specific keys.
52
52
 
53
53
  !!! version-added "Added in `langchain-core` 0.3.9"
54
-
55
54
  """
56
55
 
57
56
  audio: int
58
57
  """Audio input tokens."""
58
+
59
59
  cache_creation: int
60
60
  """Input tokens that were cached and there was a cache miss.
61
61
 
62
62
  Since there was a cache miss, the cache was created from these tokens.
63
63
  """
64
+
64
65
  cache_read: int
65
66
  """Input tokens that were cached and there was a cache hit.
66
67
 
67
68
  Since there was a cache hit, the tokens were read from the cache. More precisely,
68
69
  the model state given these tokens was read from the cache.
69
-
70
70
  """
71
71
 
72
72
 
@@ -91,12 +91,12 @@ class OutputTokenDetails(TypedDict, total=False):
91
91
 
92
92
  audio: int
93
93
  """Audio output tokens."""
94
+
94
95
  reasoning: int
95
96
  """Reasoning output tokens.
96
97
 
97
98
  Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1
98
99
  models) that are not returned as part of model output.
99
-
100
100
  """
101
101
 
102
102
 
@@ -124,9 +124,11 @@ class UsageMetadata(TypedDict):
124
124
  ```
125
125
 
126
126
  !!! warning "Behavior changed in `langchain-core` 0.3.9"
127
+
127
128
  Added `input_token_details` and `output_token_details`.
128
129
 
129
130
  !!! note "LangSmith SDK"
131
+
130
132
  The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
131
133
  LangSmith's `UsageMetadata` has additional fields to capture cost information
132
134
  used by the LangSmith platform.
@@ -134,15 +136,19 @@ class UsageMetadata(TypedDict):
134
136
 
135
137
  input_tokens: int
136
138
  """Count of input (or prompt) tokens. Sum of all input token types."""
139
+
137
140
  output_tokens: int
138
141
  """Count of output (or completion) tokens. Sum of all output token types."""
142
+
139
143
  total_tokens: int
140
144
  """Total token count. Sum of `input_tokens` + `output_tokens`."""
145
+
141
146
  input_token_details: NotRequired[InputTokenDetails]
142
147
  """Breakdown of input token counts.
143
148
 
144
149
  Does *not* need to sum to full input token count. Does *not* need to have all keys.
145
150
  """
151
+
146
152
  output_token_details: NotRequired[OutputTokenDetails]
147
153
  """Breakdown of output token counts.
148
154
 
@@ -162,8 +168,10 @@ class AIMessage(BaseMessage):
162
168
 
163
169
  tool_calls: list[ToolCall] = []
164
170
  """If present, tool calls associated with the message."""
171
+
165
172
  invalid_tool_calls: list[InvalidToolCall] = []
166
173
  """If present, tool calls with parsing errors associated with the message."""
174
+
167
175
  usage_metadata: UsageMetadata | None = None
168
176
  """If present, usage metadata for a message, such as token counts.
169
177
 
@@ -555,7 +563,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
555
563
 
556
564
  @model_validator(mode="after")
557
565
  def init_server_tool_calls(self) -> Self:
558
- """Parse `server_tool_call_chunks`."""
566
+ """Parse `server_tool_call_chunks` from [`ServerToolCallChunk`][langchain.messages.ServerToolCallChunk] objects.""" # noqa: E501
559
567
  if (
560
568
  self.chunk_position == "last"
561
569
  and self.response_metadata.get("output_version") == "v1"
@@ -391,12 +391,12 @@ class BaseMessageChunk(BaseMessage):
391
391
  Raises:
392
392
  TypeError: If the other object is not a message chunk.
393
393
 
394
- For example,
395
-
396
- `AIMessageChunk(content="Hello") + AIMessageChunk(content=" World")`
397
-
398
- will give `AIMessageChunk(content="Hello World")`
399
-
394
+ Example:
395
+ ```txt
396
+ AIMessageChunk(content="Hello", ...)
397
+ + AIMessageChunk(content=" World", ...)
398
+ = AIMessageChunk(content="Hello World", ...)
399
+ ```
400
400
  """
401
401
  if isinstance(other, BaseMessageChunk):
402
402
  # If both are (subclasses of) BaseMessageChunk,
@@ -245,11 +245,20 @@ def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock
245
245
  and message.chunk_position != "last"
246
246
  ):
247
247
  # Isolated chunk
248
- tool_call_chunk: types.ToolCallChunk = (
249
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
248
+ chunk = message.tool_call_chunks[0]
249
+
250
+ tool_call_chunk = types.ToolCallChunk(
251
+ name=chunk.get("name"),
252
+ id=chunk.get("id"),
253
+ args=chunk.get("args"),
254
+ type="tool_call_chunk",
250
255
  )
251
- if "type" not in tool_call_chunk:
252
- tool_call_chunk["type"] = "tool_call_chunk"
256
+ if "caller" in block:
257
+ tool_call_chunk["extras"] = {"caller": block["caller"]}
258
+
259
+ index = chunk.get("index")
260
+ if index is not None:
261
+ tool_call_chunk["index"] = index
253
262
  yield tool_call_chunk
254
263
  else:
255
264
  tool_call_block: types.ToolCall | None = None
@@ -282,17 +291,27 @@ def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock
282
291
  }
283
292
  if "index" in block:
284
293
  tool_call_block["index"] = block["index"]
294
+ if "caller" in block:
295
+ if "extras" not in tool_call_block:
296
+ tool_call_block["extras"] = {}
297
+ tool_call_block["extras"]["caller"] = block["caller"]
298
+
285
299
  yield tool_call_block
286
300
 
287
301
  elif block_type == "input_json_delta" and isinstance(
288
302
  message, AIMessageChunk
289
303
  ):
290
304
  if len(message.tool_call_chunks) == 1:
291
- tool_call_chunk = (
292
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
305
+ chunk = message.tool_call_chunks[0]
306
+ tool_call_chunk = types.ToolCallChunk(
307
+ name=chunk.get("name"),
308
+ id=chunk.get("id"),
309
+ args=chunk.get("args"),
310
+ type="tool_call_chunk",
293
311
  )
294
- if "type" not in tool_call_chunk:
295
- tool_call_chunk["type"] = "tool_call_chunk"
312
+ index = chunk.get("index")
313
+ if index is not None:
314
+ tool_call_chunk["index"] = index
296
315
  yield tool_call_chunk
297
316
 
298
317
  else:
@@ -209,11 +209,16 @@ def _convert_to_v1_from_converse(message: AIMessage) -> list[types.ContentBlock]
209
209
  and message.chunk_position != "last"
210
210
  ):
211
211
  # Isolated chunk
212
- tool_call_chunk: types.ToolCallChunk = (
213
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
212
+ chunk = message.tool_call_chunks[0]
213
+ tool_call_chunk = types.ToolCallChunk(
214
+ name=chunk.get("name"),
215
+ id=chunk.get("id"),
216
+ args=chunk.get("args"),
217
+ type="tool_call_chunk",
214
218
  )
215
- if "type" not in tool_call_chunk:
216
- tool_call_chunk["type"] = "tool_call_chunk"
219
+ index = chunk.get("index")
220
+ if index is not None:
221
+ tool_call_chunk["index"] = index
217
222
  yield tool_call_chunk
218
223
  else:
219
224
  tool_call_block: types.ToolCall | None = None
@@ -253,11 +258,16 @@ def _convert_to_v1_from_converse(message: AIMessage) -> list[types.ContentBlock]
253
258
  and isinstance(message, AIMessageChunk)
254
259
  and len(message.tool_call_chunks) == 1
255
260
  ):
256
- tool_call_chunk = (
257
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
261
+ chunk = message.tool_call_chunks[0]
262
+ tool_call_chunk = types.ToolCallChunk(
263
+ name=chunk.get("name"),
264
+ id=chunk.get("id"),
265
+ args=chunk.get("args"),
266
+ type="tool_call_chunk",
258
267
  )
259
- if "type" not in tool_call_chunk:
260
- tool_call_chunk["type"] = "tool_call_chunk"
268
+ index = chunk.get("index")
269
+ if index is not None:
270
+ tool_call_chunk["index"] = index
261
271
  yield tool_call_chunk
262
272
 
263
273
  else:
@@ -76,21 +76,36 @@ def translate_grounding_metadata_to_citations(
76
76
  for chunk_index in chunk_indices:
77
77
  if chunk_index < len(grounding_chunks):
78
78
  chunk = grounding_chunks[chunk_index]
79
- web_info = chunk.get("web", {})
79
+
80
+ # Handle web and maps grounding
81
+ web_info = chunk.get("web") or {}
82
+ maps_info = chunk.get("maps") or {}
83
+
84
+ # Extract citation info depending on source
85
+ url = maps_info.get("uri") or web_info.get("uri")
86
+ title = maps_info.get("title") or web_info.get("title")
87
+
88
+ # Note: confidence_scores is a legacy field from Gemini 2.0 and earlier
89
+ # that indicated confidence (0.0-1.0) for each grounding chunk.
90
+ #
91
+ # In Gemini 2.5+, this field is always None/empty and should be ignored.
92
+ extras_metadata = {
93
+ "web_search_queries": web_search_queries,
94
+ "grounding_chunk_index": chunk_index,
95
+ "confidence_scores": support.get("confidence_scores") or [],
96
+ }
97
+
98
+ # Add maps-specific metadata if present
99
+ if maps_info.get("placeId"):
100
+ extras_metadata["place_id"] = maps_info["placeId"]
80
101
 
81
102
  citation = create_citation(
82
- url=web_info.get("uri"),
83
- title=web_info.get("title"),
103
+ url=url,
104
+ title=title,
84
105
  start_index=start_index,
85
106
  end_index=end_index,
86
107
  cited_text=cited_text,
87
- extras={
88
- "google_ai_metadata": {
89
- "web_search_queries": web_search_queries,
90
- "grounding_chunk_index": chunk_index,
91
- "confidence_scores": support.get("confidence_scores", []),
92
- }
93
- },
108
+ google_ai_metadata=extras_metadata,
94
109
  )
95
110
  citations.append(citation)
96
111
 
@@ -654,7 +654,7 @@ class PlainTextContentBlock(TypedDict):
654
654
 
655
655
  !!! note
656
656
  Title and context are optional fields that may be passed to the model. See
657
- Anthropic [example](https://docs.claude.com/en/docs/build-with-claude/citations#citable-vs-non-citable-content).
657
+ Anthropic [example](https://platform.claude.com/docs/en/build-with-claude/citations#citable-vs-non-citable-content).
658
658
 
659
659
  !!! note "Factory function"
660
660
  `create_plaintext_block` may also be used as a factory to create a
@@ -29,38 +29,39 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
29
29
  `ToolMessage` objects contain the result of a tool invocation. Typically, the result
30
30
  is encoded inside the `content` field.
31
31
 
32
- Example: A `ToolMessage` representing a result of `42` from a tool call with id
32
+ `tool_call_id` is used to associate the tool call request with the tool call
33
+ response. Useful in situations where a chat model is able to request multiple tool
34
+ calls in parallel.
33
35
 
34
- ```python
35
- from langchain_core.messages import ToolMessage
36
-
37
- ToolMessage(content="42", tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL")
38
- ```
36
+ Example:
37
+ A `ToolMessage` representing a result of `42` from a tool call with id
39
38
 
40
- Example: A `ToolMessage` where only part of the tool output is sent to the model
41
- and the full output is passed in to artifact.
39
+ ```python
40
+ from langchain_core.messages import ToolMessage
42
41
 
43
- ```python
44
- from langchain_core.messages import ToolMessage
45
-
46
- tool_output = {
47
- "stdout": "From the graph we can see that the correlation between "
48
- "x and y is ...",
49
- "stderr": None,
50
- "artifacts": {"type": "image", "base64_data": "/9j/4gIcSU..."},
51
- }
52
-
53
- ToolMessage(
54
- content=tool_output["stdout"],
55
- artifact=tool_output,
56
- tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
57
- )
58
- ```
42
+ ToolMessage(content="42", tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL")
43
+ ```
59
44
 
60
- The `tool_call_id` field is used to associate the tool call request with the
61
- tool call response. Useful in situations where a chat model is able
62
- to request multiple tool calls in parallel.
45
+ Example:
46
+ A `ToolMessage` where only part of the tool output is sent to the model
47
+ and the full output is passed in to artifact.
63
48
 
49
+ ```python
50
+ from langchain_core.messages import ToolMessage
51
+
52
+ tool_output = {
53
+ "stdout": "From the graph we can see that the correlation between "
54
+ "x and y is ...",
55
+ "stderr": None,
56
+ "artifacts": {"type": "image", "base64_data": "/9j/4gIcSU..."},
57
+ }
58
+
59
+ ToolMessage(
60
+ content=tool_output["stdout"],
61
+ artifact=tool_output,
62
+ tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
63
+ )
64
+ ```
64
65
  """
65
66
 
66
67
  tool_call_id: str
@@ -15,12 +15,16 @@ import json
15
15
  import logging
16
16
  import math
17
17
  from collections.abc import Callable, Iterable, Sequence
18
- from functools import partial
18
+ from functools import partial, wraps
19
19
  from typing import (
20
20
  TYPE_CHECKING,
21
21
  Annotated,
22
22
  Any,
23
+ Concatenate,
23
24
  Literal,
25
+ ParamSpec,
26
+ Protocol,
27
+ TypeVar,
24
28
  cast,
25
29
  overload,
26
30
  )
@@ -384,33 +388,54 @@ def convert_to_messages(
384
388
  return [_convert_to_message(m) for m in messages]
385
389
 
386
390
 
387
- def _runnable_support(func: Callable) -> Callable:
388
- @overload
389
- def wrapped(
390
- messages: None = None, **kwargs: Any
391
- ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
391
+ _P = ParamSpec("_P")
392
+ _R_co = TypeVar("_R_co", covariant=True)
392
393
 
394
+
395
+ class _RunnableSupportCallable(Protocol[_P, _R_co]):
393
396
  @overload
394
- def wrapped(
395
- messages: Sequence[MessageLikeRepresentation], **kwargs: Any
396
- ) -> list[BaseMessage]: ...
397
+ def __call__(
398
+ self,
399
+ messages: None = None,
400
+ *args: _P.args,
401
+ **kwargs: _P.kwargs,
402
+ ) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
397
403
 
404
+ @overload
405
+ def __call__(
406
+ self,
407
+ messages: Sequence[MessageLikeRepresentation] | PromptValue,
408
+ *args: _P.args,
409
+ **kwargs: _P.kwargs,
410
+ ) -> _R_co: ...
411
+
412
+ def __call__(
413
+ self,
414
+ messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
415
+ *args: _P.args,
416
+ **kwargs: _P.kwargs,
417
+ ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
418
+
419
+
420
+ def _runnable_support(
421
+ func: Callable[
422
+ Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
423
+ ],
424
+ ) -> _RunnableSupportCallable[_P, _R_co]:
425
+ @wraps(func)
398
426
  def wrapped(
399
- messages: Sequence[MessageLikeRepresentation] | None = None,
400
- **kwargs: Any,
401
- ) -> (
402
- list[BaseMessage]
403
- | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
404
- ):
427
+ messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
428
+ *args: _P.args,
429
+ **kwargs: _P.kwargs,
430
+ ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
405
431
  # Import locally to prevent circular import.
406
432
  from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
407
433
 
408
434
  if messages is not None:
409
- return func(messages, **kwargs)
435
+ return func(messages, *args, **kwargs)
410
436
  return RunnableLambda(partial(func, **kwargs), name=func.__name__)
411
437
 
412
- wrapped.__doc__ = func.__doc__
413
- return wrapped
438
+ return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)
414
439
 
415
440
 
416
441
  @_runnable_support
@@ -738,8 +763,10 @@ def trim_messages(
738
763
  Set to `len` to count the number of **messages** in the chat history.
739
764
 
740
765
  !!! note
766
+
741
767
  Use `count_tokens_approximately` to get fast, approximate token
742
768
  counts.
769
+
743
770
  This is recommended for using `trim_messages` on the hot path, where
744
771
  exact token counting is not necessary.
745
772
 
@@ -47,22 +47,24 @@ def parse_tool_call(
47
47
  """
48
48
  if "function" not in raw_tool_call:
49
49
  return None
50
+
51
+ arguments = raw_tool_call["function"]["arguments"]
52
+
50
53
  if partial:
51
54
  try:
52
- function_args = parse_partial_json(
53
- raw_tool_call["function"]["arguments"], strict=strict
54
- )
55
+ function_args = parse_partial_json(arguments, strict=strict)
55
56
  except (JSONDecodeError, TypeError): # None args raise TypeError
56
57
  return None
58
+ # Handle None or empty string arguments for parameter-less tools
59
+ elif not arguments:
60
+ function_args = {}
57
61
  else:
58
62
  try:
59
- function_args = json.loads(
60
- raw_tool_call["function"]["arguments"], strict=strict
61
- )
63
+ function_args = json.loads(arguments, strict=strict)
62
64
  except JSONDecodeError as e:
63
65
  msg = (
64
66
  f"Function {raw_tool_call['function']['name']} arguments:\n\n"
65
- f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
67
+ f"{arguments}\n\nare not valid JSON. "
66
68
  f"Received JSONDecodeError {e}"
67
69
  )
68
70
  raise OutputParserException(msg) from e
@@ -37,7 +37,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
37
37
  def _parser_exception(
38
38
  self, e: Exception, json_object: dict
39
39
  ) -> OutputParserException:
40
- json_string = json.dumps(json_object)
40
+ json_string = json.dumps(json_object, ensure_ascii=False)
41
41
  name = self.pydantic_object.__name__
42
42
  msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
43
43
  return OutputParserException(msg, llm_output=json_string)
@@ -6,7 +6,33 @@ from langchain_core.output_parsers.transform import BaseTransformOutputParser
6
6
 
7
7
 
8
8
  class StrOutputParser(BaseTransformOutputParser[str]):
9
- """OutputParser that parses `LLMResult` into the top likely string."""
9
+ """Extract text content from model outputs as a string.
10
+
11
+ Converts model outputs (such as `AIMessage` or `AIMessageChunk` objects) into plain
12
+ text strings. It's the simplest output parser and is useful when you need string
13
+ responses for downstream processing, display, or storage.
14
+
15
+ Supports streaming, yielding text chunks as they're generated by the model.
16
+
17
+ Example:
18
+ ```python
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_openai import ChatOpenAI
21
+
22
+ model = ChatOpenAI(model="gpt-4o")
23
+ parser = StrOutputParser()
24
+
25
+ # Get string output from a model
26
+ message = model.invoke("Tell me a joke")
27
+ result = parser.invoke(message)
28
+ print(result) # plain string
29
+
30
+ # With streaming - use transform() to process a stream
31
+ stream = model.stream("Tell me a story")
32
+ for chunk in parser.transform(stream):
33
+ print(chunk, end="", flush=True)
34
+ ```
35
+ """
10
36
 
11
37
  @classmethod
12
38
  def is_lc_serializable(cls) -> bool:
@@ -903,23 +903,28 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
903
903
  5. A string which is shorthand for `("human", template)`; e.g.,
904
904
  `"{user_input}"`
905
905
  template_format: Format of the template.
906
- input_variables: A list of the names of the variables whose values are
907
- required as inputs to the prompt.
908
- optional_variables: A list of the names of the variables for placeholder
909
- or MessagePlaceholder that are optional.
910
-
911
- These variables are auto inferred from the prompt and user need not
912
- provide them.
913
- partial_variables: A dictionary of the partial variables the prompt
914
- template carries.
915
-
916
- Partial variables populate the template so that you don't need to pass
917
- them in every time you call the prompt.
918
- validate_template: Whether to validate the template.
919
- input_types: A dictionary of the types of the variables the prompt template
920
- expects.
921
-
922
- If not provided, all variables are assumed to be strings.
906
+ **kwargs: Additional keyword arguments passed to `BasePromptTemplate`,
907
+ including (but not limited to):
908
+
909
+ - `input_variables`: A list of the names of the variables whose values
910
+ are required as inputs to the prompt.
911
+ - `optional_variables`: A list of the names of the variables for
912
+ placeholder or `MessagePlaceholder` that are optional.
913
+
914
+ These variables are auto inferred from the prompt and user need not
915
+ provide them.
916
+
917
+ - `partial_variables`: A dictionary of the partial variables the prompt
918
+ template carries.
919
+
920
+ Partial variables populate the template so that you don't need to
921
+ pass them in every time you call the prompt.
922
+
923
+ - `validate_template`: Whether to validate the template.
924
+ - `input_types`: A dictionary of the types of the variables the prompt
925
+ template expects.
926
+
927
+ If not provided, all variables are assumed to be strings.
923
928
 
924
929
  Examples:
925
930
  Instantiation from a list of message templates:
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  from collections.abc import Callable, Sequence
20
20
 
21
21
  try:
22
- from jinja2 import Environment, meta
22
+ from jinja2 import meta
23
23
  from jinja2.sandbox import SandboxedEnvironment
24
24
 
25
25
  _HAS_JINJA2 = True
@@ -61,13 +61,9 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
61
61
  )
62
62
  raise ImportError(msg)
63
63
 
64
- # This uses a sandboxed environment to prevent arbitrary code execution.
65
- # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
66
- # Please treat this sand-boxing as a best-effort approach rather than
67
- # a guarantee of security.
68
- # We recommend to never use jinja2 templates with untrusted inputs.
69
- # https://jinja.palletsprojects.com/en/3.1.x/sandbox/
70
- # approach not a guarantee of security.
64
+ # Use a restricted sandbox that blocks ALL attribute/method access
65
+ # Only simple variable lookups like {{variable}} are allowed
66
+ # Attribute access like {{variable.attr}} or {{variable.method()}} is blocked
71
67
  return SandboxedEnvironment().from_string(template).render(**kwargs)
72
68
 
73
69
 
@@ -103,7 +99,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
103
99
  "Please install it with `pip install jinja2`."
104
100
  )
105
101
  raise ImportError(msg)
106
- env = Environment() # noqa: S701
102
+ env = SandboxedEnvironment()
107
103
  ast = env.parse(template)
108
104
  return meta.find_undeclared_variables(ast)
109
105
 
@@ -273,6 +269,30 @@ def get_template_variables(template: str, template_format: str) -> list[str]:
273
269
  msg = f"Unsupported template format: {template_format}"
274
270
  raise ValueError(msg)
275
271
 
272
+ # For f-strings, block attribute access and indexing syntax
273
+ # This prevents template injection attacks via accessing dangerous attributes
274
+ if template_format == "f-string":
275
+ for var in input_variables:
276
+ # Formatter().parse() returns field names with dots/brackets if present
277
+ # e.g., "obj.attr" or "obj[0]" - we need to block these
278
+ if "." in var or "[" in var or "]" in var:
279
+ msg = (
280
+ f"Invalid variable name {var!r} in f-string template. "
281
+ f"Variable names cannot contain attribute "
282
+ f"access (.) or indexing ([])."
283
+ )
284
+ raise ValueError(msg)
285
+
286
+ # Block variable names that are all digits (e.g., "0", "100")
287
+ # These are interpreted as positional arguments, not keyword arguments
288
+ if var.isdigit():
289
+ msg = (
290
+ f"Invalid variable name {var!r} in f-string template. "
291
+ f"Variable names cannot be all digits as they are interpreted "
292
+ f"as positional arguments."
293
+ )
294
+ raise ValueError(msg)
295
+
276
296
  return sorted(input_variables)
277
297
 
278
298
 
@@ -49,7 +49,13 @@ class StructuredPrompt(ChatPromptTemplate):
49
49
  structured_output_kwargs: additional kwargs for structured output.
50
50
  template_format: template format for the prompt.
51
51
  """
52
- schema_ = schema_ or kwargs.pop("schema")
52
+ schema_ = schema_ or kwargs.pop("schema", None)
53
+ if not schema_:
54
+ err_msg = (
55
+ "Must pass in a non-empty structured output schema. Received: "
56
+ f"{schema_}"
57
+ )
58
+ raise ValueError(err_msg)
53
59
  structured_output_kwargs = structured_output_kwargs or {}
54
60
  for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
55
61
  structured_output_kwargs[k] = kwargs.pop(k)