datarobot-genai 0.2.24__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,6 +31,124 @@ from .constants import DEFAULT_DATAROBOT_ENDPOINT
31
31
  from .constants import RUNTIME_PARAM_ENV_VAR_NAME_PREFIX
32
32
 
33
33
 
34
+ class MCPToolConfig(BaseSettings):
35
+ """Tool configuration for MCP server."""
36
+
37
+ enable_predictive_tools: bool = Field(
38
+ default=True,
39
+ validation_alias=AliasChoices(
40
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_PREDICTIVE_TOOLS",
41
+ "ENABLE_PREDICTIVE_TOOLS",
42
+ ),
43
+ description="Enable/disable predictive tools",
44
+ )
45
+
46
+ enable_jira_tools: bool = Field(
47
+ default=False,
48
+ validation_alias=AliasChoices(
49
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_JIRA_TOOLS",
50
+ "ENABLE_JIRA_TOOLS",
51
+ ),
52
+ description="Enable/disable Jira tools",
53
+ )
54
+
55
+ enable_confluence_tools: bool = Field(
56
+ default=False,
57
+ validation_alias=AliasChoices(
58
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_CONFLUENCE_TOOLS",
59
+ "ENABLE_CONFLUENCE_TOOLS",
60
+ ),
61
+ description="Enable/disable Confluence tools",
62
+ )
63
+
64
+ enable_gdrive_tools: bool = Field(
65
+ default=False,
66
+ validation_alias=AliasChoices(
67
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_GDRIVE_TOOLS",
68
+ "ENABLE_GDRIVE_TOOLS",
69
+ ),
70
+ description="Enable/disable GDrive tools",
71
+ )
72
+
73
+ enable_microsoft_graph_tools: bool = Field(
74
+ default=False,
75
+ validation_alias=AliasChoices(
76
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_MICROSOFT_GRAPH_TOOLS",
77
+ "ENABLE_MICROSOFT_GRAPH_TOOLS",
78
+ ),
79
+ description="Enable/disable Sharepoint tools",
80
+ )
81
+
82
+ is_atlassian_oauth_provider_configured: bool = Field(
83
+ default=False,
84
+ validation_alias=AliasChoices(
85
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_ATLASSIAN_OAUTH_PROVIDER_CONFIGURED",
86
+ "IS_ATLASSIAN_OAUTH_PROVIDER_CONFIGURED",
87
+ ),
88
+ description="Whether Atlassian OAuth provider is configured for Atlassian integration",
89
+ )
90
+
91
+ @property
92
+ def is_atlassian_oauth_configured(self) -> bool:
93
+ """Check if Atlassian OAuth is configured via provider flag or environment variables."""
94
+ return self.is_atlassian_oauth_provider_configured or bool(
95
+ os.getenv("ATLASSIAN_CLIENT_ID") and os.getenv("ATLASSIAN_CLIENT_SECRET")
96
+ )
97
+
98
+ is_google_oauth_provider_configured: bool = Field(
99
+ default=False,
100
+ validation_alias=AliasChoices(
101
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_GOOGLE_OAUTH_PROVIDER_CONFIGURED",
102
+ "IS_GOOGLE_OAUTH_PROVIDER_CONFIGURED",
103
+ ),
104
+ description="Whether Google OAuth provider is configured for Google integration",
105
+ )
106
+
107
+ @property
108
+ def is_google_oauth_configured(self) -> bool:
109
+ return self.is_google_oauth_provider_configured or bool(
110
+ os.getenv("GOOGLE_CLIENT_ID") and os.getenv("GOOGLE_CLIENT_SECRET")
111
+ )
112
+
113
+ is_microsoft_oauth_provider_configured: bool = Field(
114
+ default=False,
115
+ validation_alias=AliasChoices(
116
+ RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_MICROSOFT_OAUTH_PROVIDER_CONFIGURED",
117
+ "IS_MICROSOFT_OAUTH_PROVIDER_CONFIGURED",
118
+ ),
119
+ description="Whether Microsoft OAuth provider is configured for Microsoft integration",
120
+ )
121
+
122
+ @property
123
+ def is_microsoft_oauth_configured(self) -> bool:
124
+ return self.is_microsoft_oauth_provider_configured or bool(
125
+ os.getenv("MICROSOFT_CLIENT_ID") and os.getenv("MICROSOFT_CLIENT_SECRET")
126
+ )
127
+
128
+ @field_validator(
129
+ "enable_predictive_tools",
130
+ "enable_jira_tools",
131
+ "enable_confluence_tools",
132
+ "enable_gdrive_tools",
133
+ "enable_microsoft_graph_tools",
134
+ "is_atlassian_oauth_provider_configured",
135
+ "is_google_oauth_provider_configured",
136
+ "is_microsoft_oauth_provider_configured",
137
+ mode="before",
138
+ )
139
+ @classmethod
140
+ def validate_runtime_params(cls, v: Any) -> Any:
141
+ """Validate runtime parameters."""
142
+ return extract_datarobot_runtime_param_payload(v)
143
+
144
+ model_config = SettingsConfigDict(
145
+ env_file=".env",
146
+ case_sensitive=False,
147
+ env_file_encoding="utf-8",
148
+ extra="ignore",
149
+ )
150
+
151
+
34
152
  class MCPServerConfig(BaseSettings):
