langchain-core 1.0.7__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. langchain_core/callbacks/manager.py +14 -14
  2. langchain_core/callbacks/usage.py +1 -1
  3. langchain_core/indexing/api.py +2 -0
  4. langchain_core/language_models/__init__.py +15 -5
  5. langchain_core/language_models/_utils.py +1 -0
  6. langchain_core/language_models/chat_models.py +74 -94
  7. langchain_core/language_models/llms.py +5 -3
  8. langchain_core/language_models/model_profile.py +84 -0
  9. langchain_core/load/load.py +14 -1
  10. langchain_core/messages/ai.py +12 -4
  11. langchain_core/messages/base.py +6 -6
  12. langchain_core/messages/block_translators/anthropic.py +27 -8
  13. langchain_core/messages/block_translators/bedrock_converse.py +18 -8
  14. langchain_core/messages/block_translators/google_genai.py +25 -10
  15. langchain_core/messages/content.py +1 -1
  16. langchain_core/messages/tool.py +28 -27
  17. langchain_core/messages/utils.py +45 -18
  18. langchain_core/output_parsers/openai_tools.py +9 -7
  19. langchain_core/output_parsers/pydantic.py +1 -1
  20. langchain_core/output_parsers/string.py +27 -1
  21. langchain_core/prompts/chat.py +22 -17
  22. langchain_core/prompts/string.py +2 -59
  23. langchain_core/prompts/structured.py +7 -1
  24. langchain_core/runnables/base.py +174 -160
  25. langchain_core/runnables/branch.py +1 -1
  26. langchain_core/runnables/config.py +25 -20
  27. langchain_core/runnables/fallbacks.py +1 -2
  28. langchain_core/runnables/passthrough.py +2 -2
  29. langchain_core/tools/base.py +23 -4
  30. langchain_core/tools/convert.py +16 -0
  31. langchain_core/tools/retriever.py +29 -58
  32. langchain_core/tracers/event_stream.py +9 -4
  33. langchain_core/utils/aiter.py +3 -1
  34. langchain_core/utils/function_calling.py +7 -2
  35. langchain_core/utils/json_schema.py +29 -21
  36. langchain_core/utils/pydantic.py +7 -7
  37. langchain_core/utils/uuid.py +54 -0
  38. langchain_core/vectorstores/base.py +26 -18
  39. langchain_core/version.py +1 -1
  40. {langchain_core-1.0.7.dist-info → langchain_core-1.2.1.dist-info}/METADATA +2 -1
  41. {langchain_core-1.0.7.dist-info → langchain_core-1.2.1.dist-info}/RECORD +42 -40
  42. {langchain_core-1.0.7.dist-info → langchain_core-1.2.1.dist-info}/WHEEL +1 -1
@@ -51,22 +51,22 @@ class InputTokenDetails(TypedDict, total=False):
51
51
  May also hold extra provider-specific keys.
52
52
 
53
53
  !!! version-added "Added in `langchain-core` 0.3.9"
54
-
55
54
  """
56
55
 
57
56
  audio: int
58
57
  """Audio input tokens."""
58
+
59
59
  cache_creation: int
60
60
  """Input tokens that were cached and there was a cache miss.
61
61
 
62
62
  Since there was a cache miss, the cache was created from these tokens.
63
63
  """
64
+
64
65
  cache_read: int
65
66
  """Input tokens that were cached and there was a cache hit.
66
67
 
67
68
  Since there was a cache hit, the tokens were read from the cache. More precisely,
68
69
  the model state given these tokens was read from the cache.
69
-
70
70
  """
71
71
 
72
72
 
@@ -91,12 +91,12 @@ class OutputTokenDetails(TypedDict, total=False):
91
91
 
92
92
  audio: int
93
93
  """Audio output tokens."""
94
+
94
95
  reasoning: int
95
96
  """Reasoning output tokens.
96
97
 
97
98
  Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1
98
99
  models) that are not returned as part of model output.
