digmind-sdk 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.4
2
+ Name: digmind-sdk
3
+ Version: 0.1.0
4
+ Summary: DigMind Function Tool SDK for Python
5
+ Author: DigMind Team
6
+ License: MIT
7
+ Project-URL: Homepage, https://digmind.ai
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.11
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: httpx>=0.28.0
14
+ Requires-Dist: pydantic>=2.0.0
15
+
16
+ # DigMind SDK
17
+
18
+ **DigMind SDK** is the official Python library for developing **Function Tools** and orchestrating agent workflows within the [DigMind Platform](https://digmind.ai).
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ pip install digmind-sdk
24
+ ```
25
+
26
+ ## Creating a Function Tool
27
+
28
+ Function Tools allow you to construct composite tools locally using Python and connect them securely to DigMind's execution engine via the `ToolBridge`.
29
+
30
+ ### Usage Example
31
+
32
+ ```python
33
+ from pydantic import BaseModel
34
+ from digmind_sdk import FunctionToolContext, tool_model
35
+
36
+ # 1. Define the Pydantic schema and bind it to a platform Tool Reference
37
+ @tool_model("rss.list_feeds")
38
+ class ListFeeds(BaseModel):
39
+ source_ids: list[str]
40
+
41
+ # 2. Implement the entry point function
42
+ def run(ctx: FunctionToolContext) -> dict:
43
+ # Safely call the platform tools via context
44
+ ctx.progress(0.1, "Fetching RSS feeds...")
45
+ feeds = ctx.call(ListFeeds(source_ids=["tech-blog"]))
46
+
47
+ # Process results locally...
48
+ ctx.progress(1.0, "Complete!")
49
+ return {"status": "success", "data": feeds}
50
+ ```
51
+
52
+ ## Features
53
+
54
+ - **Pydantic Validation**: Ensure type-safety immediately before making network requests.
55
+ - **Progress Tracking**: Send real-time partial completion updates back to the DigMind web UI.
56
+ - **Parallel Dispatch**: Run multiple tool calls simultaneously (`ctx.parallel_call`) to drastically reduce overall execution latency.
57
+
58
+ ## Requirements
59
+
60
+ - Python 3.11+
61
+ - You must connect this SDK running instance to a valid DigMind platform context (normally provided automatically when executed dynamically within DigMind Sandbox environments).
@@ -0,0 +1,46 @@
1
+ # DigMind SDK
2
+
3
+ **DigMind SDK** is the official Python library for developing **Function Tools** and orchestrating agent workflows within the [DigMind Platform](https://digmind.ai).
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install digmind-sdk
9
+ ```
10
+
11
+ ## Creating a Function Tool
12
+
13
+ Function Tools allow you to construct composite tools locally using Python and connect them securely to DigMind's execution engine via the `ToolBridge`.
14
+
15
+ ### Usage Example
16
+
17
+ ```python
18
+ from pydantic import BaseModel
19
+ from digmind_sdk import FunctionToolContext, tool_model
20
+
21
+ # 1. Define the Pydantic schema and bind it to a platform Tool Reference
22
+ @tool_model("rss.list_feeds")
23
+ class ListFeeds(BaseModel):
24
+ source_ids: list[str]
25
+
26
+ # 2. Implement the entry point function
27
+ def run(ctx: FunctionToolContext) -> dict:
28
+ # Safely call the platform tools via context
29
+ ctx.progress(0.1, "Fetching RSS feeds...")
30
+ feeds = ctx.call(ListFeeds(source_ids=["tech-blog"]))
31
+
32
+ # Process results locally...
33
+ ctx.progress(1.0, "Complete!")
34
+ return {"status": "success", "data": feeds}
35
+ ```
36
+
37
+ ## Features
38
+
39
+ - **Pydantic Validation**: Ensure type-safety immediately before making network requests.
40
+ - **Progress Tracking**: Send real-time partial completion updates back to the DigMind web UI.
41
+ - **Parallel Dispatch**: Run multiple tool calls simultaneously (`ctx.parallel_call`) to drastically reduce overall execution latency.
42
+
43
+ ## Requirements
44
+
45
+ - Python 3.11+
46
+ - You must connect this SDK running instance to a valid DigMind platform context (normally provided automatically when executed dynamically within DigMind Sandbox environments).
@@ -0,0 +1,24 @@
1
+ """
2
+ digmind_sdk — Function Tool SDK for DigMind Sandbox
3
+
4
+ Provides the runtime context and utilities for building Function Tools.
5
+ Function Tools are Python-based composite tools that Agent can create
6
+ to orchestrate platform tool calls with custom logic.
7
+
8
+ Usage in main.py:
9
+ from digmind_sdk import FunctionToolContext, tool_model
10
+
11
+ @tool_model("rss.list_feeds")
12
+ class ListFeeds(BaseModel):
13
+ source_ids: list[str]
14
+
15
+ def run(ctx: FunctionToolContext) -> dict:
16
+ feeds = ctx.call(ListFeeds(source_ids=["s1"]))
17
+ return {"feeds": feeds}
18
+ """
19
+
20
+ from digmind_sdk.context import FunctionToolContext
21
+ from digmind_sdk.decorators import tool_model
22
+
23
+ __all__ = ["FunctionToolContext", "tool_model"]
24
+ __version__ = "0.1.0"
@@ -0,0 +1,256 @@
1
+ """
2
+ FunctionToolContext — Runtime context for Function Tools.
3
+
4
+ Provides safe, validated tool invocation via ToolBridge HTTP API.
5
+ All tool calls are routed through the existing ``/api/v1/toolbridge/call``
6
+ endpoint, ensuring consistent auth, rate-limiting, and observability.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import logging
13
+ import os
14
+ import time
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ from typing import Any, Optional
17
+
18
+ import httpx
19
+ from pydantic import BaseModel
20
+
21
+ logger = logging.getLogger("digmind_sdk")
22
+
23
+
24
+ class FunctionToolContext:
25
+ """Runtime context injected into every Function Tool's ``run(ctx)`` entry.
26
+
27
+ Attributes:
28
+ inputs: Validated input dict (parsed from tool.yaml's input schema).
29
+ logger: Structured logger scoped to this function tool execution.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ inputs: dict[str, Any],
36
+ allowed_tools: list[str] | None = None,
37
+ api_url: str | None = None,
38
+ api_key: str | None = None,
39
+ tool_name: str = "unknown",
40
+ progress_callback: Any | None = None,
41
+ timeout: int = 120,
42
+ max_tool_calls: int = 50,
43
+ ):
44
+ self.inputs = inputs
45
+ self.logger = logging.getLogger(f"digmind_sdk.fn.{tool_name}")
46
+
47
+ self._allowed_tools = set(allowed_tools) if allowed_tools else None
48
+ self._api_url = api_url or os.environ.get("DIGMIND_API_URL", "http://localhost:8000")
49
+ self._api_key = api_key or os.environ.get("DIGMIND_API_KEY", "")
50
+ self._tool_name = tool_name
51
+ self._progress_cb = progress_callback
52
+ self._timeout = timeout
53
+ self._max_tool_calls = max_tool_calls
54
+
55
+ # Metrics
56
+ self._call_count = 0
57
+ self._call_log: list[dict[str, Any]] = []
58
+
59
+ # ──────────────────────────────────────────────────────────────
60
+ # Primary API: ctx.call(Model)
61
+ # ──────────────────────────────────────────────────────────────
62
+
63
+ def call(self, model: BaseModel) -> Any:
64
+ """Invoke a tool via its Pydantic Model (recommended).
65
+
66
+ The model class must be decorated with ``@tool_model("...")``
67
+ to bind it to a tool reference. Pydantic validates the
68
+ arguments at construction time; this method serialises and
69
+ dispatches the call.
70
+
71
+ Args:
72
+ model: A ``@tool_model``-decorated Pydantic instance.
73
+
74
+ Returns:
75
+ Tool result (deserialized JSON).
76
+
77
+ Raises:
78
+ ValueError: If the model is not decorated with @tool_model.
79
+ ToolCallError: If the remote tool call fails.
80
+ """
81
+ tool_ref = getattr(model.__class__, "__tool_ref__", None)
82
+ if not tool_ref:
83
+ raise ValueError(
84
+ f"{model.__class__.__name__} is not a @tool_model. "
85
+ f"Decorate it with @tool_model('tool.ref') first."
86
+ )
87
+ return self.tool_call(tool_ref, **model.model_dump(exclude_none=True))
88
+
89
+ # ──────────────────────────────────────────────────────────────
90
+ # Fallback API: ctx.tool_call(name, **kwargs)
91
+ # ──────────────────────────────────────────────────────────────
92
+
93
+ def tool_call(self, tool_name: str, **kwargs: Any) -> Any:
94
+ """Invoke a tool by name (fallback for ad-hoc calls).
95
+
96
+ Subject to the ``required_tools`` whitelist defined in
97
+ ``tool.yaml``.
98
+
99
+ Args:
100
+ tool_name: Tool reference (e.g. ``"rss.list_feeds"``).
101
+ **kwargs: Tool arguments.
102
+
103
+ Returns:
104
+ Tool result (deserialized JSON).
105
+ """
106
+ # Whitelist check
107
+ if self._allowed_tools and tool_name not in self._allowed_tools:
108
+ raise PermissionError(
109
+ f"Tool '{tool_name}' is not in required_tools whitelist. "
110
+ f"Allowed tools: {sorted(self._allowed_tools)}"
111
+ )
112
+
113
+ # Rate limit
114
+ if self._call_count >= self._max_tool_calls:
115
+ raise RuntimeError(
116
+ f"Max tool calls ({self._max_tool_calls}) exceeded. "
117
+ f"Increase runtime.max_tool_calls in tool.yaml if needed."
118
+ )
119
+
120
+ self._call_count += 1
121
+ start = time.monotonic()
122
+
123
+ try:
124
+ result = self._http_call(tool_name, kwargs)
125
+ elapsed = time.monotonic() - start
126
+ self._call_log.append({
127
+ "tool": tool_name,
128
+ "elapsed_ms": round(elapsed * 1000),
129
+ "success": True,
130
+ })
131
+ return result
132
+ except Exception as e:
133
+ elapsed = time.monotonic() - start
134
+ self._call_log.append({
135
+ "tool": tool_name,
136
+ "elapsed_ms": round(elapsed * 1000),
137
+ "success": False,
138
+ "error": str(e),
139
+ })
140
+ raise
141
+
142
+ # ──────────────────────────────────────────────────────────────
143
+ # Parallel API: ctx.parallel_call([Model, Model, ...])
144
+ # ──────────────────────────────────────────────────────────────
145
+
146
+ def parallel_call(
147
+ self,
148
+ models: list[BaseModel],
149
+ *,
150
+ max_concurrency: int = 5,
151
+ ) -> list[Any]:
152
+ """Execute multiple tool calls in parallel.
153
+
154
+ Each model in the list must be a ``@tool_model``-decorated
155
+ Pydantic instance. Results are returned in the same order
156
+ as the input list.
157
+
158
+ Args:
159
+ models: List of tool model instances.
160
+ max_concurrency: Max concurrent HTTP requests.
161
+
162
+ Returns:
163
+ Ordered list of results (or dicts with ``error`` key on failure).
164
+ """
165
+ results: list[Any] = [None] * len(models)
166
+
167
+ with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
168
+ future_to_idx = {
169
+ pool.submit(self.call, m): i for i, m in enumerate(models)
170
+ }
171
+ for future in as_completed(future_to_idx):
172
+ idx = future_to_idx[future]
173
+ try:
174
+ results[idx] = future.result()
175
+ except Exception as e:
176
+ self.logger.warning(f"Parallel call [{idx}] failed: {e}")
177
+ results[idx] = {"error": str(e)}
178
+
179
+ return results
180
+
181
+ # ──────────────────────────────────────────────────────────────
182
+ # Progress reporting
183
+ # ──────────────────────────────────────────────────────────────
184
+
185
+ def progress(self, fraction: float, message: str = "") -> None:
186
+ """Report execution progress (0.0 ~ 1.0).
187
+
188
+ Progress is surfaced in DAG visualisation and CLI output.
189
+
190
+ Args:
191
+ fraction: Progress fraction (0.0 to 1.0).
192
+ message: Optional human-readable status message.
193
+ """
194
+ clamped = max(0.0, min(1.0, fraction))
195
+ self.logger.info(f"[{clamped:.0%}] {message}")
196
+ if self._progress_cb:
197
+ try:
198
+ self._progress_cb(clamped, message)
199
+ except Exception:
200
+ pass # Progress reporting is best-effort
201
+
202
+ # ──────────────────────────────────────────────────────────────
203
+ # Metrics
204
+ # ──────────────────────────────────────────────────────────────
205
+
206
+ def get_metrics(self) -> dict[str, Any]:
207
+ """Return execution metrics for benchmarking / diagnostics."""
208
+ total_ms = sum(c["elapsed_ms"] for c in self._call_log)
209
+ by_tool: dict[str, int] = {}
210
+ for c in self._call_log:
211
+ by_tool[c["tool"]] = by_tool.get(c["tool"], 0) + 1
212
+ errors = sum(1 for c in self._call_log if not c["success"])
213
+ return {
214
+ "total_calls": self._call_count,
215
+ "total_ms": total_ms,
216
+ "errors": errors,
217
+ "by_tool": by_tool,
218
+ "call_log": self._call_log,
219
+ }
220
+
221
+ # ──────────────────────────────────────────────────────────────
222
+ # Internal: HTTP transport
223
+ # ──────────────────────────────────────────────────────────────
224
+
225
+ def _http_call(self, tool_name: str, arguments: dict[str, Any]) -> Any:
226
+ """Execute a single tool call via ToolBridge HTTP API."""
227
+ url = f"{self._api_url}/api/v1/toolbridge/call"
228
+ headers = {"Authorization": f"Bearer {self._api_key}"}
229
+
230
+ payload = {
231
+ "tool_name": tool_name,
232
+ "arguments": arguments,
233
+ }
234
+
235
+ self.logger.debug(f"Calling {tool_name} → {url}")
236
+
237
+ with httpx.Client(timeout=self._timeout) as client:
238
+ resp = client.post(url, json=payload, headers=headers)
239
+
240
+ if resp.status_code != 200:
241
+ error_detail = ""
242
+ try:
243
+ error_detail = resp.json().get("detail", resp.text)
244
+ except Exception:
245
+ error_detail = resp.text
246
+ raise ToolCallError(
247
+ f"Tool '{tool_name}' failed (HTTP {resp.status_code}): {error_detail}"
248
+ )
249
+
250
+ data = resp.json()
251
+ return data.get("result", data)
252
+
253
+
254
+ class ToolCallError(Exception):
255
+ """Raised when a remote tool call fails."""
256
+ pass
@@ -0,0 +1,32 @@
1
+ """
2
+ tool_model decorator — Binds a Pydantic BaseModel to a tool reference.
3
+
4
+ Usage:
5
+ from pydantic import BaseModel, Field
6
+ from digmind_sdk import tool_model
7
+
8
+ @tool_model("rss.list_feeds")
9
+ class ListFeeds(BaseModel):
10
+ source_ids: list[str] = Field(description="Source IDs to query")
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Any
16
+
17
+
18
+ def tool_model(tool_ref: str):
19
+ """Bind a Pydantic Model class to a ToolBridge tool reference.
20
+
21
+ The decorated class gains a ``__tool_ref__`` attribute that
22
+ ``FunctionToolContext.call()`` uses to route the invocation.
23
+
24
+ Args:
25
+ tool_ref: Fully-qualified tool reference, e.g. ``"rss.list_feeds"``.
26
+ """
27
+
28
+ def decorator(cls: Any) -> Any:
29
+ cls.__tool_ref__ = tool_ref
30
+ return cls
31
+
32
+ return decorator
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.4
2
+ Name: digmind-sdk
3
+ Version: 0.1.0
4
+ Summary: DigMind Function Tool SDK for Python
5
+ Author: DigMind Team
6
+ License: MIT
7
+ Project-URL: Homepage, https://digmind.ai
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.11
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: httpx>=0.28.0
14
+ Requires-Dist: pydantic>=2.0.0
15
+
16
+ # DigMind SDK
17
+
18
+ **DigMind SDK** is the official Python library for developing **Function Tools** and orchestrating agent workflows within the [DigMind Platform](https://digmind.ai).
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ pip install digmind-sdk
24
+ ```
25
+
26
+ ## Creating a Function Tool
27
+
28
+ Function Tools allow you to construct composite tools locally using Python and connect them securely to DigMind's execution engine via the `ToolBridge`.
29
+
30
+ ### Usage Example
31
+
32
+ ```python
33
+ from pydantic import BaseModel
34
+ from digmind_sdk import FunctionToolContext, tool_model
35
+
36
+ # 1. Define the Pydantic schema and bind it to a platform Tool Reference
37
+ @tool_model("rss.list_feeds")
38
+ class ListFeeds(BaseModel):
39
+ source_ids: list[str]
40
+
41
+ # 2. Implement the entry point function
42
+ def run(ctx: FunctionToolContext) -> dict:
43
+ # Safely call the platform tools via context
44
+ ctx.progress(0.1, "Fetching RSS feeds...")
45
+ feeds = ctx.call(ListFeeds(source_ids=["tech-blog"]))
46
+
47
+ # Process results locally...
48
+ ctx.progress(1.0, "Complete!")
49
+ return {"status": "success", "data": feeds}
50
+ ```
51
+
52
+ ## Features
53
+
54
+ - **Pydantic Validation**: Ensure type-safety immediately before making network requests.
55
+ - **Progress Tracking**: Send real-time partial completion updates back to the DigMind web UI.
56
+ - **Parallel Dispatch**: Run multiple tool calls simultaneously (`ctx.parallel_call`) to drastically reduce overall execution latency.
57
+
58
+ ## Requirements
59
+
60
+ - Python 3.11+
61
+ - You must connect this SDK running instance to a valid DigMind platform context (normally provided automatically when executed dynamically within DigMind Sandbox environments).
@@ -0,0 +1,15 @@
1
+ README.md
2
+ __init__.py
3
+ context.py
4
+ decorators.py
5
+ pyproject.toml
6
+ testing.py
7
+ ./__init__.py
8
+ ./context.py
9
+ ./decorators.py
10
+ ./testing.py
11
+ digmind_sdk.egg-info/PKG-INFO
12
+ digmind_sdk.egg-info/SOURCES.txt
13
+ digmind_sdk.egg-info/dependency_links.txt
14
+ digmind_sdk.egg-info/requires.txt
15
+ digmind_sdk.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ httpx>=0.28.0
2
+ pydantic>=2.0.0
@@ -0,0 +1 @@
1
+ digmind_sdk
@@ -0,0 +1,30 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "digmind-sdk"
7
+ version = "0.1.0"
8
+ description = "DigMind Function Tool SDK for Python"
9
+ readme = "README.md"
10
+ authors = [
11
+ { name = "DigMind Team" }
12
+ ]
13
+ requires-python = ">=3.11"
14
+ dependencies = [
15
+ "httpx>=0.28.0",
16
+ "pydantic>=2.0.0"
17
+ ]
18
+ license = { text = "MIT" }
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Developers",
22
+ "Programming Language :: Python :: 3",
23
+ ]
24
+
25
+ [project.urls]
26
+ Homepage = "https://digmind.ai"
27
+
28
+ [tool.setuptools]
29
+ package-dir = {"digmind_sdk" = "."}
30
+ packages = ["digmind_sdk"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,186 @@
1
+ """
2
+ MockContext — Test helper for Function Tools.
3
+
4
+ Provides a drop-in replacement for FunctionToolContext that records
5
+ tool calls and returns mock data instead of hitting the real API.
6
+
7
+ Usage in tests/test_main.py:
8
+ from digmind_sdk.testing import MockContext
9
+ from _tools import ListFeeds, Fetch
10
+
11
+ def test_basic():
12
+ ctx = MockContext(inputs={"source_ids": ["s1"]})
13
+ ctx.mock(ListFeeds, return_value=[{"url": "https://a.com"}])
14
+ ctx.mock(Fetch, side_effect=[{"content": "hello"}])
15
+
16
+ from main import run
17
+ result = run(ctx)
18
+
19
+ assert ctx.call_count(ListFeeds) == 1
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ from typing import Any, Optional
26
+ from unittest.mock import MagicMock
27
+
28
+ from pydantic import BaseModel
29
+
30
+
31
+ class MockContext:
32
+ """Test double for FunctionToolContext.
33
+
34
+ Records all tool calls and returns pre-configured mock data.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ inputs: dict[str, Any] | None = None,
40
+ allowed_tools: list[str] | None = None,
41
+ ):
42
+ self.inputs = inputs or {}
43
+ self.logger = logging.getLogger("digmind_sdk.test")
44
+
45
+ self._allowed_tools = set(allowed_tools) if allowed_tools else None
46
+ self._mocks: dict[str, _MockEntry] = {} # tool_ref -> MockEntry
47
+ self._calls: dict[str, list[dict]] = {} # tool_ref -> list of call records
48
+ self._progress_log: list[tuple[float, str]] = []
49
+
50
+ # ──────────────────────────────────────────────────────────────
51
+ # Mock registration
52
+ # ──────────────────────────────────────────────────────────────
53
+
54
+ def mock(
55
+ self,
56
+ model_cls: type[BaseModel] | str,
57
+ *,
58
+ return_value: Any = None,
59
+ side_effect: Any = None,
60
+ ) -> None:
61
+ """Register a mock for a @tool_model class or tool name.
62
+
63
+ Args:
64
+ model_cls: A @tool_model-decorated class, or a tool_ref string.
65
+ return_value: Fixed return value for every call.
66
+ side_effect: List of values (popped per call) or callable.
67
+ """
68
+ tool_ref = self._resolve_ref(model_cls)
69
+ self._mocks[tool_ref] = _MockEntry(
70
+ return_value=return_value,
71
+ side_effect=side_effect,
72
+ )
73
+
74
+ # ──────────────────────────────────────────────────────────────
75
+ # Same API as FunctionToolContext
76
+ # ──────────────────────────────────────────────────────────────
77
+
78
+ def call(self, model: BaseModel) -> Any:
79
+ """Mock version of ctx.call(Model)."""
80
+ tool_ref = getattr(model.__class__, "__tool_ref__", None)
81
+ if not tool_ref:
82
+ raise ValueError(
83
+ f"{model.__class__.__name__} is not a @tool_model"
84
+ )
85
+ args = model.model_dump(exclude_none=True)
86
+ return self._dispatch(tool_ref, args)
87
+
88
+ def tool_call(self, tool_name: str, **kwargs: Any) -> Any:
89
+ """Mock version of ctx.tool_call(name, **kwargs)."""
90
+ return self._dispatch(tool_name, kwargs)
91
+
92
+ def parallel_call(
93
+ self,
94
+ models: list[BaseModel],
95
+ *,
96
+ max_concurrency: int = 5,
97
+ ) -> list[Any]:
98
+ """Mock version of ctx.parallel_call([Model, ...])."""
99
+ return [self.call(m) for m in models]
100
+
101
+ def progress(self, fraction: float, message: str = "") -> None:
102
+ """Record progress call."""
103
+ self._progress_log.append((fraction, message))
104
+
105
+ # ──────────────────────────────────────────────────────────────
106
+ # Assertions
107
+ # ──────────────────────────────────────────────────────────────
108
+
109
+ def call_count(self, model_cls: type[BaseModel] | str) -> int:
110
+ """Return number of times a tool was called."""
111
+ tool_ref = self._resolve_ref(model_cls)
112
+ return len(self._calls.get(tool_ref, []))
113
+
114
+ def call_args(self, model_cls: type[BaseModel] | str) -> list[dict]:
115
+ """Return list of argument dicts for each call to a tool."""
116
+ tool_ref = self._resolve_ref(model_cls)
117
+ return self._calls.get(tool_ref, [])
118
+
119
+ def get_metrics(self) -> dict[str, Any]:
120
+ """Return mock metrics (compatible with FunctionToolContext)."""
121
+ total = sum(len(v) for v in self._calls.values())
122
+ by_tool = {k: len(v) for k, v in self._calls.items()}
123
+ return {
124
+ "total_calls": total,
125
+ "total_ms": 0,
126
+ "errors": 0,
127
+ "by_tool": by_tool,
128
+ "call_log": [],
129
+ }
130
+
131
+ # ──────────────────────────────────────────────────────────────
132
+ # Internal
133
+ # ──────────────────────────────────────────────────────────────
134
+
135
+ def _resolve_ref(self, model_cls_or_str: type[BaseModel] | str) -> str:
136
+ if isinstance(model_cls_or_str, str):
137
+ return model_cls_or_str
138
+ ref = getattr(model_cls_or_str, "__tool_ref__", None)
139
+ if not ref:
140
+ raise ValueError(
141
+ f"{model_cls_or_str.__name__} is not a @tool_model"
142
+ )
143
+ return ref
144
+
145
+ def _dispatch(self, tool_ref: str, args: dict) -> Any:
146
+ # Record the call
147
+ self._calls.setdefault(tool_ref, []).append(args)
148
+
149
+ # Find mock
150
+ entry = self._mocks.get(tool_ref)
151
+ if entry is None:
152
+ raise ValueError(
153
+ f"No mock registered for '{tool_ref}'. "
154
+ f"Call ctx.mock({tool_ref}, return_value=...) first."
155
+ )
156
+
157
+ return entry.next_value()
158
+
159
+
160
+ class _MockEntry:
161
+ """Internal mock state for a single tool."""
162
+
163
+ def __init__(
164
+ self,
165
+ return_value: Any = None,
166
+ side_effect: Any = None,
167
+ ):
168
+ self._return_value = return_value
169
+ self._side_effect = list(side_effect) if isinstance(side_effect, (list, tuple)) else side_effect
170
+ self._call_index = 0
171
+
172
+ def next_value(self) -> Any:
173
+ if self._side_effect is not None:
174
+ if isinstance(self._side_effect, list):
175
+ if self._call_index >= len(self._side_effect):
176
+ raise IndexError(
177
+ f"side_effect exhausted after {self._call_index} calls"
178
+ )
179
+ val = self._side_effect[self._call_index]
180
+ self._call_index += 1
181
+ return val
182
+ elif callable(self._side_effect):
183
+ return self._side_effect()
184
+ else:
185
+ return self._side_effect
186
+ return self._return_value