langchain-tool-args-validation-middleware 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ """Validate LLM tool-call arguments against each tool's schema before execution."""
2
+
3
+ from ._strip import DEFAULT_PLACEHOLDER_STRINGS, strip_empty
4
+ from ._validation import ValidationIssue
5
+ from .extras import detect_langchain_internal_ids
6
+ from .middleware import (
7
+ ExtraValidator,
8
+ OnFailure,
9
+ ToolArgsValidationError,
10
+ ToolArgsValidationMiddleware,
11
+ )
12
+
13
+ __all__ = [
14
+ "DEFAULT_PLACEHOLDER_STRINGS",
15
+ "ExtraValidator",
16
+ "OnFailure",
17
+ "ToolArgsValidationError",
18
+ "ToolArgsValidationMiddleware",
19
+ "ValidationIssue",
20
+ "detect_langchain_internal_ids",
21
+ "strip_empty",
22
+ ]
23
+
24
+ __version__ = "0.1.0"
@@ -0,0 +1,82 @@
1
+ """Recursive stripping of "empty" values from LLM-generated tool arguments.
2
+
3
+ LLMs (Gemini especially) routinely emit explicit ``null`` or empty containers
4
+ for optional fields instead of omitting them. Stripping these before validation
5
+ avoids unnecessary retries: an optional field simply becomes absent, and a
6
+ required field surfaces a clear ``'<field>' is a required property`` error.
7
+
8
+ Design note — write-back contract
9
+ ----------------------------------
10
+ When stripping is enabled the *cleaned* arguments replace the originals in the
11
+ tool call, so the cleaned version is what both validation **and tool execution**
12
+ see. This keeps "what we validated" and "what runs" identical (no soundness
13
+ gap), at the cost of mutating the model's output. That trade-off is the whole
14
+ point of stripping, but it means stripping a value that is *semantically
15
+ meaningful* (e.g. ``tags: []`` meaning "clear all tags", or ``null`` meaning
16
+ "explicitly unset") changes behaviour. Container stripping (``None``/``{}``/
17
+ ``[]``) is on by default; the far riskier string-placeholder stripping
18
+ (``"none"``, ``"N/A"``, ...) is **opt-in only**, because tokens like ``"NA"``
19
+ are legitimate values (Namibia's ISO code, "North America", ...).
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import Any
25
+
26
+ # A conservative, opt-in default set of placeholder strings. Deliberately
27
+ # excludes ambiguous real-world tokens like "na"/"nil". Callers may pass their
28
+ # own set instead. Only used when string stripping is explicitly enabled.
29
+ DEFAULT_PLACEHOLDER_STRINGS: frozenset[str] = frozenset(
30
+ {"none", "null", "undefined", '""', "''"}
31
+ )
32
+
33
+
34
+ def strip_empty(
35
+ value: Any,
36
+ *,
37
+ placeholder_strings: frozenset[str] | None = None,
38
+ ) -> Any:
39
+ """Return a copy of *value* with "empty" entries recursively removed.
40
+
41
+ Parameters
42
+ ----------
43
+ value:
44
+ The value to clean (typically a tool call's ``args`` dict, but works on
45
+ any nested dict/list structure).
46
+ placeholder_strings:
47
+ If provided, string values whose stripped/lower-cased form is in this
48
+ set are also removed. ``None`` (the default) disables string stripping
49
+ entirely — only ``None``/``{}``/``[]`` are removed.
50
+
51
+ Notes
52
+ -----
53
+ Returns a new structure; it never mutates *value* in place. The caller
54
+ decides whether to write the result back onto the tool call.
55
+ """
56
+ if isinstance(value, dict):
57
+ cleaned: dict[Any, Any] = {}
58
+ for key, val in value.items():
59
+ if _is_empty(val, placeholder_strings):
60
+ continue
61
+ cleaned[key] = strip_empty(val, placeholder_strings=placeholder_strings)
62
+ return cleaned
63
+ if isinstance(value, list):
64
+ return [
65
+ strip_empty(item, placeholder_strings=placeholder_strings)
66
+ for item in value
67
+ if not _is_empty(item, placeholder_strings)
68
+ ]
69
+ return value
70
+
71
+
72
+ def _is_empty(value: Any, placeholder_strings: frozenset[str] | None) -> bool:
73
+ """Whether *value* should be dropped during stripping."""
74
+ if value is None:
75
+ return True
76
+ if value == {} or value == []:
77
+ return True
78
+ return (
79
+ placeholder_strings is not None
80
+ and isinstance(value, str)
81
+ and value.strip().lower() in placeholder_strings
82
+ )
@@ -0,0 +1,140 @@
1
+ """Schema resolution and validation, decoupled from any single validator lib.
2
+
3
+ Both validation backends (Pydantic and JSON Schema) are normalised into a
4
+ neutral :class:`ValidationIssue` so the rest of the middleware — error
5
+ formatting, retry — works uniformly and never imports ``jsonschema`` unless a
6
+ dict-schema tool is actually present.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any, Protocol
13
+
14
+ from langchain_core.tools import BaseTool
15
+ from pydantic import BaseModel
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class ValidationIssue:
20
+ """A single normalised validation problem."""
21
+
22
+ path: list[Any]
23
+ message: str
24
+
25
+ def render(self) -> str:
26
+ loc = " → ".join(str(p) for p in self.path) or "(root)"
27
+ return f" • [{loc}] {self.message}"
28
+
29
+
30
+ class _JsonSchemaValidator(Protocol):
31
+ def iter_errors(self, instance: Any) -> Any: # pragma: no cover - structural
32
+ ...
33
+
34
+
35
+ class ToolValidator:
36
+ """Validates one tool's arguments. Either Pydantic- or JSON-Schema-backed."""
37
+
38
+ __slots__ = ("name", "_pydantic_model", "_json_validator")
39
+
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ *,
44
+ pydantic_model: type[BaseModel] | None = None,
45
+ json_validator: _JsonSchemaValidator | None = None,
46
+ ) -> None:
47
+ self.name = name
48
+ self._pydantic_model = pydantic_model
49
+ self._json_validator = json_validator
50
+
51
+ def validate(self, args: dict[str, Any]) -> list[ValidationIssue]:
52
+ if self._pydantic_model is not None:
53
+ return _validate_pydantic(self._pydantic_model, args)
54
+ if self._json_validator is not None:
55
+ return _validate_json_schema(self._json_validator, args)
56
+ return []
57
+
58
+
59
+ def resolve_validators(
60
+ tools: list[BaseTool],
61
+ *,
62
+ json_schema_validator_class: type | None,
63
+ ) -> dict[str, ToolValidator]:
64
+ """Build a name → :class:`ToolValidator` map from a list of tools.
65
+
66
+ Tools with a ``dict`` ``args_schema`` are validated via JSON Schema; tools
67
+ with a Pydantic ``BaseModel`` subclass via ``model_validate``. Tools with
68
+ neither are skipped (they pass through unvalidated). ``jsonschema`` is
69
+ imported lazily here, and only if at least one dict-schema tool exists.
70
+ """
71
+ validators: dict[str, ToolValidator] = {}
72
+ validator_cls = json_schema_validator_class
73
+
74
+ for tool in tools:
75
+ schema = getattr(tool, "args_schema", None)
76
+ if isinstance(schema, dict):
77
+ if validator_cls is None:
78
+ validator_cls = _default_json_validator_class()
79
+ validators[tool.name] = ToolValidator(
80
+ tool.name, json_validator=validator_cls(schema)
81
+ )
82
+ elif isinstance(schema, type) and issubclass(schema, BaseModel):
83
+ validators[tool.name] = ToolValidator(tool.name, pydantic_model=schema)
84
+ # else: unknown schema shape → no validator → passes through.
85
+
86
+ return validators
87
+
88
+
89
+ def _default_json_validator_class() -> type:
90
+ try:
91
+ from jsonschema import Draft7Validator
92
+ except ImportError as exc: # pragma: no cover - import guard
93
+ raise ImportError(
94
+ "A tool with a JSON Schema (dict) args_schema was provided, but "
95
+ "'jsonschema' is not installed. Install it with "
96
+ "`pip install langchain-tool-args-validation-middleware[jsonschema]`, or pass a "
97
+ "custom `json_schema_validator_class`."
98
+ ) from exc
99
+ return Draft7Validator # type: ignore[no-any-return]
100
+
101
+
102
+ def _validate_pydantic(
103
+ model: type[BaseModel], args: dict[str, Any]
104
+ ) -> list[ValidationIssue]:
105
+ # Import locally so a missing/v1 pydantic surfaces at call time, not import.
106
+ from pydantic import ValidationError
107
+
108
+ try:
109
+ model.model_validate(args)
110
+ return []
111
+ except ValidationError as exc:
112
+ return [
113
+ ValidationIssue(path=list(e["loc"]), message=str(e["msg"]))
114
+ for e in exc.errors()
115
+ ]
116
+
117
+
118
+ def _validate_json_schema(
119
+ validator: _JsonSchemaValidator, args: dict[str, Any]
120
+ ) -> list[ValidationIssue]:
121
+ issues: list[ValidationIssue] = []
122
+ for err in validator.iter_errors(args):
123
+ issues.append(
124
+ ValidationIssue(path=list(err.absolute_path), message=err.message)
125
+ )
126
+ return issues
127
+
128
+
129
+ def format_issues(tool_name: str, issues: list[ValidationIssue]) -> str:
130
+ """Build a concise, LLM-friendly description of validation errors."""
131
+ parts = [
132
+ f"Tool '{tool_name}' argument validation failed. "
133
+ "Fix the following errors and retry:"
134
+ ]
135
+ parts.extend(issue.render() for issue in issues)
136
+ parts.append(
137
+ "\nHint: if a field is optional and not needed, omit it entirely from "
138
+ "the arguments rather than setting it to null or an empty value."
139
+ )
140
+ return "\n".join(parts)
@@ -0,0 +1,44 @@
1
+ """Optional, pluggable extra validators for :data:`ExtraValidator`.
2
+
3
+ These are *not* part of the core middleware behaviour — they encode
4
+ domain-specific heuristics that some users want and others don't. Opt in by
5
+ passing them via ``extra_validators=[...]``.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ from typing import Any
12
+
13
+ # LangChain internal message IDs (``lc_<uuid4>``). LLMs sometimes lift these out
14
+ # of a ToolMessage envelope and pass them as real data identifiers.
15
+ _LANGCHAIN_ID_RE = re.compile(
16
+ r"^lc_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
17
+ re.IGNORECASE,
18
+ )
19
+
20
+
21
+ def detect_langchain_internal_ids(name: str, args: dict[str, Any]) -> list[str]:
22
+ """Flag any arg value that looks like a leaked ``lc_<uuid>`` internal ID.
23
+
24
+ Use via ``extra_validators=[detect_langchain_internal_ids]``.
25
+ """
26
+ errors: list[str] = []
27
+
28
+ def check(key: str, value: Any, *, in_list: bool) -> None:
29
+ if isinstance(value, str) and _LANGCHAIN_ID_RE.match(value):
30
+ where = " in a list" if in_list else ""
31
+ errors.append(
32
+ f"Tool '{name}' argument '{key}' contains a LangChain internal "
33
+ f"ID ('{value}'){where}. This is NOT a valid data identifier — "
34
+ "use only real resource IDs from the API response data, not IDs "
35
+ "from the tool-call metadata envelope."
36
+ )
37
+
38
+ for key, value in args.items():
39
+ if isinstance(value, list):
40
+ for item in value:
41
+ check(key, item, in_list=True)
42
+ else:
43
+ check(key, value, in_list=False)
44
+ return errors
@@ -0,0 +1,318 @@
1
+ """``ToolArgsValidationMiddleware`` — validate LLM tool-call args before execution.
2
+
3
+ The middleware wraps the model invocation (``wrap_model_call`` /
4
+ ``awrap_model_call``). After each model response it validates every tool call's
5
+ arguments against the tool's schema. On failure it appends error
6
+ ``ToolMessage``\\s and re-invokes the model so it can self-correct. The retry
7
+ loop runs entirely inside the model node, so only the final ``AIMessage`` enters
8
+ the graph state — and any human-in-the-loop step that runs *after* the model
9
+ node never sees invalid arguments.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ from collections.abc import Awaitable, Callable, Generator
16
+ from typing import Any, Literal, cast
17
+
18
+ from langchain.agents.middleware import AgentMiddleware
19
+ from langchain.agents.middleware.types import ModelRequest, ModelResponse
20
+ from langchain_core.messages import AIMessage, AnyMessage, ToolMessage
21
+ from langchain_core.tools import BaseTool
22
+
23
+ from ._strip import DEFAULT_PLACEHOLDER_STRINGS, strip_empty
24
+ from ._validation import (
25
+ ToolValidator,
26
+ format_issues,
27
+ resolve_validators,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # A user-supplied extra check: given (tool_name, args) return a list of error
33
+ # strings (empty = no problem). Lets callers plug in domain rules (e.g. catching
34
+ # leaked internal IDs) without bloating the core.
35
+ ExtraValidator = Callable[[str, "dict[str, Any]"], "list[str]"]
36
+
37
+ OnFailure = Literal["pass", "raise"]
38
+
39
+
40
+ class ToolArgsValidationError(RuntimeError):
41
+ """Raised when validation retries are exhausted and ``on_failure='raise'``."""
42
+
43
+
44
+ _BATCH_SIBLING_NOTICE = (
45
+ "This tool call was not executed because another tool call in the same "
46
+ "batch failed argument validation. Re-issue all tool calls together with "
47
+ "corrected arguments."
48
+ )
49
+
50
+
51
+ class ToolArgsValidationMiddleware(AgentMiddleware):
52
+ """Validate tool-call arguments against each tool's schema, with retry.
53
+
54
+ Parameters
55
+ ----------
56
+ tools:
57
+ Optional explicit tool list. If omitted (the default), schemas are
58
+ resolved lazily from ``request.tools`` on each call and cached by the
59
+ set of tool names — so dynamic toolsets (tools added/removed by other
60
+ middleware) stay correct rather than going stale.
61
+ max_retries:
62
+ Number of validation-retry cycles per model invocation (default ``2``).
63
+ Up to ``max_retries + 1`` model calls may be made.
64
+ strip_empty_values:
65
+ If ``True`` (default), recursively remove keys whose value is ``None``,
66
+ ``{}`` or ``[]`` before validation. The cleaned args are written back
67
+ onto the tool call, so they are also what the tool executes. See
68
+ :mod:`._strip` for the write-back contract and its caveats.
69
+ strip_placeholder_strings:
70
+ If ``True``, also strip string values that look like empty placeholders
71
+ (e.g. ``"null"``, ``"none"``). **Off by default** because tokens like
72
+ ``"NA"`` are legitimate values. Combine with ``placeholder_strings`` to
73
+ control the set. Has no effect unless ``strip_empty_values`` is ``True``.
74
+ placeholder_strings:
75
+ The set used when ``strip_placeholder_strings`` is enabled. Defaults to
76
+ a conservative built-in set.
77
+ json_schema_validator_class:
78
+ Validator class for dict-schema (MCP) tools. ``None`` (default) lazily
79
+ imports ``jsonschema.Draft7Validator``.
80
+ extra_validators:
81
+ Optional extra per-tool-call checks (see :data:`ExtraValidator`).
82
+ on_failure:
83
+ What to do after retries are exhausted with the args still invalid:
84
+ ``"pass"`` (default) returns the last response unchanged (fail open —
85
+ downstream tool error handling takes over); ``"raise"`` raises
86
+ :class:`ToolArgsValidationError`.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ *,
92
+ tools: list[BaseTool] | None = None,
93
+ max_retries: int = 2,
94
+ strip_empty_values: bool = True,
95
+ strip_placeholder_strings: bool = False,
96
+ placeholder_strings: frozenset[str] = DEFAULT_PLACEHOLDER_STRINGS,
97
+ json_schema_validator_class: type | None = None,
98
+ extra_validators: list[ExtraValidator] | None = None,
99
+ on_failure: OnFailure = "pass",
100
+ ) -> None:
101
+ super().__init__()
102
+ self._max_retries = max_retries
103
+ self._strip_empty_values = strip_empty_values
104
+ self._placeholder_strings = (
105
+ placeholder_strings if strip_placeholder_strings else None
106
+ )
107
+ self._json_schema_validator_class = json_schema_validator_class
108
+ self._extra_validators = extra_validators or []
109
+ self._on_failure = on_failure
110
+
111
+ # Cache of {frozenset(tool names) -> {tool name -> ToolValidator}}.
112
+ self._cache: dict[frozenset[str], dict[str, ToolValidator]] = {}
113
+ self._explicit: dict[str, ToolValidator] | None = (
114
+ resolve_validators(
115
+ tools, json_schema_validator_class=json_schema_validator_class
116
+ )
117
+ if tools is not None
118
+ else None
119
+ )
120
+
121
+ # ------------------------------------------------------------------ #
122
+ # Hooks
123
+ # ------------------------------------------------------------------ #
124
+
125
+ def wrap_model_call(
126
+ self,
127
+ request: ModelRequest[Any],
128
+ handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
129
+ ) -> ModelResponse[Any]:
130
+ validators = self._resolve(request)
131
+ if not validators:
132
+ return handler(request)
133
+
134
+ # Drive the shared validate/retry generator with a synchronous handler.
135
+ loop = self._validate_loop(validators, request, handler(request))
136
+ try:
137
+ retry_request = next(loop)
138
+ while True:
139
+ retry_request = loop.send(handler(retry_request))
140
+ except StopIteration as stop:
141
+ return cast(ModelResponse[Any], stop.value)
142
+
143
+ async def awrap_model_call(
144
+ self,
145
+ request: ModelRequest[Any],
146
+ handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
147
+ ) -> ModelResponse[Any]:
148
+ validators = self._resolve(request)
149
+ if not validators:
150
+ return await handler(request)
151
+
152
+ # Same generator, driven with an async handler.
153
+ loop = self._validate_loop(validators, request, await handler(request))
154
+ try:
155
+ retry_request = next(loop)
156
+ while True:
157
+ retry_request = loop.send(await handler(retry_request))
158
+ except StopIteration as stop:
159
+ return cast(ModelResponse[Any], stop.value)
160
+
161
+ # ------------------------------------------------------------------ #
162
+ # Core logic (single source of truth, shared by sync + async)
163
+ # ------------------------------------------------------------------ #
164
+
165
+ def _validate_loop(
166
+ self,
167
+ validators: dict[str, ToolValidator],
168
+ request: ModelRequest[Any],
169
+ first_response: ModelResponse[Any],
170
+ ) -> Generator[ModelRequest[Any], ModelResponse[Any], ModelResponse[Any]]:
171
+ """Validate-and-retry as a sans-I/O generator.
172
+
173
+ Yields the request to re-run and receives the resulting response back via
174
+ ``send``; returns the response to surface to the model node. Every
175
+ response — including the one from the final retry — is validated before
176
+ deciding whether retries are exhausted.
177
+ """
178
+ convo: list[AnyMessage] = list(request.messages)
179
+ response = first_response
180
+ for attempt in range(self._max_retries + 1):
181
+ ai_msg, errors = self._check(validators, response)
182
+ if not errors:
183
+ return response
184
+ if attempt == self._max_retries:
185
+ return self._exhausted(response)
186
+ # errors is non-empty only after validating tool calls on an AIMessage.
187
+ assert ai_msg is not None
188
+ self._log_retry(attempt + 1)
189
+ convo = [*convo, ai_msg, *errors]
190
+ response = yield request.override(messages=convo)
191
+ return self._exhausted(response) # unreachable; keeps types total
192
+
193
+ def _resolve(self, request: ModelRequest[Any]) -> dict[str, ToolValidator]:
194
+ if self._explicit is not None:
195
+ return self._explicit
196
+ tools: list[BaseTool] = list(getattr(request, "tools", []) or [])
197
+ key = frozenset(t.name for t in tools)
198
+ cached = self._cache.get(key)
199
+ if cached is None:
200
+ cached = resolve_validators(
201
+ tools, json_schema_validator_class=self._json_schema_validator_class
202
+ )
203
+ self._cache[key] = cached
204
+ return cached
205
+
206
+ def _check(
207
+ self, validators: dict[str, ToolValidator], response: ModelResponse[Any]
208
+ ) -> tuple[AIMessage | None, list[ToolMessage]]:
209
+ """Validate the response's tool calls.
210
+
211
+ Returns ``(ai_msg, error_messages)``. ``error_messages`` is empty when
212
+ there is nothing to retry (no AI message, no tool calls, or all valid).
213
+ """
214
+ ai_msg = _get_ai_message(response)
215
+ if ai_msg is None or not ai_msg.tool_calls:
216
+ return ai_msg, []
217
+ return ai_msg, self._validate_tool_calls(validators, ai_msg)
218
+
219
+ def _validate_tool_calls(
220
+ self, validators: dict[str, ToolValidator], ai_msg: AIMessage
221
+ ) -> list[ToolMessage]:
222
+ """Validate every tool call in *ai_msg*; return error ``ToolMessage``\\s.
223
+
224
+ Batch contract: if *any* tool call is invalid, *every* tool call in the
225
+ message gets a ``ToolMessage`` (errors for the bad ones, a "not executed"
226
+ notice for the good ones). Providers require each ``tool_call`` to have a
227
+ matching response, and the good calls have not actually run yet (we are
228
+ inside the model node), so they cannot get real results. Returns an empty
229
+ list only when all tool calls are valid.
230
+ """
231
+ error_msgs: list[ToolMessage] = []
232
+ valid_ids: list[str] = []
233
+
234
+ for tc in ai_msg.tool_calls:
235
+ name = tc.get("name") or ""
236
+ call_id = tc.get("id") or ""
237
+ args = tc.get("args") or {}
238
+
239
+ if self._strip_empty_values:
240
+ args = strip_empty(args, placeholder_strings=self._placeholder_strings)
241
+ tc["args"] = args # write-back: cleaned args are what executes
242
+
243
+ errors = self._errors_for_call(validators, name, args)
244
+ if errors:
245
+ error_msgs.append(ToolMessage(content=errors, tool_call_id=call_id))
246
+ else:
247
+ valid_ids.append(call_id)
248
+
249
+ if not error_msgs:
250
+ return []
251
+
252
+ # Every sibling tool call needs a response too (provider requirement).
253
+ error_msgs.extend(
254
+ ToolMessage(content=_BATCH_SIBLING_NOTICE, tool_call_id=cid)
255
+ for cid in valid_ids
256
+ )
257
+ return error_msgs
258
+
259
+ def _errors_for_call(
260
+ self, validators: dict[str, ToolValidator], name: str, args: dict[str, Any]
261
+ ) -> str:
262
+ """Return a formatted error string for one tool call, or ``""`` if valid.
263
+
264
+ Unknown tools (no validator) pass through. Schema issues and any
265
+ ``extra_validators`` findings are combined into one message.
266
+ """
267
+ validator = validators.get(name)
268
+ if validator is None:
269
+ # Still run extra validators on unknown tools — they may carry rules
270
+ # that don't depend on a registered schema.
271
+ extra = self._run_extra_validators(name, args)
272
+ return "\n".join(extra) if extra else ""
273
+
274
+ issues = validator.validate(args)
275
+ extra = self._run_extra_validators(name, args)
276
+
277
+ if not issues and not extra:
278
+ return ""
279
+
280
+ parts: list[str] = []
281
+ if issues:
282
+ parts.append(format_issues(name, issues))
283
+ parts.extend(extra)
284
+ logger.warning("Validation failed for tool '%s': %s", name, "; ".join(parts))
285
+ return "\n".join(parts)
286
+
287
+ def _run_extra_validators(self, name: str, args: dict[str, Any]) -> list[str]:
288
+ out: list[str] = []
289
+ for check in self._extra_validators:
290
+ out.extend(check(name, args))
291
+ return out
292
+
293
+ def _log_retry(self, attempt: int) -> None:
294
+ logger.warning(
295
+ "Tool-arg validation failed (attempt %d/%d); re-invoking model",
296
+ attempt,
297
+ self._max_retries,
298
+ )
299
+
300
+ def _exhausted(self, response: ModelResponse[Any]) -> ModelResponse[Any]:
301
+ if self._on_failure == "raise":
302
+ raise ToolArgsValidationError(
303
+ f"Tool-call arguments still invalid after {self._max_retries} "
304
+ "validation retries."
305
+ )
306
+ logger.warning(
307
+ "Tool-arg validation retries exhausted (%d); passing response through",
308
+ self._max_retries,
309
+ )
310
+ return response
311
+
312
+
313
+ def _get_ai_message(response: ModelResponse[Any]) -> AIMessage | None:
314
+ """Extract the AIMessage from a ModelResponse, if present."""
315
+ for msg in getattr(response, "result", None) or []:
316
+ if isinstance(msg, AIMessage):
317
+ return msg
318
+ return None
File without changes
@@ -0,0 +1,155 @@
1
+ Metadata-Version: 2.4
2
+ Name: langchain-tool-args-validation-middleware
3
+ Version: 0.1.0
4
+ Summary: LangChain agent middleware that validates LLM-generated tool-call arguments against each tool's schema before tool execution / HITL.
5
+ Project-URL: Homepage, https://github.com/Serjbory/langchain-tool-args-validation-middleware
6
+ Project-URL: Repository, https://github.com/Serjbory/langchain-tool-args-validation-middleware
7
+ Author: Serj
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Keywords: agents,langchain,mcp,middleware,tools,validation
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Topic :: Software Development :: Libraries
14
+ Requires-Python: >=3.10
15
+ Requires-Dist: langchain-core>=0.3.0
16
+ Requires-Dist: langchain>=1.0.0
17
+ Requires-Dist: pydantic>=2.0
18
+ Provides-Extra: dev
19
+ Requires-Dist: jsonschema>=4.0; extra == 'dev'
20
+ Requires-Dist: mypy; extra == 'dev'
21
+ Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
22
+ Requires-Dist: pytest-cov>=5.0; extra == 'dev'
23
+ Requires-Dist: pytest>=8.0; extra == 'dev'
24
+ Requires-Dist: ruff; extra == 'dev'
25
+ Provides-Extra: jsonschema
26
+ Requires-Dist: jsonschema>=4.0; extra == 'jsonschema'
27
+ Provides-Extra: test
28
+ Requires-Dist: jsonschema>=4.0; extra == 'test'
29
+ Requires-Dist: pytest-asyncio>=0.23; extra == 'test'
30
+ Requires-Dist: pytest-cov>=5.0; extra == 'test'
31
+ Requires-Dist: pytest>=8.0; extra == 'test'
32
+ Description-Content-Type: text/markdown
33
+
34
+ # langchain-tool-args-validation-middleware
35
+
36
+ A LangChain agent middleware that validates LLM-generated **tool-call arguments**
37
+ against each tool's schema **before** the tool runs (and before any
38
+ human-in-the-loop approval step). When arguments are invalid it appends error
39
+ `ToolMessage`s and re-invokes the model so it can self-correct — all inside the
40
+ model node, so only the final valid `AIMessage` ever enters the graph state.
41
+
42
+ ```bash
43
+ pip install langchain-tool-args-validation-middleware # Pydantic tools only
44
+ pip install "langchain-tool-args-validation-middleware[jsonschema]" # + MCP / dict-schema tools
45
+ ```
46
+
47
+ ## Why
48
+
49
+ LLMs frequently emit malformed tool calls: missing required fields, wrong types,
50
+ hallucinated empty values, or extra keys. Without validation those reach the
51
+ tool node and cause runtime errors or silent corruption — and in
52
+ human-in-the-loop workflows, a human is asked to approve obviously-broken
53
+ arguments. Catching this at the model boundary lets the agent fix itself in one
54
+ extra model call instead of a full agent-loop iteration.
55
+
56
+ It complements, rather than replaces, `ToolRetryMiddleware` (retries on tool
57
+ *exceptions*) and `ModelRetryMiddleware` (retries on model *exceptions*): this
58
+ one retries on *schema violations*, before execution.
59
+
60
+ ![Trace showing the middleware catching an invalid tool call and prompting the model to self-correct](https://raw.githubusercontent.com/Serjbory/langchain-tool-args-validation-middleware/main/docs/images/trace-example.jpg)
61
+
62
+ *A trace of `create_oos_alert`: the model emitted arguments that violate the
63
+ schema, the middleware rejected them with a precise error and a corrective hint,
64
+ and the model retried — all inside the model node, before the tool ran.*
65
+
66
+ ## Usage
67
+
68
+ ```python
69
+ from langchain.agents import create_agent
70
+ from langchain_tool_args_validation_middleware import ToolArgsValidationMiddleware
71
+
72
+ agent = create_agent(
73
+ model,
74
+ tools=tools,
75
+ middleware=[ToolArgsValidationMiddleware()], # resolves schemas from the agent's tools
76
+ )
77
+ ```
78
+
79
+ Both validation paths are supported automatically:
80
+
81
+ - **Pydantic tools** (`@tool`, or any tool with a `BaseModel` `args_schema`) →
82
+ validated with `BaseModel.model_validate`.
83
+ - **MCP / dict-schema tools** (`args_schema` is a raw JSON Schema `dict`) →
84
+ validated with `jsonschema` (soft dependency, `Draft7Validator` by default).
85
+
86
+ Unknown tools (no resolvable schema) pass through unvalidated.
87
+
88
+ ## Configuration
89
+
90
+ | Parameter | Default | Description |
91
+ |---|---|---|
92
+ | `tools` | `None` | Explicit tool list. If omitted, schemas are resolved lazily from `request.tools` and cached by tool-name set (handles dynamic toolsets). |
93
+ | `max_retries` | `2` | Validation-retry cycles per model invocation (up to `max_retries + 1` model calls). |
94
+ | `strip_empty_values` | `True` | Recursively drop `None` / `{}` / `[]` before validation. |
95
+ | `strip_placeholder_strings` | `False` | Also drop placeholder strings like `"null"`. Off by default — see below. |
96
+ | `placeholder_strings` | conservative set | Set used when string stripping is enabled. |
97
+ | `json_schema_validator_class` | `None` | Override the JSON Schema validator class. `None` → lazy `Draft7Validator`. |
98
+ | `extra_validators` | `None` | Extra `(name, args) -> list[str]` checks for domain rules. |
99
+ | `on_failure` | `"pass"` | After retries are exhausted: `"pass"` (fail open) or `"raise"`. |
100
+
101
+ ## Design decisions for the two thorniest cases
102
+
103
+ ### Batch (partial) failure
104
+
105
+ Providers (Anthropic, Gemini, OpenAI) require that **every** `tool_call` in an
106
+ assistant message receive a matching `ToolMessage` before the next turn. So when
107
+ a multi-call turn has *any* invalid call, the middleware emits:
108
+
109
+ - an **error** `ToolMessage` for each invalid call, and
110
+ - a **"not executed"** notice for each *valid* sibling call (it hasn't run yet —
111
+ we're still inside the model node — so it can't have a real result), asking the
112
+ model to re-issue the whole batch with corrected arguments.
113
+
114
+ The failed `AIMessage` is placed before these `ToolMessage`s, and failed turns
115
+ accumulate across retries so the model sees its repeated mistakes.
116
+
117
+ ### `strip_empty_values` and the write-back contract
118
+
119
+ LLMs (Gemini especially) emit explicit `null`/`{}`/`[]` for optional fields
120
+ instead of omitting them, causing needless validation failures. When stripping
121
+ is on, the **cleaned arguments replace the originals on the tool call**, so what
122
+ we validate is exactly what executes — no soundness gap between validation and
123
+ execution.
124
+
125
+ The trade-off: stripping a value that is *meaningfully empty* (e.g. `tags: []`
126
+ meaning "clear all tags", or `null` meaning "explicitly unset") changes
127
+ behaviour. Container stripping (`None`/`{}`/`[]`) is on by default because it's
128
+ usually safe. **String-placeholder stripping is opt-in only** — tokens like
129
+ `"NA"` (Namibia's ISO code) are legitimate values and must never be dropped
130
+ silently. Enable it deliberately with `strip_placeholder_strings=True` and a set
131
+ you control.
132
+
133
+ ### Fail-open
134
+
135
+ After `max_retries`, the default `on_failure="pass"` returns the last response
136
+ unchanged — the (still-invalid) args reach the tool node, where normal tool
137
+ error handling takes over. This makes the middleware best-effort
138
+ self-correction, not a hard guarantee. Use `on_failure="raise"` if you'd rather
139
+ surface a `ToolArgsValidationError`.
140
+
141
+ ## Extra validators
142
+
143
+ Plug in domain rules without touching core behaviour. A bundled example flags
144
+ LangChain internal message IDs (`lc_<uuid>`) that LLMs sometimes mistake for
145
+ real data identifiers:
146
+
147
+ ```python
148
+ from langchain_tool_args_validation_middleware import detect_langchain_internal_ids
149
+
150
+ ToolArgsValidationMiddleware(extra_validators=[detect_langchain_internal_ids])
151
+ ```
152
+
153
+ ## License
154
+
155
+ MIT
@@ -0,0 +1,10 @@
1
+ langchain_tool_args_validation_middleware/__init__.py,sha256=9RJcLk_PWeOVdsZg-qXm0b4saxOcfb62JeCetxpj3ws,621
2
+ langchain_tool_args_validation_middleware/_strip.py,sha256=CcIqt0K73x4jYjMeHg7YHwuicek39UcFMFD19uzEGho,3285
3
+ langchain_tool_args_validation_middleware/_validation.py,sha256=ylL2CD-6STAjVC50X2yZ8Q-9TmJhEiHeWF11Rlgzlcg,4829
4
+ langchain_tool_args_validation_middleware/extras.py,sha256=FAM_XyFLbZQ-5WPxid0k6BE57lJSRy-SIZxG1sQHTfg,1629
5
+ langchain_tool_args_validation_middleware/middleware.py,sha256=ezcKaStT3lvChXSHdnR_K3yWUsEgsI5P7s5NDI2cGBI,13129
6
+ langchain_tool_args_validation_middleware/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ langchain_tool_args_validation_middleware-0.1.0.dist-info/METADATA,sha256=6wE1povvfpTFqOLUwMl5rE34nTRvsINAiBV7e9_jv4U,7166
8
+ langchain_tool_args_validation_middleware-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
9
+ langchain_tool_args_validation_middleware-0.1.0.dist-info/licenses/LICENSE,sha256=-qbRwFG05BhSnZR2O8BvzMqyUjiU_lIMPnyp1pUuvms,1061
10
+ langchain_tool_args_validation_middleware-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Serj
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.