99
-
100
100
  """
101
101
 
102
102
 
@@ -124,9 +124,11 @@ class UsageMetadata(TypedDict):
124
124
  ```
125
125
 
126
126
  !!! warning "Behavior changed in `langchain-core` 0.3.9"
127
+
127
128
  Added `input_token_details` and `output_token_details`.
128
129
 
129
130
  !!! note "LangSmith SDK"
131
+
130
132
  The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
131
133
  LangSmith's `UsageMetadata` has additional fields to capture cost information
132
134
  used by the LangSmith platform.
@@ -134,15 +136,19 @@ class UsageMetadata(TypedDict):
134
136
 
135
137
  input_tokens: int
136
138
  """Count of input (or prompt) tokens. Sum of all input token types."""
139
+
137
140
  output_tokens: int
138
141
  """Count of output (or completion) tokens. Sum of all output token types."""
142
+
139
143
  total_tokens: int
140
144
  """Total token count. Sum of `input_tokens` + `output_tokens`."""
145
+
141
146
  input_token_details: NotRequired[InputTokenDetails]
142
147
  """Breakdown of input token counts.
143
148
 
144
149
  Does *not* need to sum to full input token count. Does *not* need to have all keys.
145
150
  """
151
+
146
152
  output_token_details: NotRequired[OutputTokenDetails]
147
153
  """Breakdown of output token counts.
148
154
 
@@ -162,8 +168,10 @@ class AIMessage(BaseMessage):
162
168
 
163
169
  tool_calls: list[ToolCall] = []
164
170
  """If present, tool calls associated with the message."""
171
+
165
172
  invalid_tool_calls: list[InvalidToolCall] = []
166
173
  """If present, tool calls with parsing errors associated with the message."""
174
+
167
175
  usage_metadata: UsageMetadata | None = None
168
176
  """If present, usage metadata for a message, such as token counts.
169
177
 
@@ -555,7 +563,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
555
563
 
556
564
  @model_validator(mode="after")
557
565
  def init_server_tool_calls(self) -> Self:
558
- """Parse `server_tool_call_chunks`."""
566
+ """Parse `server_tool_call_chunks` from [`ServerToolCallChunk`][langchain.messages.ServerToolCallChunk] objects.""" # noqa: E501
559
567
  if (
560
568
  self.chunk_position == "last"
561
569
  and self.response_metadata.get("output_version") == "v1"
@@ -391,12 +391,12 @@ class BaseMessageChunk(BaseMessage):
391
391
  Raises:
392
392
  TypeError: If the other object is not a message chunk.
393
393
 
394
- For example,
395
-
396
- `AIMessageChunk(content="Hello") + AIMessageChunk(content=" World")`
397
-
398
- will give `AIMessageChunk(content="Hello World")`
399
-
394
+ Example:
395
+ ```txt
396
+ AIMessageChunk(content="Hello", ...)
397
+ + AIMessageChunk(content=" World", ...)
398
+ = AIMessageChunk(content="Hello World", ...)
399
+ ```
400
400
  """
401
401
  if isinstance(other, BaseMessageChunk):
402
402
  # If both are (subclasses of) BaseMessageChunk,
@@ -245,11 +245,20 @@ def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock
245
245
  and message.chunk_position != "last"
246
246
  ):
247
247
  # Isolated chunk
248
- tool_call_chunk: types.ToolCallChunk = (
249
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
248
+ chunk = message.tool_call_chunks[0]
249
+
250
+ tool_call_chunk = types.ToolCallChunk(
251
+ name=chunk.get("name"),
252
+ id=chunk.get("id"),
253
+ args=chunk.get("args"),
254
+ type="tool_call_chunk",
250
255
  )
251
- if "type" not in tool_call_chunk:
252
- tool_call_chunk["type"] = "tool_call_chunk"
256
+ if "caller" in block:
257
+ tool_call_chunk["extras"] = {"caller": block["caller"]}
258
+
259
+ index = chunk.get("index")
260
+ if index is not None:
261
+ tool_call_chunk["index"] = index
253
262
  yield tool_call_chunk
254
263
  else:
255
264
  tool_call_block: types.ToolCall | None = None
@@ -282,17 +291,27 @@ def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock
282
291
  }