35
153
  """MCP Server configuration using pydantic settings."""
36
154
 
@@ -188,86 +306,11 @@ class MCPServerConfig(BaseSettings):
188
306
  ),
189
307
  description="Enable/disable memory management",
190
308
  )
191
- enable_predictive_tools: bool = Field(
192
- default=True,
193
- validation_alias=AliasChoices(
194
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_PREDICTIVE_TOOLS",
195
- "ENABLE_PREDICTIVE_TOOLS",
196
- ),
197
- description="Enable/disable predictive tools",
198
- )
199
309
 
200
- # Jira tools
201
- enable_jira_tools: bool = Field(
202
- default=False,
203
- validation_alias=AliasChoices(
204
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_JIRA_TOOLS",
205
- "ENABLE_JIRA_TOOLS",
206
- ),
207
- description="Enable/disable Jira tools",
310
+ tool_config: MCPToolConfig = Field(
311
+ default_factory=MCPToolConfig,
312
+ description="Tool configuration",
208
313
  )
209
- is_jira_oauth_provider_configured: bool = Field(
210
- default=False,
211
- validation_alias=AliasChoices(
212
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_JIRA_OAUTH_PROVIDER_CONFIGURED",
213
- "IS_JIRA_OAUTH_PROVIDER_CONFIGURED",
214
- ),
215
- description="Whether Jira OAuth provider is configured for Jira integration",
216
- )
217
-
218
- @property
219
- def is_jira_oauth_configured(self) -> bool:
220
- return self.is_jira_oauth_provider_configured or bool(
221
- os.getenv("JIRA_CLIENT_ID") and os.getenv("JIRA_CLIENT_SECRET")
222
- )
223
-
224
- # Confluence tools
225
- enable_confluence_tools: bool = Field(
226
- default=False,
227
- validation_alias=AliasChoices(
228
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_CONFLUENCE_TOOLS",
229
- "ENABLE_CONFLUENCE_TOOLS",
230
- ),
231
- description="Enable/disable Confluence tools",
232
- )
233
- is_confluence_oauth_provider_configured: bool = Field(
234
- default=False,
235
- validation_alias=AliasChoices(
236
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_CONFLUENCE_OAUTH_PROVIDER_CONFIGURED",
237
- "IS_CONFLUENCE_OAUTH_PROVIDER_CONFIGURED",
238
- ),
239
- description="Whether Confluence OAuth provider is configured for Confluence integration",
240
- )
241
-
242
- @property
243
- def is_confluence_oauth_configured(self) -> bool:
244
- return self.is_confluence_oauth_provider_configured or bool(
245
- os.getenv("CONFLUENCE_CLIENT_ID") and os.getenv("CONFLUENCE_CLIENT_SECRET")
246
- )
247
-
248
- # Gdrive tools
249
- enable_gdrive_tools: bool = Field(
250
- default=False,
251
- validation_alias=AliasChoices(
252
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "ENABLE_GDRIVE_TOOLS",
253
- "ENABLE_GDRIVE_TOOLS",
254
- ),
255
- description="Enable/disable GDrive tools",
256
- )
257
- is_gdrive_oauth_provider_configured: bool = Field(
258
- default=False,
259
- validation_alias=AliasChoices(
260
- RUNTIME_PARAM_ENV_VAR_NAME_PREFIX + "IS_GDRIVE_OAUTH_PROVIDER_CONFIGURED",
261
- "IS_GDRIVE_OAUTH_PROVIDER_CONFIGURED",
262
- ),
263
- description="Whether GDrive OAuth provider is configured for GDrive integration",
264
- )
265
-
266
- @property
267
- def is_gdrive_oauth_configured(self) -> bool:
268
- return self.is_gdrive_oauth_provider_configured or bool(
269
- os.getenv("GDRIVE_CLIENT_ID") and os.getenv("GDRIVE_CLIENT_SECRET")
270
- )
271
314
 
272
315
  @field_validator(
273
316
  "otel_attributes",
@@ -291,11 +334,6 @@ class MCPServerConfig(BaseSettings):
291
334
  "mcp_server_register_dynamic_tools_on_startup",
292
335
  "tool_registration_duplicate_behavior",
293
336
  "mcp_server_register_dynamic_prompts_on_startup",
294
- "enable_predictive_tools",
295
- "enable_jira_tools",
296
- "is_jira_oauth_provider_configured",
297
- "enable_confluence_tools",
298
- "is_confluence_oauth_provider_configured",
299
337
  mode="before",
300
338
  )
301
339
  @classmethod
@@ -31,9 +31,6 @@ from .dynamic_prompts.register import register_prompts_from_datarobot_prompt_man
31
31
  from .dynamic_tools.deployment.register import register_tools_of_datarobot_deployments
32
32
  from .logging import MCPLogging
33
33
  from .mcp_instance import mcp
34
- from .mcp_server_tools import get_all_available_tags # noqa # pylint: disable=unused-import
35
- from .mcp_server_tools import get_tool_info_by_name # noqa # pylint: disable=unused-import
36
- from .mcp_server_tools import list_tools_by_tags # noqa # pylint: disable=unused-import
37
34
  from .memory_management.manager import MemoryManager
38
35
  from .routes import register_routes
39
36
  from .routes_utils import prefix_mount_path
@@ -16,20 +16,18 @@ import logging
16
16
  from collections.abc import Callable
17
17
  from functools import wraps
18
18
  from typing import Any
19
- from typing import overload
19
+ from typing import TypedDict
20
20
 
21
21
  from fastmcp import Context
22
22
  from fastmcp import FastMCP
23
23
  from fastmcp.exceptions import NotFoundError
24
24
  from fastmcp.prompts.prompt import Prompt
25
25
  from fastmcp.server.dependencies import get_context
26
- from fastmcp.tools import FunctionTool
27
26
  from fastmcp.tools import Tool
28
- from fastmcp.utilities.types import NotSet
29
- from fastmcp.utilities.types import NotSetT
30
27
  from mcp.types import AnyFunction
31
28
  from mcp.types import Tool as MCPTool
32
29
  from mcp.types import ToolAnnotations
30
+ from typing_extensions import Unpack
33
31
 
34
32
  from .config import MCPServerConfig
35
33
  from .config import get_config
@@ -120,86 +118,6 @@ class TaggedFastMCP(FastMCP):
120
118
  "In stateless mode, clients will see changes on next request."
121
119
  )
