data-designer-engine 0.4.0rc3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. data_designer/engine/analysis/column_profilers/base.py +1 -2
  2. data_designer/engine/analysis/dataset_profiler.py +1 -2
  3. data_designer/engine/column_generators/generators/base.py +1 -6
  4. data_designer/engine/column_generators/generators/custom.py +195 -0
  5. data_designer/engine/column_generators/generators/llm_completion.py +32 -5
  6. data_designer/engine/column_generators/registry.py +3 -0
  7. data_designer/engine/column_generators/utils/errors.py +3 -0
  8. data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
  9. data_designer/engine/dataset_builders/column_wise_builder.py +23 -5
  10. data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
  11. data_designer/engine/mcp/__init__.py +30 -0
  12. data_designer/engine/mcp/errors.py +22 -0
  13. data_designer/engine/mcp/facade.py +485 -0
  14. data_designer/engine/mcp/factory.py +46 -0
  15. data_designer/engine/mcp/io.py +487 -0
  16. data_designer/engine/mcp/registry.py +203 -0
  17. data_designer/engine/model_provider.py +68 -0
  18. data_designer/engine/models/facade.py +74 -9
  19. data_designer/engine/models/factory.py +18 -1
  20. data_designer/engine/models/utils.py +28 -1
  21. data_designer/engine/resources/resource_provider.py +72 -3
  22. data_designer/engine/testing/fixtures.py +233 -0
  23. data_designer/engine/testing/stubs.py +1 -2
  24. {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
  25. {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +26 -19
  26. data_designer/engine/_version.py +0 -34
  27. {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
@@ -8,6 +8,7 @@ from functools import cached_property
8
8
  from pydantic import BaseModel, field_validator, model_validator
9
9
  from typing_extensions import Self
10
10
 
11
+ from data_designer.config.mcp import MCPProviderT
11
12
  from data_designer.config.models import ModelProvider
12
13
  from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
13
14
 
@@ -75,3 +76,70 @@ def resolve_model_provider_registry(
75
76
  providers=model_providers,
76
77
  default=default_provider_name or model_providers[0].name,
77
78
  )
79
+
80
+
81
+ class MCPProviderRegistry(BaseModel):
82
+ """Registry for MCP providers.
83
+
84
+ Unlike ModelProviderRegistry, MCPProviderRegistry can be empty since MCP providers
85
+ are optional. Users only need to register MCP providers if they want to use MCP tools
86
+ for generation.
87
+
88
+ Attributes:
89
+ providers: List of MCP providers (both MCPProvider and LocalStdioMCPProvider).
90
+ """
91
+
92
+ providers: list[MCPProviderT] = []
93
+
94
+ @field_validator("providers", mode="after")
95
+ @classmethod
96
+ def validate_providers_have_unique_names(cls, v: list[MCPProviderT]) -> list[MCPProviderT]:
97
+ names = set()
98
+ dupes = set()
99
+ for provider in v:
100
+ if provider.name in names:
101
+ dupes.add(provider.name)
102
+ names.add(provider.name)
103
+
104
+ if len(dupes) > 0:
105
+ raise ValueError(f"MCP providers must have unique names, found duplicates: {dupes}")
106
+ return v
107
+
108
+ @cached_property
109
+ def _providers_dict(self) -> dict[str, MCPProviderT]:
110
+ return {p.name: p for p in self.providers}
111
+
112
+ def get_provider(self, name: str) -> MCPProviderT:
113
+ """Get an MCP provider by name.
114
+
115
+ Args:
116
+ name: The name of the MCP provider.
117
+
118
+ Returns:
119
+ The MCP provider with the given name.
120
+
121
+ Raises:
122
+ UnknownProviderError: If no provider with the given name is registered.
123
+ """
124
+ try:
125
+ return self._providers_dict[name]
126
+ except KeyError:
127
+ raise UnknownProviderError(f"No MCP provider named {name!r} registered")
128
+
129
+ def is_empty(self) -> bool:
130
+ """Check if the registry has no providers."""
131
+ return len(self.providers) == 0
132
+
133
+
134
+ def resolve_mcp_provider_registry(
135
+ mcp_providers: list[MCPProviderT] | None = None,
136
+ ) -> MCPProviderRegistry:
137
+ """Create an MCPProviderRegistry from a list of MCP providers.
138
+
139
+ Args:
140
+ mcp_providers: Optional list of MCP providers. If None or empty, returns an empty registry.
141
+
142
+ Returns:
143
+ An MCPProviderRegistry containing the provided MCP providers.
144
+ """
145
+ return MCPProviderRegistry(providers=mcp_providers or [])
@@ -9,6 +9,7 @@ from copy import deepcopy
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
11
  from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
12
+ from data_designer.engine.mcp.errors import MCPConfigurationError
12
13
  from data_designer.engine.model_provider import ModelProviderRegistry
13
14
  from data_designer.engine.models.errors import (
14
15
  GenerationValidationFailureError,
@@ -25,6 +26,15 @@ from data_designer.lazy_heavy_imports import litellm
25
26
  if TYPE_CHECKING:
26
27
  import litellm
27
28
 
29
+ from data_designer.engine.mcp.facade import MCPFacade
30
+ from data_designer.engine.mcp.registry import MCPRegistry
31
+
32
+
33
+ def _identity(x: Any) -> Any:
34
+ """Identity function for default parser. Module-level for pickling compatibility."""
35
+ return x
36
+
37
+
28
38
  logger = logging.getLogger(__name__)
29
39
 
30
40
 
@@ -34,10 +44,13 @@ class ModelFacade:
34
44
  model_config: ModelConfig,
35
45
  secret_resolver: SecretResolver,
36
46
  model_provider_registry: ModelProviderRegistry,
37
- ):
47
+ *,
48
+ mcp_registry: MCPRegistry | None = None,
49
+ ) -> None:
38
50
  self._model_config = model_config
39
51
  self._secret_resolver = secret_resolver
40
52
  self._model_provider_registry = model_provider_registry
53
+ self._mcp_registry = mcp_registry
41
54
  self._litellm_deployment = self._get_litellm_deployment(model_config)
42
55
  self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
43
56
  self._usage_stats = ModelUsageStats()
@@ -104,6 +117,17 @@ class ModelFacade:
104
117
  kwargs["extra_headers"] = self.model_provider.extra_headers
105
118
  return kwargs
106
119
 
120
+ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
121
+ if tool_alias is None:
122
+ return None
123
+ if self._mcp_registry is None:
124
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
125
+
126
+ try:
127
+ return self._mcp_registry.get_mcp(tool_alias=tool_alias)
128
+ except ValueError as exc:
129
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
130
+
107
131
  @catch_llm_exceptions
108
132
  def generate_text_embeddings(
109
133
  self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
@@ -142,9 +166,10 @@ class ModelFacade:
142
166
  self,
143
167
  prompt: str,
144
168
  *,
145
- parser: Callable[[str], Any],
169
+ parser: Callable[[str], Any] = _identity,
146
170
  system_prompt: str | None = None,
147
171
  multi_modal_context: list[dict[str, Any]] | None = None,
172
+ tool_alias: str | None = None,
148
173
  max_correction_steps: int = 0,
149
174
  max_conversation_restarts: int = 0,
150
175
  skip_usage_tracking: bool = False,
@@ -170,7 +195,10 @@ class ModelFacade:
170
195
  no system message is provided and the model should use its default system
171
196
  prompt.
172
197
  parser (func(str) -> Any): A function applied to the LLM response which processes
173
- an LLM response into some output object.
198
+ an LLM response into some output object. Default: identity function.
199
+ tool_alias (str | None): Optional tool configuration alias. When provided,
200
+ the model may call permitted tools from the configured MCP providers.
201
+ The alias must reference a ToolConfig registered in the MCPRegistry.
174
202
  max_correction_steps (int): Maximum number of correction rounds permitted
175
203
  within a single conversation. Note, many rounds can lead to increasing
176
204
  context size without necessarily improving performance -- small language
@@ -186,25 +214,61 @@ class ModelFacade:
186
214
  Returns:
187
215
  A tuple containing:
188
216
  - The parsed output object from the parser.
189
- - The full trace of ChatMessage entries in the conversation, including any
190
- corrections and reasoning traces. Callers can decide whether to store this.
217
+ - The full trace of ChatMessage entries in the conversation, including any tool calls,
218
+ corrections, and reasoning traces. Callers can decide whether to store this.
191
219
 
192
220
  Raises:
193
221
  GenerationValidationFailureError: If the maximum number of retries or
194
222
  correction steps are met and the last response failures on
195
223
  generation validation.
224
+ MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured.
196
225
  """
197
226
  output_obj = None
227
+ tool_schemas = None
228
+ tool_call_turns = 0
198
229
  curr_num_correction_steps = 0
199
230
  curr_num_restarts = 0
200
231
 
201
- starting_messages = prompt_to_messages(
232
+ mcp_facade = self._get_mcp_facade(tool_alias)
233
+
234
+ # Checkpoint for restarts - updated after tool calls so we don't repeat them
235
+ restart_checkpoint = prompt_to_messages(
202
236
  user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
203
237
  )
204
- messages: list[ChatMessage] = deepcopy(starting_messages)
238
+ checkpoint_tool_call_turns = 0
239
+ messages: list[ChatMessage] = deepcopy(restart_checkpoint)
240
+
241
+ if mcp_facade is not None:
242
+ tool_schemas = mcp_facade.get_tool_schemas()
205
243
 
206
244
  while True:
207
- completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
245
+ completion_kwargs = dict(kwargs)
246
+ if tool_schemas is not None:
247
+ completion_kwargs["tools"] = tool_schemas
248
+
249
+ completion_response = self.completion(
250
+ messages,
251
+ skip_usage_tracking=skip_usage_tracking,
252
+ **completion_kwargs,
253
+ )
254
+
255
+ # Process any tool calls in the response (handles parallel tool calling)
256
+ if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
257
+ tool_call_turns += 1
258
+
259
+ if tool_call_turns > mcp_facade.max_tool_call_turns:
260
+ # Gracefully refuse tool calls when budget is exhausted
261
+ messages.extend(mcp_facade.refuse_completion_response(completion_response))
262
+ else:
263
+ messages.extend(mcp_facade.process_completion_response(completion_response))
264
+
265
+ # Update checkpoint so restarts don't repeat tool calls
266
+ restart_checkpoint = deepcopy(messages)
267
+ checkpoint_tool_call_turns = tool_call_turns
268
+
269
+ continue # Back to top
270
+
271
+ # No tool calls remaining to process
208
272
  response = completion_response.choices[0].message.content or ""
209
273
  reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
210
274
  messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
@@ -226,7 +290,8 @@ class ModelFacade:
226
290
  elif curr_num_restarts < max_conversation_restarts:
227
291
  curr_num_correction_steps = 0
228
292
  curr_num_restarts += 1
229
- messages = deepcopy(starting_messages)
293
+ messages = deepcopy(restart_checkpoint)
294
+ tool_call_turns = checkpoint_tool_call_turns
230
295
 
231
296
  else:
232
297
  raise GenerationValidationFailureError(
@@ -10,6 +10,7 @@ from data_designer.engine.model_provider import ModelProviderRegistry
10
10
  from data_designer.engine.secret_resolver import SecretResolver
11
11
 
12
12
  if TYPE_CHECKING:
13
+ from data_designer.engine.mcp.registry import MCPRegistry
13
14
  from data_designer.engine.models.registry import ModelRegistry
14
15
 
15
16
 
@@ -18,12 +19,23 @@ def create_model_registry(
18
19
  model_configs: list[ModelConfig] | None = None,
19
20
  secret_resolver: SecretResolver,
20
21
  model_provider_registry: ModelProviderRegistry,
22
+ mcp_registry: MCPRegistry | None = None,
21
23
  ) -> ModelRegistry:
22
24
  """Factory function for creating a ModelRegistry instance.
23
25
 
24
26
  Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
27
  This is a factory function pattern - imports inside factories are idiomatic Python
26
28
  for lazy initialization.
29
+
30
+ Args:
31
+ model_configs: Optional list of model configurations to register.
32
+ secret_resolver: Resolver for secrets referenced in provider configs.
33
+ model_provider_registry: Registry of model provider configurations.
34
+ mcp_registry: Optional MCP registry for tool operations. When provided,
35
+ ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation.
36
+
37
+ Returns:
38
+ A configured ModelRegistry instance.
27
39
  """
28
40
  from data_designer.engine.models.facade import ModelFacade
29
41
  from data_designer.engine.models.litellm_overrides import apply_litellm_patches
@@ -32,7 +44,12 @@ def create_model_registry(
32
44
  apply_litellm_patches()
33
45
 
34
46
  def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
- return ModelFacade(model_config, secret_resolver, model_provider_registry)
47
+ return ModelFacade(
48
+ model_config,
49
+ secret_resolver,
50
+ model_provider_registry,
51
+ mcp_registry=mcp_registry,
52
+ )
36
53
 
37
54
  return ModelRegistry(
38
55
  model_configs=model_configs,
@@ -36,11 +36,14 @@ class ChatMessage:
36
36
  def to_dict(self) -> dict[str, Any]:
37
37
  """Convert the message to a dictionary format for API calls.
38
38
 
39
+ Content is normalized to a list of ChatML-style blocks to keep a
40
+ consistent schema across traces and API payloads.
41
+
39
42
  Returns:
40
43
  A dictionary containing the message fields. Only includes non-empty
41
44
  optional fields to keep the output clean.
42
45
  """
43
- result: dict[str, Any] = {"role": self.role, "content": self.content}
46
+ result: dict[str, Any] = {"role": self.role, "content": _normalize_content_blocks(self.content)}
44
47
  if self.reasoning_content:
45
48
  result["reasoning_content"] = self.reasoning_content
46
49
  if self.tool_calls:
@@ -99,3 +102,27 @@ def prompt_to_messages(
99
102
  if system_prompt:
100
103
  return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
101
104
  return [ChatMessage.as_user(user_content)]
105
+
106
+
107
+ def _normalize_content_blocks(content: Any) -> list[dict[str, Any]]:
108
+ if isinstance(content, list):
109
+ return [_normalize_content_block(block) for block in content]
110
+ if content is None:
111
+ return []
112
+ return [_text_block(content)]
113
+
114
+
115
+ def _normalize_content_block(block: Any) -> dict[str, Any]:
116
+ if isinstance(block, dict) and "type" in block:
117
+ return block
118
+ return _text_block(block)
119
+
120
+
121
+ def _text_block(value: Any) -> dict[str, Any]:
122
+ if value is None:
123
+ text_value = ""
124
+ elif isinstance(value, str):
125
+ text_value = value
126
+ else:
127
+ text_value = str(value)
128
+ return {"type": "text", "text": text_value}
@@ -5,15 +5,21 @@ from __future__ import annotations
5
5
 
6
6
  from data_designer.config.base import ConfigBase
7
7
  from data_designer.config.dataset_metadata import DatasetMetadata
8
+ from data_designer.config.mcp import MCPProviderT, ToolConfig
8
9
  from data_designer.config.models import ModelConfig
9
10
  from data_designer.config.run_config import RunConfig
10
11
  from data_designer.config.seed_source import SeedSource
11
12
  from data_designer.config.utils.type_helpers import StrEnum
12
13
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
13
- from data_designer.engine.model_provider import ModelProviderRegistry
14
+ from data_designer.engine.mcp.factory import create_mcp_registry
15
+ from data_designer.engine.mcp.registry import MCPRegistry
16
+ from data_designer.engine.model_provider import (
17
+ ModelProviderRegistry,
18
+ resolve_mcp_provider_registry,
19
+ )
14
20
  from data_designer.engine.models.factory import create_model_registry
15
21
  from data_designer.engine.models.registry import ModelRegistry
16
- from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
22
+ from data_designer.engine.resources.managed_storage import ManagedBlobStorage
17
23
  from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
18
24
  from data_designer.engine.secret_resolver import SecretResolver
19
25
 
@@ -28,6 +34,7 @@ class ResourceProvider(ConfigBase):
28
34
  artifact_storage: ArtifactStorage
29
35
  blob_storage: ManagedBlobStorage | None = None
30
36
  model_registry: ModelRegistry | None = None
37
+ mcp_registry: MCPRegistry | None = None
31
38
  run_config: RunConfig = RunConfig()
32
39
  seed_reader: SeedReader | None = None
33
40
 
@@ -43,6 +50,31 @@ class ResourceProvider(ConfigBase):
43
50
  return DatasetMetadata(seed_column_names=seed_column_names)
44
51
 
45
52
 
53
+ def _validate_tool_configs_against_providers(
54
+ tool_configs: list[ToolConfig],
55
+ mcp_providers: list[MCPProviderT],
56
+ ) -> None:
57
+ """Validate that all providers referenced in tool configs exist.
58
+
59
+ Args:
60
+ tool_configs: List of tool configurations to validate.
61
+ mcp_providers: List of available MCP provider configurations.
62
+
63
+ Raises:
64
+ ValueError: If a tool config references a provider that doesn't exist.
65
+ """
66
+ available_providers = {p.name for p in mcp_providers}
67
+
68
+ for tc in tool_configs:
69
+ missing_providers = [p for p in tc.providers if p not in available_providers]
70
+ if missing_providers:
71
+ available_list = sorted(available_providers) if available_providers else ["(none configured)"]
72
+ raise ValueError(
73
+ f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} "
74
+ f"which are not registered. Available providers: {available_list}"
75
+ )
76
+
77
+
46
78
  def create_resource_provider(
47
79
  *,
48
80
  artifact_storage: ArtifactStorage,
@@ -53,9 +85,31 @@ def create_resource_provider(
53
85
  blob_storage: ManagedBlobStorage | None = None,
54
86
  seed_dataset_source: SeedSource | None = None,
55
87
  run_config: RunConfig | None = None,
88
+ mcp_providers: list[MCPProviderT] | None = None,
89
+ tool_configs: list[ToolConfig] | None = None,
56
90
  ) -> ResourceProvider:
57
91
  """Factory function for creating a ResourceProvider instance.
92
+
58
93
  This function triggers lazy loading of heavy dependencies like litellm.
94
+ The creation order is:
95
+ 1. MCPProviderRegistry (can be empty)
96
+ 2. MCPRegistry with tool_configs
97
+ 3. ModelRegistry with mcp_registry
98
+
99
+ Args:
100
+ artifact_storage: Storage for build artifacts.
101
+ model_configs: List of model configurations.
102
+ secret_resolver: Resolver for secrets.
103
+ model_provider_registry: Registry of model providers.
104
+ seed_reader_registry: Registry of seed readers.
105
+ blob_storage: Optional blob storage for large files.
106
+ seed_dataset_source: Optional source for seed datasets.
107
+ run_config: Optional runtime configuration.
108
+ mcp_providers: Optional list of MCP provider configurations.
109
+ tool_configs: Optional list of tool configurations.
110
+
111
+ Returns:
112
+ A configured ResourceProvider instance.
59
113
  """
60
114
  seed_reader = None
61
115
  if seed_dataset_source:
@@ -64,14 +118,29 @@ def create_resource_provider(
64
118
  secret_resolver,
65
119
  )
66
120
 
121
+ # Create MCPProviderRegistry first (can be empty)
122
+ mcp_provider_registry = resolve_mcp_provider_registry(mcp_providers)
123
+
124
+ # Create MCPRegistry with tool configs (only if tool_configs provided)
125
+ # Tool validation is performed during dataset builder health checks.
126
+ mcp_registry = None
127
+ if tool_configs:
128
+ mcp_registry = create_mcp_registry(
129
+ tool_configs=tool_configs,
130
+ secret_resolver=secret_resolver,
131
+ mcp_provider_registry=mcp_provider_registry,
132
+ )
133
+
67
134
  return ResourceProvider(
68
135
  artifact_storage=artifact_storage,
69
136
  model_registry=create_model_registry(
70
137
  model_configs=model_configs,
71
138
  secret_resolver=secret_resolver,
72
139
  model_provider_registry=model_provider_registry,
140
+ mcp_registry=mcp_registry,
73
141
  ),
74
- blob_storage=blob_storage or init_managed_blob_storage(),
142
+ blob_storage=blob_storage,
143
+ mcp_registry=mcp_registry,
75
144
  seed_reader=seed_reader,
76
145
  run_config=run_config or RunConfig(),
77
146
  )
@@ -0,0 +1,233 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Pytest fixtures for engine testing.
5
+
6
+ Located in src/ so it can be packaged and shared across subpackages via pytest_plugins.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any
12
+ from unittest.mock import MagicMock
13
+
14
+ import pytest
15
+
16
+ from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig
17
+ from data_designer.engine.mcp.facade import MCPFacade
18
+ from data_designer.engine.model_provider import MCPProviderRegistry
19
+ from data_designer.engine.secret_resolver import SecretResolver
20
+ from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader
21
+
22
+ # =============================================================================
23
+ # Fake LLM response classes (used by completion response fixtures)
24
+ # =============================================================================
25
+
26
+
27
+ class _FakeMessage:
28
+ """Fake message class for mocking LLM completion responses."""
29
+
30
+ def __init__(
31
+ self,
32
+ content: str | None,
33
+ tool_calls: list[dict] | None = None,
34
+ reasoning_content: str | None = None,
35
+ ) -> None:
36
+ self.content = content
37
+ self.tool_calls = tool_calls
38
+ self.reasoning_content = reasoning_content
39
+
40
+
41
+ class _FakeChoice:
42
+ """Fake choice class for mocking LLM completion responses."""
43
+
44
+ def __init__(self, message: _FakeMessage) -> None:
45
+ self.message = message
46
+
47
+
48
+ class _FakeResponse:
49
+ """Fake response class for mocking LLM completion responses."""
50
+
51
+ def __init__(self, message: _FakeMessage) -> None:
52
+ self.choices = [_FakeChoice(message)]
53
+
54
+
55
+ # =============================================================================
56
+ # Seed reader fixtures
57
+ # =============================================================================
58
+
59
+
60
+ @pytest.fixture
61
+ def stub_seed_reader() -> StubHuggingFaceSeedReader:
62
+ """Stub seed reader for testing seed dataset functionality."""
63
+ return StubHuggingFaceSeedReader()
64
+
65
+
66
+ # =============================================================================
67
+ # MCP Provider fixtures
68
+ # =============================================================================
69
+
70
+
71
+ @pytest.fixture
72
+ def stub_mcp_provider_registry() -> MCPProviderRegistry:
73
+ """Create a stub MCP provider registry with test providers."""
74
+ return MCPProviderRegistry(
75
+ providers=[
76
+ LocalStdioMCPProvider(name="tools", command="python"),
77
+ LocalStdioMCPProvider(name="secondary", command="python"),
78
+ ]
79
+ )
80
+
81
+
82
+ @pytest.fixture
83
+ def stub_mcp_provider_registry_single() -> MCPProviderRegistry:
84
+ """Create a stub MCP provider registry with a single provider."""
85
+ return MCPProviderRegistry(providers=[LocalStdioMCPProvider(name="tools", command="python")])
86
+
87
+
88
+ @pytest.fixture
89
+ def stub_secret_resolver() -> MagicMock:
90
+ """Create a stub secret resolver for testing."""
91
+ resolver = MagicMock(spec=SecretResolver)
92
+ resolver.resolve.side_effect = lambda x: x # Return the input as-is
93
+ return resolver
94
+
95
+
96
+ @pytest.fixture
97
+ def stub_stdio_provider() -> LocalStdioMCPProvider:
98
+ """Create a stub stdio MCP provider for testing."""
99
+ return LocalStdioMCPProvider(
100
+ name="test-stdio",
101
+ command="python",
102
+ args=["-m", "test_server"],
103
+ env={"TEST_VAR": "value"},
104
+ )
105
+
106
+
107
+ @pytest.fixture
108
+ def stub_sse_provider() -> MCPProvider:
109
+ """Create a stub SSE MCP provider for testing."""
110
+ return MCPProvider(
111
+ name="test-sse",
112
+ endpoint="http://localhost:8080/sse",
113
+ api_key="test-key",
114
+ )
115
+
116
+
117
+ # =============================================================================
118
+ # Tool config fixtures
119
+ # =============================================================================
120
+
121
+
122
+ @pytest.fixture
123
+ def stub_tool_config() -> ToolConfig:
124
+ """Create a basic tool configuration for testing."""
125
+ return ToolConfig(
126
+ tool_alias="test-tools",
127
+ providers=["tools"],
128
+ max_tool_call_turns=3,
129
+ timeout_sec=30.0,
130
+ )
131
+
132
+
133
+ @pytest.fixture
134
+ def stub_tool_config_with_allow_list() -> ToolConfig:
135
+ """Create a tool configuration with an allow list."""
136
+ return ToolConfig(
137
+ tool_alias="test-tools",
138
+ providers=["tools"],
139
+ allow_tools=["lookup", "search"],
140
+ max_tool_call_turns=3,
141
+ )
142
+
143
+
144
+ # =============================================================================
145
+ # Facade fixtures
146
+ # =============================================================================
147
+
148
+
149
+ @pytest.fixture
150
+ def stub_mcp_facade(
151
+ stub_tool_config: ToolConfig, stub_secret_resolver: MagicMock, stub_mcp_provider_registry: MCPProviderRegistry
152
+ ) -> MCPFacade:
153
+ """Create a stub MCPFacade for testing."""
154
+ return MCPFacade(
155
+ tool_config=stub_tool_config,
156
+ secret_resolver=stub_secret_resolver,
157
+ mcp_provider_registry=stub_mcp_provider_registry,
158
+ )
159
+
160
+
161
+ @pytest.fixture
162
+ def stub_mcp_facade_factory() -> Any:
163
+ """Create a stub MCP facade factory for testing."""
164
+
165
+ def factory(
166
+ tool_config: ToolConfig, secret_resolver: SecretResolver, provider_registry: MCPProviderRegistry
167
+ ) -> MCPFacade:
168
+ return MCPFacade(
169
+ tool_config=tool_config, secret_resolver=secret_resolver, mcp_provider_registry=provider_registry
170
+ )
171
+
172
+ return factory
173
+
174
+
175
+ # =============================================================================
176
+ # Completion response fixtures
177
+ # =============================================================================
178
+
179
+
180
+ @pytest.fixture
181
+ def mock_completion_response_no_tools() -> _FakeResponse:
182
+ """Mock LLM response with no tool calls."""
183
+ return _FakeResponse(_FakeMessage(content="Hello, I can help with that."))
184
+
185
+
186
+ @pytest.fixture
187
+ def mock_completion_response_single_tool() -> _FakeResponse:
188
+ """Mock LLM response with single tool call."""
189
+ tool_call = {
190
+ "id": "call-1",
191
+ "type": "function",
192
+ "function": {"name": "lookup", "arguments": '{"query": "test"}'},
193
+ }
194
+ return _FakeResponse(_FakeMessage(content="Let me look that up.", tool_calls=[tool_call]))
195
+
196
+
197
+ @pytest.fixture
198
+ def mock_completion_response_parallel_tools() -> _FakeResponse:
199
+ """Mock LLM response with multiple parallel tool calls."""
200
+ tool_calls = [
201
+ {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "first"}'}},
202
+ {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "second"}'}},
203
+ {"id": "call-3", "type": "function", "function": {"name": "fetch", "arguments": '{"url": "example.com"}'}},
204
+ ]
205
+ return _FakeResponse(_FakeMessage(content="Executing multiple tools.", tool_calls=tool_calls))
206
+
207
+
208
+ @pytest.fixture
209
+ def mock_completion_response_with_reasoning() -> _FakeResponse:
210
+ """Mock LLM response with reasoning_content."""
211
+ return _FakeResponse(
212
+ _FakeMessage(
213
+ content=" Final answer with extra spaces. ",
214
+ reasoning_content=" Thinking about the problem... ",
215
+ )
216
+ )
217
+
218
+
219
+ @pytest.fixture
220
+ def mock_completion_response_tool_with_reasoning() -> _FakeResponse:
221
+ """Mock LLM response with tool calls and reasoning_content."""
222
+ tool_call = {
223
+ "id": "call-1",
224
+ "type": "function",
225
+ "function": {"name": "lookup", "arguments": '{"query": "test"}'},
226
+ }
227
+ return _FakeResponse(
228
+ _FakeMessage(
229
+ content=" Looking it up... ",
230
+ tool_calls=[tool_call],
231
+ reasoning_content=" I should use the lookup tool. ",
232
+ )
233
+ )