283
292
  if "index" in block:
284
293
  tool_call_block["index"] = block["index"]
294
+ if "caller" in block:
295
+ if "extras" not in tool_call_block:
296
+ tool_call_block["extras"] = {}
297
+ tool_call_block["extras"]["caller"] = block["caller"]
298
+
285
299
  yield tool_call_block
286
300
 
287
301
  elif block_type == "input_json_delta" and isinstance(
288
302
  message, AIMessageChunk
289
303
  ):
290
304
  if len(message.tool_call_chunks) == 1:
291
- tool_call_chunk = (
292
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
305
+ chunk = message.tool_call_chunks[0]
306
+ tool_call_chunk = types.ToolCallChunk(
307
+ name=chunk.get("name"),
308
+ id=chunk.get("id"),
309
+ args=chunk.get("args"),
310
+ type="tool_call_chunk",
293
311
  )
294
- if "type" not in tool_call_chunk:
295
- tool_call_chunk["type"] = "tool_call_chunk"
312
+ index = chunk.get("index")
313
+ if index is not None:
314
+ tool_call_chunk["index"] = index
296
315
  yield tool_call_chunk
297
316
 
298
317
  else:
@@ -209,11 +209,16 @@ def _convert_to_v1_from_converse(message: AIMessage) -> list[types.ContentBlock]
209
209
  and message.chunk_position != "last"
210
210
  ):
211
211
  # Isolated chunk
212
- tool_call_chunk: types.ToolCallChunk = (
213
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
212
+ chunk = message.tool_call_chunks[0]
213
+ tool_call_chunk = types.ToolCallChunk(
214
+ name=chunk.get("name"),
215
+ id=chunk.get("id"),
216
+ args=chunk.get("args"),
217
+ type="tool_call_chunk",
214
218
  )
215
- if "type" not in tool_call_chunk:
216
- tool_call_chunk["type"] = "tool_call_chunk"
219
+ index = chunk.get("index")
220
+ if index is not None:
221
+ tool_call_chunk["index"] = index
217
222
  yield tool_call_chunk
218
223
  else:
219
224
  tool_call_block: types.ToolCall | None = None
@@ -253,11 +258,16 @@ def _convert_to_v1_from_converse(message: AIMessage) -> list[types.ContentBlock]
253
258
  and isinstance(message, AIMessageChunk)
254
259
  and len(message.tool_call_chunks) == 1
255
260
  ):