122
120
 
123
- @overload
124
- def tool(
125
- self,
126
- name_or_fn: AnyFunction,
127
- *,
128
- name: str | None = None,
129
- title: str | None = None,
130
- description: str | None = None,
131
- tags: set[str] | None = None,
132
- output_schema: dict[str, Any] | None | NotSetT = NotSet,
133
- annotations: ToolAnnotations | dict[str, Any] | None = None,
134
- exclude_args: list[str] | None = None,
135
- meta: dict[str, Any] | None = None,
136
- enabled: bool | None = None,
137
- ) -> FunctionTool: ...
138
-
139
- @overload
140
- def tool(
141
- self,
142
- name_or_fn: str | None = None,
143
- *,
144
- name: str | None = None,
145
- title: str | None = None,
146
- description: str | None = None,
147
- tags: set[str] | None = None,
148
- output_schema: dict[str, Any] | None | NotSetT = NotSet,
149
- annotations: ToolAnnotations | dict[str, Any] | None = None,
150
- exclude_args: list[str] | None = None,
151
- meta: dict[str, Any] | None = None,
152
- enabled: bool | None = None,
153
- ) -> Callable[[AnyFunction], FunctionTool]: ...
154
-
155
- def tool(
156
- self,
157
- name_or_fn: str | Callable[..., Any] | None = None,
158
- *,
159
- name: str | None = None,
160
- title: str | None = None,
161
- description: str | None = None,
162
- tags: set[str] | None = None,
163
- output_schema: dict[str, Any] | None | NotSetT = NotSet,
164
- annotations: ToolAnnotations | dict[str, Any] | None = None,
165
- exclude_args: list[str] | None = None,
166
- meta: dict[str, Any] | None = None,
167
- enabled: bool | None = None,
168
- **kwargs: Any,
169
- ) -> Callable[[AnyFunction], FunctionTool] | FunctionTool:
170
- """
171
- Extend tool decorator that supports tags and other annotations, while remaining
172
- signature-compatible with FastMCP.tool to avoid recursion issues with partials.
173
- """
174
- if isinstance(annotations, dict):
175
- annotations = ToolAnnotations(**annotations)
176
-
177
- # Ensure tags are available both via native fastmcp `tags` and inside annotations
178
- if tags is not None:
179
- tags_ = sorted(tags)
180
- if annotations is None:
181
- annotations = ToolAnnotations() # type: ignore[call-arg]
182
- annotations.tags = tags_ # type: ignore[attr-defined, union-attr]
183
- else:
184
- # At this point, annotations is ToolAnnotations (not dict)
185
- assert isinstance(annotations, ToolAnnotations)
186
- annotations.tags = tags_ # type: ignore[attr-defined]
187
-
188
- return super().tool(
189
- name_or_fn,
190
- name=name,
191
- title=title,
192
- description=description,
193
- tags=tags,
194
- output_schema=output_schema
195
- if output_schema is not None
196
- else kwargs.get("output_schema"),
197
- annotations=annotations,
198
- exclude_args=exclude_args,
199
- meta=meta,
200
- enabled=enabled,
201
- )
202
-
203
121
  async def list_tools(
204
122
  self, tags: list[str] | None = None, match_all: bool = False
205
123
  ) -> list[MCPTool]:
@@ -371,16 +289,37 @@ mcp = TaggedFastMCP(
371
289
  )
372
290
 
373
291
 
292
+ class ToolKwargs(TypedDict, total=False):
293
+ """Keyword arguments passed through to FastMCP's mcp.tool() decorator.
294
+
295
+ All parameters are optional and forwarded directly to FastMCP tool registration.
296
+ See FastMCP documentation for full details on each parameter.
297
+ """
298
+
299
+ name: str | None
300
+ title: str | None
301
+ description: str | None
302
+ icons: list[Any] | None
303
+ tags: set[str] | None
304
+ output_schema: dict[str, Any] | None
305
+ annotations: Any | None
306
+ exclude_args: list[str] | None
307
+ meta: dict[str, Any] | None
308
+ enabled: bool | None
309
+
310
+
374
311
  def dr_core_mcp_tool(
375
- name: str | None = None,
376
- description: str | None = None,
377
- tags: set[str] | None = None,
312
+ **kwargs: Unpack[ToolKwargs],
378
313
  ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
379
- """Combine decorator that includes mcp.tool() and dr_mcp_extras()."""
314
+ """Combine decorator that includes mcp.tool() and dr_mcp_extras().
315
+
316
+ All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
317
+ See ToolKwargs for available parameters.
318
+ """
380
319
 
381
320
  def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
382
321
  instrumented = dr_mcp_extras()(func)
383
- mcp.tool(name=name, description=description, tags=tags)(instrumented)
322
+ mcp.tool(**kwargs)(instrumented)
384
323
  return instrumented
385
324
 
386
325
  return decorator
@@ -413,27 +352,23 @@ async def memory_aware_wrapper(func: Callable[..., Any], *args: Any, **kwargs: A
413
352
 
414
353
 
415
354
  def dr_mcp_tool(
416
- name: str | None = None,
417
- description: str | None = None,
418
- tags: set[str] | None = None,
355
+ **kwargs: Unpack[ToolKwargs],
419
356
  ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
420
357
  """Combine decorator that includes mcp.tool(), dr_mcp_extras(), and capture memory ids from
421
358
  the request headers if they exist.
422
359
 
423
- Args:
424
- name: Tool name
425
- description: Tool description
426
- tags: Optional set of tags to apply to the tool
360
+ All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
361
+ See ToolKwargs for available parameters.
427
362
  """
428
363
 
429
364
  def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
430
365
  @wraps(func)
431
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
432
- return await memory_aware_wrapper(func, *args, **kwargs)
366
+ async def wrapper(*args: Any, **inner_kwargs: Any) -> Any:
367
+ return await memory_aware_wrapper(func, *args, **inner_kwargs)
433
368
 
434
369
  # Apply the MCP decorators
435
370
  instrumented = dr_mcp_extras()(wrapper)
436
- mcp.tool(name=name, description=description, tags=tags)(instrumented)
371
+ mcp.tool(**kwargs)(instrumented)
437
372
  return instrumented
438
373
 
439
374
  return decorator
@@ -488,11 +423,10 @@ async def register_tools(
488
423
  # Apply dr_mcp_extras to the memory-aware function
489
424
  wrapped_fn = dr_mcp_extras()(memory_aware_fn)
490
425
 
491
- # Create annotations with tags, deployment_id if provided
492
- annotations = ToolAnnotations() # type: ignore[call-arg]
493
- if tags is not None:
494
- annotations.tags = tags # type: ignore[attr-defined]
426
+ # Create annotations only when additional metadata is required
427
+ annotations: ToolAnnotations | None = None # type: ignore[assignment]
495
428
  if deployment_id is not None:
429
+ annotations = ToolAnnotations() # type: ignore[call-arg]
496
430
  annotations.deployment_id = deployment_id # type: ignore[attr-defined]
497
431
 
498
432
  tool = Tool.from_function(
@@ -30,6 +30,7 @@ class ToolType(str, Enum):
30
30
  JIRA = "jira"
31
31
  CONFLUENCE = "confluence"
32
32
  GDRIVE = "gdrive"
33
+ MICROSOFT_GRAPH = "microsoft_graph"
33
34
 
34
35
 
35
36
  class ToolConfig(TypedDict):
@@ -39,7 +40,7 @@ class ToolConfig(TypedDict):
39
40
  oauth_check: Callable[["MCPServerConfig"], bool] | None
40
41
  directory: str
41
42
  package_prefix: str
42
- config_field_name: str # Name of the config field (e.g., "enable_predictive_tools")
43
+ config_field_name: str
43
44
 
44
45
 
45
46
  # Tool configuration registry
@@ -53,25 +54,32 @@ TOOL_CONFIGS: dict[ToolType, ToolConfig] = {
53
54
  ),
54
55
  ToolType.JIRA: ToolConfig(
55
56
  name="jira",
56
- oauth_check=lambda config: config.is_jira_oauth_configured,
57
+ oauth_check=lambda config: config.tool_config.is_atlassian_oauth_configured,
57
58
  directory="jira",
58
59
  package_prefix="datarobot_genai.drmcp.tools.jira",
59
60
  config_field_name="enable_jira_tools",
60
61
  ),
61
62
  ToolType.CONFLUENCE: ToolConfig(
62
63
  name="confluence",
63
- oauth_check=lambda config: config.is_confluence_oauth_configured,
64
+ oauth_check=lambda config: config.tool_config.is_atlassian_oauth_configured,
64
65
  directory="confluence",
65
66
  package_prefix="datarobot_genai.drmcp.tools.confluence",
66
67
  config_field_name="enable_confluence_tools",
67
68
  ),
68
69
  ToolType.GDRIVE: ToolConfig(
69
70
  name="gdrive",
70
- oauth_check=lambda config: config.is_gdrive_oauth_configured,
71
+ oauth_check=lambda config: config.tool_config.is_google_oauth_configured,
71
72
  directory="gdrive",
72
73
  package_prefix="datarobot_genai.drmcp.tools.gdrive",
73
74
  config_field_name="enable_gdrive_tools",
74
75
  ),
76
+ ToolType.MICROSOFT_GRAPH: ToolConfig(
77
+ name="microsoft_graph",
78
+ oauth_check=lambda config: config.tool_config.is_microsoft_oauth_configured,
79
+ directory="microsoft_graph",
80
+ package_prefix="datarobot_genai.drmcp.tools.microsoft_graph",
81
+ config_field_name="enable_microsoft_graph_tools",
82
+ ),
75
83
  }
76
84
 
77
85
 
@@ -92,12 +100,12 @@ def is_tool_enabled(tool_type: ToolType, config: "MCPServerConfig") -> bool:
92
100
  -------
93
101
  True if the tool is enabled, False otherwise
94
102
  """
95
- tool_config = TOOL_CONFIGS[tool_type]
96
- enable_config_name = tool_config["config_field_name"]
97
- is_enabled = getattr(config, enable_config_name)
103
+ tool_config_registry = TOOL_CONFIGS[tool_type]
104
+ enable_config_name = tool_config_registry["config_field_name"]
105
+ is_enabled = getattr(config.tool_config, enable_config_name)
98
106
 
99
107
  # If tool is enabled, check OAuth requirements if needed
100
- if is_enabled and tool_config["oauth_check"] is not None:
101
- return tool_config["oauth_check"](config)
108
+ if is_enabled and tool_config_registry["oauth_check"] is not None:
109
+ return tool_config_registry["oauth_check"](config)
102
110
 
103
111
  return is_enabled
@@ -41,7 +41,7 @@ def filter_tools_by_tags(
41
41
  filtered_tools = []
42
42
 
43
43
  for tool in tools:
44
- tool_tags = getattr(tool.annotations, "tags", []) if tool.annotations else []
44
+ tool_tags = get_tool_tags(tool)
45
45
 
46
46
  if not tool_tags:
47
47
  continue
@@ -68,9 +68,18 @@ def get_tool_tags(tool: Tool | MCPTool) -> list[str]:
68
68
  -------
69
69
  List of tags for the tool
70
70
  """
71
+ # Primary: native FastMCP meta location
72
+ if hasattr(tool, "meta") and getattr(tool, "meta"):
73
+ fastmcp_meta = tool.meta.get("_fastmcp", {})
74
+ meta_tags = fastmcp_meta.get("tags", [])
75
+ if isinstance(meta_tags, list):
76
+ return meta_tags
77
+
78
+ # Fallback: annotations.tags (for compatibility during transition)
71
79
  if tool.annotations and hasattr(tool.annotations, "tags"):
72
80
  tags = getattr(tool.annotations, "tags", [])
73
81
  return tags if isinstance(tags, list) else []
82
+
74
83
  return []
75
84
 
76
85
 
@@ -39,6 +39,54 @@ class ETETestExpectations(BaseModel):
39
39
  SHOULD_NOT_BE_EMPTY = "SHOULD_NOT_BE_EMPTY"
40
40
 
41
41
 
42
+ def _extract_structured_content(tool_result: str) -> Any:
43
+ r"""
44
+ Extract and parse structured content from tool result string.
45
+
46
+ Tool results are formatted as:
47
+ "Content: {content}\nStructured content: {structured_content}"
48
+
49
+ Structured content can be:
50
+ 1. A JSON object with a "result" key: {"result": "..."} or {"result": "{...}"}
51
+ 2. A direct JSON object: {"key": "value", ...}
52
+ 3. Empty or missing
53
+
54
+ Args:
55
+ tool_result: The tool result string
56
+
57
+ Returns
58
+ -------
59
+ Parsed structured content, or None if not available
60
+ """
61
+ # Early returns for invalid inputs
62
+ if not tool_result or "Structured content: " not in tool_result:
63
+ return None
64
+
65
+ structured_part = tool_result.split("Structured content: ", 1)[1].strip()
66
+ # Parse JSON, return None on failure or empty structured_part
67
+ if not structured_part:
68
+ return None
69
+ try:
70
+ structured_data = json.loads(structured_part)
71
+ except json.JSONDecodeError:
72
+ return None
73
+
74
+ # If structured data has a "result" key, extract and parse that
75
+ if isinstance(structured_data, dict) and "result" in structured_data:
76
+ result_value = structured_data["result"]
77
+ # If result is a JSON string (starts with { or [), try to parse it
78
+ if isinstance(result_value, str) and result_value.strip().startswith(("{", "[")):
79
+ try:
80
+ parsed_result = json.loads(result_value)
81
+ except json.JSONDecodeError:
82
+ parsed_result = result_value # Return string as-is if parsing fails
83
+ return parsed_result
84
+ return result_value # Return result value directly
85
+
86
+ # If it's a direct JSON object (not wrapped in {"result": ...}), return it as-is
87
+ return structured_data
88
+
89
+
42
90
  def _check_dict_has_keys(
43
91
  expected: dict[str, Any],
44
92
  actual: dict[str, Any] | list[dict[str, Any]],
@@ -130,7 +178,26 @@ class ToolBaseE2E:
130
178
  f"result, but got: {response.tool_results[i]}"
131
179
  )
132
180
  else:
133
- actual_result = json.loads(response.tool_results[i])
181
+ actual_result = _extract_structured_content(response.tool_results[i])
182
+ if actual_result is None:
183
+ # Fallback: try to parse the entire tool result as JSON
184
+ try:
185
+ actual_result = json.loads(response.tool_results[i])
186
+ except json.JSONDecodeError:
187
+ # If that fails, try to extract content part
188
+ if "Content: " in response.tool_results[i]:
189
+ content_part = response.tool_results[i].split("Content: ", 1)[1]
190
+ if "\nStructured content: " in content_part:
191
+ content_part = content_part.split(
192
+ "\nStructured content: ", 1
193
+ )[0]
194
+ try:
195
+ actual_result = json.loads(content_part.strip())
196
+ except json.JSONDecodeError:
197
+ raise AssertionError(
198
+ f"Could not parse tool result for "
199
+ f"{tool_call.tool_name}: {response.tool_results[i]}"
200
+ )
134
201
  assert _check_dict_has_keys(expected_result, actual_result), (
135
202
  f"Should have called {tool_call.tool_name} tool with the correct "
136
203
  f"result structure, but got: {response.tool_results[i]}"