256
- tool_call_chunk = (
257
- message.tool_call_chunks[0].copy() # type: ignore[assignment]
261
+ chunk = message.tool_call_chunks[0]
262
+ tool_call_chunk = types.ToolCallChunk(
263
+ name=chunk.get("name"),
264
+ id=chunk.get("id"),
265
+ args=chunk.get("args"),
266
+ type="tool_call_chunk",
258
267
  )
259
- if "type" not in tool_call_chunk:
260
- tool_call_chunk["type"] = "tool_call_chunk"
268
+ index = chunk.get("index")
269
+ if index is not None:
270
+ tool_call_chunk["index"] = index
261
271
  yield tool_call_chunk
262
272
 
263
273
  else:
@@ -76,21 +76,36 @@ def translate_grounding_metadata_to_citations(
76
76
  for chunk_index in chunk_indices:
77
77
  if chunk_index < len(grounding_chunks):
78
78
  chunk = grounding_chunks[chunk_index]
79
- web_info = chunk.get("web", {})
79
+
80
+ # Handle web and maps grounding
81
+ web_info = chunk.get("web") or {}
82
+ maps_info = chunk.get("maps") or {}
83
+
84
+ # Extract citation info depending on source
85
+ url = maps_info.get("uri") or web_info.get("uri")
86
+ title = maps_info.get("title") or web_info.get("title")
87
+
88
+ # Note: confidence_scores is a legacy field from Gemini 2.0 and earlier
89
+ # that indicated confidence (0.0-1.0) for each grounding chunk.
90
+ #
91
+ # In Gemini 2.5+, this field is always None/empty and should be ignored.
92
+ extras_metadata = {
93
+ "web_search_queries": web_search_queries,
94
+ "grounding_chunk_index": chunk_index,
95
+ "confidence_scores": support.get("confidence_scores") or [],
96
+ }
97
+
98
+ # Add maps-specific metadata if present
99
+ if maps_info.get("placeId"):
100
+ extras_metadata["place_id"] = maps_info["placeId"]
80
101
 
81
102
  citation = create_citation(
82
- url=web_info.get("uri"),
83
- title=web_info.get("title"),
103
+ url=url,
104
+ title=title,
84
105
  start_index=start_index,
85
106
  end_index=end_index,
86
107
  cited_text=cited_text,
87
- extras={
88
- "google_ai_metadata": {
89
- "web_search_queries": web_search_queries,
90
- "grounding_chunk_index": chunk_index,
91
- "confidence_scores": support.get("confidence_scores", []),
92
- }
93
- },
108
+ google_ai_metadata=extras_metadata,
94
109
  )
95
110
  citations.append(citation)
96
111
 
@@ -654,7 +654,7 @@ class PlainTextContentBlock(TypedDict):
654
654
 
655
655
  !!! note
656
656
  Title and context are optional fields that may be passed to the model. See
657
- Anthropic [example](https://docs.claude.com/en/docs/build-with-claude/citations#citable-vs-non-citable-content).
657
+ Anthropic [example](https://platform.claude.com/docs/en/build-with-claude/citations#citable-vs-non-citable-content).
658
658
 
659
659
  !!! note "Factory function"
660
660
  `create_plaintext_block` may also be used as a factory to create a
@@ -29,38 +29,39 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
29
29
  `ToolMessage` objects contain the result of a tool invocation. Typically, the result
30
30
  is encoded inside the `content` field.
31
31
 
32
- Example: A `ToolMessage` representing a result of `42` from a tool call with id
32
+ `tool_call_id` is used to associate the tool call request with the tool call
33
+ response. Useful in situations where a chat model is able to request multiple tool
34
+ calls in parallel.
33
35
 
34
- ```python
35
- from langchain_core.messages import ToolMessage
36
-
37
- ToolMessage(content="42", tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL")
38
- ```
36
+ Example:
37
+ A `ToolMessage` representing a result of `42` from a tool call with id
39
38
 
40
- Example: A `ToolMessage` where only part of the tool output is sent to the model
41
- and the full output is passed in to artifact.
39
+ ```python
40
+ from langchain_core.messages import ToolMessage
42
41
 
43
- ```python
44
- from langchain_core.messages import ToolMessage
45
-
46
- tool_output = {
47
- "stdout": "From the graph we can see that the correlation between "
48
- "x and y is ...",
49
- "stderr": None,
50
- "artifacts": {"type": "image", "base64_data": "/9j/4gIcSU..."},
51
- }
52
-
53
- ToolMessage(
54
- content=tool_output["stdout"],
55
- artifact=tool_output,
56
- tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
57
- )
58
- ```
42
+ ToolMessage(content="42", tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL")
43
+ ```
59
44
 
60
- The `tool_call_id` field is used to associate the tool call request with the
61
- tool call response. Useful in situations where a chat model is able
62
- to request multiple tool calls in parallel.
45
+ Example:
46
+ A `ToolMessage` where only part of the tool output is sent to the model
47
+ and the full output is passed in to artifact.
63
48
 
49
+ ```python
50
+ from langchain_core.messages import ToolMessage
51
+
52
+ tool_output = {
53
+ "stdout": "From the graph we can see that the correlation between "
54
+ "x and y is ...",
55
+ "stderr": None,
56
+ "artifacts": {"type": "image", "base64_data": "/9j/4gIcSU..."},
57
+ }
58
+
59
+ ToolMessage(
60
+ content=tool_output["stdout"],
61
+ artifact=tool_output,
62
+ tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
63
+ )
64
+ ```
64
65
  """
65
66
 
66
67
  tool_call_id: str
@@ -15,12 +15,16 @@ import json
15
15
  import logging
16
16
  import math
17
17
  from collections.abc import Callable, Iterable, Sequence
18
- from functools import partial
18
+ from functools import partial, wraps
19
19
  from typing import (
20
20
  TYPE_CHECKING,
21
21
  Annotated,
22
22
  Any,
23
+ Concatenate,
23
24
  Literal,
25
+ ParamSpec,
26
+ Protocol,
27
+ TypeVar,
24
28
  cast,
25
29
  overload,
26
30
  )
@@ -384,33 +388,54 @@ def convert_to_messages(
384
388
  return [_convert_to_message(m) for m in messages]
385
389
 
386
390
 
387
- def _runnable_support(func: Callable) -> Callable:
388
- @overload
389
- def wrapped(
390
- messages: None = None, **kwargs: Any
391
- ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
391
+ _P = ParamSpec("_P")
392
+ _R_co = TypeVar("_R_co", covariant=True)
392
393
 
394
+
395
+ class _RunnableSupportCallable(Protocol[_P, _R_co]):
393
396
  @overload
394
- def wrapped(
395
- messages: Sequence[MessageLikeRepresentation], **kwargs: Any
396
- ) -> list[BaseMessage]: ...
397
+ def __call__(
398
+ self,
399
+ messages: None = None,
400
+ *args: _P.args,
401
+ **kwargs: _P.kwargs,
402
+ ) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
397
403
 
404
+ @overload
405
+ def __call__(
406
+ self,
407
+ messages: Sequence[MessageLikeRepresentation] | PromptValue,
408
+ *args: _P.args,
409
+ **kwargs: _P.kwargs,
410
+ ) -> _R_co: ...
411
+
412
+ def __call__(
413
+ self,
414
+ messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
415
+ *args: _P.args,
416
+ **kwargs: _P.kwargs,
417
+ ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
418
+
419
+
420
+ def _runnable_support(
421
+ func: Callable[
422
+ Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
423
+ ],
424
+ ) -> _RunnableSupportCallable[_P, _R_co]:
425
+ @wraps(func)
398
426
  def wrapped(
399
- messages: Sequence[MessageLikeRepresentation] | None = None,
400
- **kwargs: Any,
401
- ) -> (
402
- list[BaseMessage]
403
- | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
404
- ):
427
+ messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
428
+ *args: _P.args,
429
+ **kwargs: _P.kwargs,
430
+ ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
405
431
  # Import locally to prevent circular import.
406
432
  from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
407
433
 
408
434
  if messages is not None:
409
- return func(messages, **kwargs)
435
+ return func(messages, *args, **kwargs)
410
436
  return RunnableLambda(partial(func, **kwargs), name=func.__name__)
411
437
 
412
- wrapped.__doc__ = func.__doc__
413
- return wrapped
438
+ return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)
414
439
 
415
440
 
416
441
  @_runnable_support
@@ -738,8 +763,10 @@ def trim_messages(
738
763
  Set to `len` to count the number of **messages** in the chat history.
739
764
 
740
765
  !!! note
766
+
741
767
  Use `count_tokens_approximately` to get fast, approximate token
742
768
  counts.
769
+
743
770
  This is recommended for using `trim_messages` on the hot path, where
744
771
  exact token counting is not necessary.
745
772
 
@@ -47,22 +47,24 @@ def parse_tool_call(
47
47
  """
48
48
  if "function" not in raw_tool_call:
49
49
  return None
50
+
51
+ arguments = raw_tool_call["function"]["arguments"]
52
+
50
53
  if partial:
51
54
  try:
52
- function_args = parse_partial_json(
53
- raw_tool_call["function"]["arguments"], strict=strict
54
- )
55
+ function_args = parse_partial_json(arguments, strict=strict)
55
56
  except (JSONDecodeError, TypeError): # None args raise TypeError
56
57
  return None
58
+ # Handle None or empty string arguments for parameter-less tools
59
+ elif not arguments:
60
+ function_args = {}
57
61
  else:
58
62
  try:
59
- function_args = json.loads(
60
- raw_tool_call["function"]["arguments"], strict=strict
61
- )
63
+ function_args = json.loads(arguments, strict=strict)
62
64
  except JSONDecodeError as e:
63
65
  msg = (
64
66
  f"Function {raw_tool_call['function']['name']} arguments:\n\n"
65
- f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
67
+ f"{arguments}\n\nare not valid JSON. "
66
68
  f"Received JSONDecodeError {e}"
67
69
  )
68
70
  raise OutputParserException(msg) from e
@@ -37,7 +37,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
37
37
  def _parser_exception(
38
38
  self, e: Exception, json_object: dict
39
39
  ) -> OutputParserException:
40
- json_string = json.dumps(json_object)
40
+ json_string = json.dumps(json_object, ensure_ascii=False)
41
41
  name = self.pydantic_object.__name__
42
42
  msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
43
43
  return OutputParserException(msg, llm_output=json_string)
@@ -6,7 +6,33 @@ from langchain_core.output_parsers.transform import BaseTransformOutputParser
6
6
 
7
7
 
8
8
  class StrOutputParser(BaseTransformOutputParser[str]):
9
- """OutputParser that parses `LLMResult` into the top likely string."""
9
+ """Extract text content from model outputs as a string.
10
+
11
+ Converts model outputs (such as `AIMessage` or `AIMessageChunk` objects) into plain
12
+ text strings. It's the simplest output parser and is useful when you need string
13
+ responses for downstream processing, display, or storage.
14
+
15
+ Supports streaming, yielding text chunks as they're generated by the model.
16
+
17
+ Example:
18
+ ```python
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_openai import ChatOpenAI
21
+
22
+ model = ChatOpenAI(model="gpt-4o")
23
+ parser = StrOutputParser()
24
+
25
+ # Get string output from a model
26
+ message = model.invoke("Tell me a joke")
27
+ result = parser.invoke(message)
28
+ print(result) # plain string
29
+
30
+ # With streaming - use transform() to process a stream
31
+ stream = model.stream("Tell me a story")
32
+ for chunk in parser.transform(stream):
33
+ print(chunk, end="", flush=True)
34
+ ```
35
+ """
10
36
 
11
37
  @classmethod
12
38
  def is_lc_serializable(cls) -> bool:
@@ -903,23 +903,28 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
903
903
  5. A string which is shorthand for `("human", template)`; e.g.,
904
904
  `"{user_input}"`
905
905
  template_format: Format of the template.
906
- input_variables: A list of the names of the variables whose values are
907
- required as inputs to the prompt.
908
- optional_variables: A list of the names of the variables for placeholder
909
- or MessagePlaceholder that are optional.
910
-
911
- These variables are auto inferred from the prompt and user need not
912
- provide them.
913
- partial_variables: A dictionary of the partial variables the prompt
914
- template carries.
915
-
916
- Partial variables populate the template so that you don't need to pass
917
- them in every time you call the prompt.
918
- validate_template: Whether to validate the template.
919
- input_types: A dictionary of the types of the variables the prompt template
920
- expects.
921
-
922
- If not provided, all variables are assumed to be strings.
906
+ **kwargs: Additional keyword arguments passed to `BasePromptTemplate`,
907
+ including (but not limited to):
908
+
909
+ - `input_variables`: A list of the names of the variables whose values
910
+ are required as inputs to the prompt.
911
+ - `optional_variables`: A list of the names of the variables for
912
+ placeholder or `MessagePlaceholder` that are optional.
913
+
914
+ These variables are auto inferred from the prompt and user need not
915
+ provide them.
916
+
917
+ - `partial_variables`: A dictionary of the partial variables the prompt
918
+ template carries.
919
+
920
+ Partial variables populate the template so that you don't need to
921
+ pass them in every time you call the prompt.
922
+
923
+ - `validate_template`: Whether to validate the template.
924
+ - `input_types`: A dictionary of the types of the variables the prompt
925
+ template expects.
926
+
927
+ If not provided, all variables are assumed to be strings.
923
928
 
924
929
  Examples:
925
930
  Instantiation from a list of message templates:
@@ -20,65 +20,8 @@ if TYPE_CHECKING:
20
20
 
21
21
  try:
22
22
  from jinja2 import meta
23
- from jinja2.exceptions import SecurityError
24
23
  from jinja2.sandbox import SandboxedEnvironment
25
24
 
26
- class _RestrictedSandboxedEnvironment(SandboxedEnvironment):
27
- """A more restrictive Jinja2 sandbox that blocks all attribute/method access.
28
-
29
- This sandbox only allows simple variable lookups, no attribute or method access.
30
- This prevents template injection attacks via methods like parse_raw().
31
- """
32
-
33
- def is_safe_attribute(self, _obj: Any, _attr: str, _value: Any) -> bool:
34
- """Block ALL attribute access for security.
35
-
36
- Only allow accessing variables directly from the context dict,
37
- no attribute access on those objects.
38
-
39
- Args:
40
- _obj: The object being accessed (unused, always blocked).
41
- _attr: The attribute name (unused, always blocked).
42
- _value: The attribute value (unused, always blocked).
43
-
44
- Returns:
45
- False - all attribute access is blocked.
46
- """
47
- # Block all attribute access
48
- return False
49
-
50
- def is_safe_callable(self, _obj: Any) -> bool:
51
- """Block all method calls for security.
52
-
53
- Args:
54
- _obj: The object being checked (unused, always blocked).
55
-
56
- Returns:
57
- False - all callables are blocked.
58
- """
59
- return False
60
-
61
- def getattr(self, obj: Any, attribute: str) -> Any:
62
- """Override getattr to block all attribute access.
63
-
64
- Args:
65
- obj: The object.
66
- attribute: The attribute name.
67
-
68
- Returns:
69
- Never returns.
70
-
71
- Raises:
72
- SecurityError: Always, to block attribute access.
73
- """
74
- msg = (
75
- f"Access to attributes is not allowed in templates. "
76
- f"Attempted to access '{attribute}' on {type(obj).__name__}. "
77
- f"Use only simple variable names like {{{{variable}}}} "
78
- f"without dots or methods."
79
- )
80
- raise SecurityError(msg)
81
-
82
25
  _HAS_JINJA2 = True
83
26
  except ImportError:
84
27
  _HAS_JINJA2 = False
@@ -121,7 +64,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
121
64
  # Use a restricted sandbox that blocks ALL attribute/method access
122
65
  # Only simple variable lookups like {{variable}} are allowed
123
66
  # Attribute access like {{variable.attr}} or {{variable.method()}} is blocked
124
- return _RestrictedSandboxedEnvironment().from_string(template).render(**kwargs)
67
+ return SandboxedEnvironment().from_string(template).render(**kwargs)
125
68
 
126
69
 
127
70
  def validate_jinja2(template: str, input_variables: list[str]) -> None:
@@ -156,7 +99,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
156
99
  "Please install it with `pip install jinja2`."
157
100
  )
158
101
  raise ImportError(msg)
159
- env = _RestrictedSandboxedEnvironment()
102
+ env = SandboxedEnvironment()
160
103
  ast = env.parse(template)
161
104
  return meta.find_undeclared_variables(ast)
162
105
 
@@ -49,7 +49,13 @@ class StructuredPrompt(ChatPromptTemplate):
49
49
  structured_output_kwargs: additional kwargs for structured output.
50
50
  template_format: template format for the prompt.
51
51
  """
52
- schema_ = schema_ or kwargs.pop("schema")
52
+ schema_ = schema_ or kwargs.pop("schema", None)
53
+ if not schema_:
54
+ err_msg = (
55
+ "Must pass in a non-empty structured output schema. Received: "
56
+ f"{schema_}"
57
+ )
58
+ raise ValueError(err_msg)
53
59
  structured_output_kwargs = structured_output_kwargs or {}
54
60
  for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
55
61
  structured_output_kwargs[k] = kwargs.pop(k)