data-designer-engine 0.4.0rc2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. data_designer/engine/analysis/column_profilers/base.py +1 -2
  2. data_designer/engine/analysis/dataset_profiler.py +1 -2
  3. data_designer/engine/column_generators/generators/base.py +1 -6
  4. data_designer/engine/column_generators/generators/custom.py +195 -0
  5. data_designer/engine/column_generators/generators/llm_completion.py +34 -4
  6. data_designer/engine/column_generators/registry.py +3 -0
  7. data_designer/engine/column_generators/utils/errors.py +3 -0
  8. data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
  9. data_designer/engine/dataset_builders/column_wise_builder.py +47 -10
  10. data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
  11. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  12. data_designer/engine/mcp/__init__.py +30 -0
  13. data_designer/engine/mcp/errors.py +22 -0
  14. data_designer/engine/mcp/facade.py +485 -0
  15. data_designer/engine/mcp/factory.py +46 -0
  16. data_designer/engine/mcp/io.py +487 -0
  17. data_designer/engine/mcp/registry.py +203 -0
  18. data_designer/engine/model_provider.py +68 -0
  19. data_designer/engine/models/facade.py +92 -30
  20. data_designer/engine/models/factory.py +18 -1
  21. data_designer/engine/models/utils.py +111 -21
  22. data_designer/engine/resources/resource_provider.py +72 -3
  23. data_designer/engine/testing/fixtures.py +233 -0
  24. data_designer/engine/testing/stubs.py +1 -2
  25. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
  26. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +27 -19
  27. data_designer/engine/_version.py +0 -34
  28. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
@@ -8,6 +8,7 @@ from functools import cached_property
8
8
  from pydantic import BaseModel, field_validator, model_validator
9
9
  from typing_extensions import Self
10
10
 
11
+ from data_designer.config.mcp import MCPProviderT
11
12
  from data_designer.config.models import ModelProvider
12
13
  from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
13
14
 
@@ -75,3 +76,70 @@ def resolve_model_provider_registry(
75
76
  providers=model_providers,
76
77
  default=default_provider_name or model_providers[0].name,
77
78
  )
79
+
80
+
81
+ class MCPProviderRegistry(BaseModel):
82
+ """Registry for MCP providers.
83
+
84
+ Unlike ModelProviderRegistry, MCPProviderRegistry can be empty since MCP providers
85
+ are optional. Users only need to register MCP providers if they want to use MCP tools
86
+ for generation.
87
+
88
+ Attributes:
89
+ providers: List of MCP providers (both MCPProvider and LocalStdioMCPProvider).
90
+ """
91
+
92
+ providers: list[MCPProviderT] = []
93
+
94
+ @field_validator("providers", mode="after")
95
+ @classmethod
96
+ def validate_providers_have_unique_names(cls, v: list[MCPProviderT]) -> list[MCPProviderT]:
97
+ names = set()
98
+ dupes = set()
99
+ for provider in v:
100
+ if provider.name in names:
101
+ dupes.add(provider.name)
102
+ names.add(provider.name)
103
+
104
+ if len(dupes) > 0:
105
+ raise ValueError(f"MCP providers must have unique names, found duplicates: {dupes}")
106
+ return v
107
+
108
+ @cached_property
109
+ def _providers_dict(self) -> dict[str, MCPProviderT]:
110
+ return {p.name: p for p in self.providers}
111
+
112
+ def get_provider(self, name: str) -> MCPProviderT:
113
+ """Get an MCP provider by name.
114
+
115
+ Args:
116
+ name: The name of the MCP provider.
117
+
118
+ Returns:
119
+ The MCP provider with the given name.
120
+
121
+ Raises:
122
+ UnknownProviderError: If no provider with the given name is registered.
123
+ """
124
+ try:
125
+ return self._providers_dict[name]
126
+ except KeyError:
127
+ raise UnknownProviderError(f"No MCP provider named {name!r} registered")
128
+
129
+ def is_empty(self) -> bool:
130
+ """Check if the registry has no providers."""
131
+ return len(self.providers) == 0
132
+
133
+
134
+ def resolve_mcp_provider_registry(
135
+ mcp_providers: list[MCPProviderT] | None = None,
136
+ ) -> MCPProviderRegistry:
137
+ """Create an MCPProviderRegistry from a list of MCP providers.
138
+
139
+ Args:
140
+ mcp_providers: Optional list of MCP providers. If None or empty, returns an empty registry.
141
+
142
+ Returns:
143
+ An MCPProviderRegistry containing the provided MCP providers.
144
+ """
145
+ return MCPProviderRegistry(providers=mcp_providers or [])
@@ -9,6 +9,7 @@ from copy import deepcopy
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
11
  from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
12
+ from data_designer.engine.mcp.errors import MCPConfigurationError
12
13
  from data_designer.engine.model_provider import ModelProviderRegistry
13
14
  from data_designer.engine.models.errors import (
14
15
  GenerationValidationFailureError,
@@ -18,13 +19,22 @@ from data_designer.engine.models.errors import (
18
19
  from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
19
20
  from data_designer.engine.models.parsers.errors import ParserException
20
21
  from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
21
- from data_designer.engine.models.utils import prompt_to_messages, str_to_message
22
+ from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
22
23
  from data_designer.engine.secret_resolver import SecretResolver
23
24
  from data_designer.lazy_heavy_imports import litellm
24
25
 
25
26
  if TYPE_CHECKING:
26
27
  import litellm
27
28
 
29
+ from data_designer.engine.mcp.facade import MCPFacade
30
+ from data_designer.engine.mcp.registry import MCPRegistry
31
+
32
+
33
+ def _identity(x: Any) -> Any:
34
+ """Identity function for default parser. Module-level for pickling compatibility."""
35
+ return x
36
+
37
+
28
38
  logger = logging.getLogger(__name__)
29
39
 
30
40
 
@@ -34,10 +44,13 @@ class ModelFacade:
34
44
  model_config: ModelConfig,
35
45
  secret_resolver: SecretResolver,
36
46
  model_provider_registry: ModelProviderRegistry,
37
- ):
47
+ *,
48
+ mcp_registry: MCPRegistry | None = None,
49
+ ) -> None:
38
50
  self._model_config = model_config
39
51
  self._secret_resolver = secret_resolver
40
52
  self._model_provider_registry = model_provider_registry
53
+ self._mcp_registry = mcp_registry
41
54
  self._litellm_deployment = self._get_litellm_deployment(model_config)
42
55
  self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
43
56
  self._usage_stats = ModelUsageStats()
@@ -67,16 +80,17 @@ class ModelFacade:
67
80
  return self._usage_stats
68
81
 
69
82
  def completion(
70
- self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
83
+ self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs
71
84
  ) -> litellm.ModelResponse:
85
+ message_payloads = [message.to_dict() for message in messages]
72
86
  logger.debug(
73
87
  f"Prompting model {self.model_name!r}...",
74
- extra={"model": self.model_name, "messages": messages},
88
+ extra={"model": self.model_name, "messages": message_payloads},
75
89
  )
76
90
  response = None
77
91
  kwargs = self.consolidate_kwargs(**kwargs)
78
92
  try:
79
- response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
93
+ response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs)
80
94
  logger.debug(
81
95
  f"Received completion from model {self.model_name!r}",
82
96
  extra={
@@ -103,6 +117,17 @@ class ModelFacade:
103
117
  kwargs["extra_headers"] = self.model_provider.extra_headers
104
118
  return kwargs
105
119
 
120
+ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
121
+ if tool_alias is None:
122
+ return None
123
+ if self._mcp_registry is None:
124
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
125
+
126
+ try:
127
+ return self._mcp_registry.get_mcp(tool_alias=tool_alias)
128
+ except ValueError as exc:
129
+ raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
130
+
106
131
  @catch_llm_exceptions
107
132
  def generate_text_embeddings(
108
133
  self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
@@ -141,15 +166,16 @@ class ModelFacade:
141
166
  self,
142
167
  prompt: str,
143
168
  *,
144
- parser: Callable[[str], Any],
169
+ parser: Callable[[str], Any] = _identity,
145
170
  system_prompt: str | None = None,
146
171
  multi_modal_context: list[dict[str, Any]] | None = None,
172
+ tool_alias: str | None = None,
147
173
  max_correction_steps: int = 0,
148
174
  max_conversation_restarts: int = 0,
149
175
  skip_usage_tracking: bool = False,
150
176
  purpose: str | None = None,
151
177
  **kwargs,
152
- ) -> tuple[Any, str | None]:
178
+ ) -> tuple[Any, list[ChatMessage]]:
153
179
  """Generate a parsed output with correction steps.
154
180
 
155
181
  This generation call will attempt to generate an output which is
@@ -169,7 +195,10 @@ class ModelFacade:
169
195
  no system message is provided and the model should use its default system
170
196
  prompt.
171
197
  parser (func(str) -> Any): A function applied to the LLM response which processes
172
- an LLM response into some output object.
198
+ an LLM response into some output object. Default: identity function.
199
+ tool_alias (str | None): Optional tool configuration alias. When provided,
200
+ the model may call permitted tools from the configured MCP providers.
201
+ The alias must reference a ToolConfig registered in the MCPRegistry.
173
202
  max_correction_steps (int): Maximum number of correction rounds permitted
174
203
  within a single conversation. Note, many rounds can lead to increasing
175
204
  context size without necessarily improving performance -- small language
@@ -182,37 +211,67 @@ class ModelFacade:
182
211
  It is expected to be used by the @catch_llm_exceptions decorator.
183
212
  **kwargs: Additional arguments to pass to the model.
184
213
 
214
+ Returns:
215
+ A tuple containing:
216
+ - The parsed output object from the parser.
217
+ - The full trace of ChatMessage entries in the conversation, including any tool calls,
218
+ corrections, and reasoning traces. Callers can decide whether to store this.
219
+
185
220
  Raises:
186
221
  GenerationValidationFailureError: If the maximum number of retries or
187
222
  correction steps are met and the last response failures on
188
223
  generation validation.
224
+ MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured.
189
225
  """
190
226
  output_obj = None
227
+ tool_schemas = None
228
+ tool_call_turns = 0
191
229
  curr_num_correction_steps = 0
192
230
  curr_num_restarts = 0
193
- curr_generation_attempt = 0
194
- max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
195
231
 
196
- starting_messages = prompt_to_messages(
232
+ mcp_facade = self._get_mcp_facade(tool_alias)
233
+
234
+ # Checkpoint for restarts - updated after tool calls so we don't repeat them
235
+ restart_checkpoint = prompt_to_messages(
197
236
  user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
198
237
  )
199
- messages = deepcopy(starting_messages)
238
+ checkpoint_tool_call_turns = 0
239
+ messages: list[ChatMessage] = deepcopy(restart_checkpoint)
240
+
241
+ if mcp_facade is not None:
242
+ tool_schemas = mcp_facade.get_tool_schemas()
200
243
 
201
244
  while True:
202
- curr_generation_attempt += 1
203
- logger.debug(
204
- f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
245
+ completion_kwargs = dict(kwargs)
246
+ if tool_schemas is not None:
247
+ completion_kwargs["tools"] = tool_schemas
248
+
249
+ completion_response = self.completion(
250
+ messages,
251
+ skip_usage_tracking=skip_usage_tracking,
252
+ **completion_kwargs,
205
253
  )
206
254
 
207
- completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
208
- response = completion_response.choices[0].message.content or ""
209
- reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
255
+ # Process any tool calls in the response (handles parallel tool calling)
256
+ if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
257
+ tool_call_turns += 1
258
+
259
+ if tool_call_turns > mcp_facade.max_tool_call_turns:
260
+ # Gracefully refuse tool calls when budget is exhausted
261
+ messages.extend(mcp_facade.refuse_completion_response(completion_response))
262
+ else:
263
+ messages.extend(mcp_facade.process_completion_response(completion_response))
264
+
265
+ # Update checkpoint so restarts don't repeat tool calls
266
+ restart_checkpoint = deepcopy(messages)
267
+ checkpoint_tool_call_turns = tool_call_turns
210
268
 
211
- if reasoning_trace:
212
- ## There are generally some extra newlines with how these get parsed.
213
- response = response.strip()
214
- reasoning_trace = reasoning_trace.strip()
269
+ continue # Back to top
215
270
 
271
+ # No tool calls remaining to process
272
+ response = completion_response.choices[0].message.content or ""
273
+ reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
274
+ messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
216
275
  curr_num_correction_steps += 1
217
276
 
218
277
  try:
@@ -223,21 +282,24 @@ class ModelFacade:
223
282
  raise GenerationValidationFailureError(
224
283
  "Unsuccessful generation attempt. No retries were attempted."
225
284
  ) from exc
285
+
226
286
  if curr_num_correction_steps <= max_correction_steps:
227
- ## Add turns to loop-back errors for correction
228
- messages += [
229
- str_to_message(content=response, role="assistant"),
230
- str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
231
- ]
287
+ # Add user message with error for correction
288
+ messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
289
+
232
290
  elif curr_num_restarts < max_conversation_restarts:
233
291
  curr_num_correction_steps = 0
234
292
  curr_num_restarts += 1
235
- messages = deepcopy(starting_messages)
293
+ messages = deepcopy(restart_checkpoint)
294
+ tool_call_turns = checkpoint_tool_call_turns
295
+
236
296
  else:
237
297
  raise GenerationValidationFailureError(
238
- f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
298
+ f"Unsuccessful generation despite {max_correction_steps} correction steps "
299
+ f"and {max_conversation_restarts} conversation restarts."
239
300
  ) from exc
240
- return output_obj, reasoning_trace
301
+
302
+ return output_obj, messages
241
303
 
242
304
  def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
243
305
  provider = self._model_provider_registry.get_provider(model_config.provider)
@@ -10,6 +10,7 @@ from data_designer.engine.model_provider import ModelProviderRegistry
10
10
  from data_designer.engine.secret_resolver import SecretResolver
11
11
 
12
12
  if TYPE_CHECKING:
13
+ from data_designer.engine.mcp.registry import MCPRegistry
13
14
  from data_designer.engine.models.registry import ModelRegistry
14
15
 
15
16
 
@@ -18,12 +19,23 @@ def create_model_registry(
18
19
  model_configs: list[ModelConfig] | None = None,
19
20
  secret_resolver: SecretResolver,
20
21
  model_provider_registry: ModelProviderRegistry,
22
+ mcp_registry: MCPRegistry | None = None,
21
23
  ) -> ModelRegistry:
22
24
  """Factory function for creating a ModelRegistry instance.
23
25
 
24
26
  Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
27
  This is a factory function pattern - imports inside factories are idiomatic Python
26
28
  for lazy initialization.
29
+
30
+ Args:
31
+ model_configs: Optional list of model configurations to register.
32
+ secret_resolver: Resolver for secrets referenced in provider configs.
33
+ model_provider_registry: Registry of model provider configurations.
34
+ mcp_registry: Optional MCP registry for tool operations. When provided,
35
+ ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation.
36
+
37
+ Returns:
38
+ A configured ModelRegistry instance.
27
39
  """
28
40
  from data_designer.engine.models.facade import ModelFacade
29
41
  from data_designer.engine.models.litellm_overrides import apply_litellm_patches
@@ -32,7 +44,12 @@ def create_model_registry(
32
44
  apply_litellm_patches()
33
45
 
34
46
  def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
- return ModelFacade(model_config, secret_resolver, model_provider_registry)
47
+ return ModelFacade(
48
+ model_config,
49
+ secret_resolver,
50
+ model_provider_registry,
51
+ mcp_registry=mcp_registry,
52
+ )
36
53
 
37
54
  return ModelRegistry(
38
55
  model_configs=model_configs,
@@ -3,7 +3,84 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Any
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Literal
8
+
9
+
10
+ @dataclass
11
+ class ChatMessage:
12
+ """A chat message in an LLM conversation.
13
+
14
+ This dataclass represents messages exchanged in a conversation with an LLM,
15
+ supporting various message types including user prompts, assistant responses,
16
+ system instructions, and tool interactions.
17
+
18
+ Attributes:
19
+ role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
20
+ content: The message content. Can be a string or a list of content blocks
21
+ for multimodal messages (e.g., text + images).
22
+ reasoning_content: Optional reasoning/thinking content from the assistant,
23
+ typically from extended thinking or chain-of-thought models.
24
+ tool_calls: Optional list of tool calls requested by the assistant.
25
+ Each tool call contains 'id', 'type', and 'function' keys.
26
+ tool_call_id: Optional ID linking a tool response to its corresponding
27
+ tool call. Required for messages with role='tool'.
28
+ """
29
+
30
+ role: Literal["user", "assistant", "system", "tool"]
31
+ content: str | list[dict[str, Any]] = ""
32
+ reasoning_content: str | None = None
33
+ tool_calls: list[dict[str, Any]] = field(default_factory=list)
34
+ tool_call_id: str | None = None
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Convert the message to a dictionary format for API calls.
38
+
39
+ Content is normalized to a list of ChatML-style blocks to keep a
40
+ consistent schema across traces and API payloads.
41
+
42
+ Returns:
43
+ A dictionary containing the message fields. Only includes non-empty
44
+ optional fields to keep the output clean.
45
+ """
46
+ result: dict[str, Any] = {"role": self.role, "content": _normalize_content_blocks(self.content)}
47
+ if self.reasoning_content:
48
+ result["reasoning_content"] = self.reasoning_content
49
+ if self.tool_calls:
50
+ result["tool_calls"] = self.tool_calls
51
+ if self.tool_call_id:
52
+ result["tool_call_id"] = self.tool_call_id
53
+ return result
54
+
55
+ @classmethod
56
+ def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
57
+ """Create a user message."""
58
+ return cls(role="user", content=content)
59
+
60
+ @classmethod
61
+ def as_assistant(
62
+ cls,
63
+ content: str = "",
64
+ reasoning_content: str | None = None,
65
+ tool_calls: list[dict[str, Any]] | None = None,
66
+ ) -> ChatMessage:
67
+ """Create an assistant message."""
68
+ return cls(
69
+ role="assistant",
70
+ content=content,
71
+ reasoning_content=reasoning_content,
72
+ tool_calls=tool_calls or [],
73
+ )
74
+
75
+ @classmethod
76
+ def as_system(cls, content: str) -> ChatMessage:
77
+ """Create a system message."""
78
+ return cls(role="system", content=content)
79
+
80
+ @classmethod
81
+ def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
82
+ """Create a tool response message."""
83
+ return cls(role="tool", content=content, tool_call_id=tool_call_id)
7
84
 
8
85
 
9
86
  def prompt_to_messages(
@@ -11,28 +88,41 @@ def prompt_to_messages(
11
88
  user_prompt: str,
12
89
  system_prompt: str | None = None,
13
90
  multi_modal_context: list[dict[str, Any]] | None = None,
14
- ) -> list[dict[str, str | list[dict]]]:
15
- """Convert a user and system prompt into Messages format.
91
+ ) -> list[ChatMessage]:
92
+ """Convert a user and system prompt into ChatMessage list.
16
93
 
17
94
  Args:
18
95
  user_prompt (str): A user prompt.
19
96
  system_prompt (str, optional): An optional system prompt.
20
97
  """
21
- user_content = user_prompt
22
- if multi_modal_context and len(multi_modal_context) > 0:
23
- user_content = []
24
- for context in multi_modal_context:
25
- user_content.append(context)
26
- user_content.append({"type": "text", "text": user_prompt})
27
- return (
28
- [
29
- str_to_message(content=system_prompt, role="system"),
30
- str_to_message(content=user_content, role="user"),
31
- ]
32
- if system_prompt
33
- else [str_to_message(content=user_content, role="user")]
34
- )
35
-
36
-
37
- def str_to_message(content: str | list[dict], role: str = "user") -> dict[str, str | list[dict]]:
38
- return {"content": content, "role": role}
98
+ user_content: str | list[dict[str, Any]] = user_prompt
99
+ if multi_modal_context:
100
+ user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
101
+
102
+ if system_prompt:
103
+ return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
104
+ return [ChatMessage.as_user(user_content)]
105
+
106
+
107
+ def _normalize_content_blocks(content: Any) -> list[dict[str, Any]]:
108
+ if isinstance(content, list):
109
+ return [_normalize_content_block(block) for block in content]
110
+ if content is None:
111
+ return []
112
+ return [_text_block(content)]
113
+
114
+
115
+ def _normalize_content_block(block: Any) -> dict[str, Any]:
116
+ if isinstance(block, dict) and "type" in block:
117
+ return block
118
+ return _text_block(block)
119
+
120
+
121
+ def _text_block(value: Any) -> dict[str, Any]:
122
+ if value is None:
123
+ text_value = ""
124
+ elif isinstance(value, str):
125
+ text_value = value
126
+ else:
127
+ text_value = str(value)
128
+ return {"type": "text", "text": text_value}
@@ -5,15 +5,21 @@ from __future__ import annotations
5
5
 
6
6
  from data_designer.config.base import ConfigBase
7
7
  from data_designer.config.dataset_metadata import DatasetMetadata
8
+ from data_designer.config.mcp import MCPProviderT, ToolConfig
8
9
  from data_designer.config.models import ModelConfig
9
10
  from data_designer.config.run_config import RunConfig
10
11
  from data_designer.config.seed_source import SeedSource
11
12
  from data_designer.config.utils.type_helpers import StrEnum
12
13
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
13
- from data_designer.engine.model_provider import ModelProviderRegistry
14
+ from data_designer.engine.mcp.factory import create_mcp_registry
15
+ from data_designer.engine.mcp.registry import MCPRegistry
16
+ from data_designer.engine.model_provider import (
17
+ ModelProviderRegistry,
18
+ resolve_mcp_provider_registry,
19
+ )
14
20
  from data_designer.engine.models.factory import create_model_registry
15
21
  from data_designer.engine.models.registry import ModelRegistry
16
- from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
22
+ from data_designer.engine.resources.managed_storage import ManagedBlobStorage
17
23
  from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
18
24
  from data_designer.engine.secret_resolver import SecretResolver
19
25
 
@@ -28,6 +34,7 @@ class ResourceProvider(ConfigBase):
28
34
  artifact_storage: ArtifactStorage
29
35
  blob_storage: ManagedBlobStorage | None = None
30
36
  model_registry: ModelRegistry | None = None
37
+ mcp_registry: MCPRegistry | None = None
31
38
  run_config: RunConfig = RunConfig()
32
39
  seed_reader: SeedReader | None = None
33
40
 
@@ -43,6 +50,31 @@ class ResourceProvider(ConfigBase):
43
50
  return DatasetMetadata(seed_column_names=seed_column_names)
44
51
 
45
52
 
53
+ def _validate_tool_configs_against_providers(
54
+ tool_configs: list[ToolConfig],
55
+ mcp_providers: list[MCPProviderT],
56
+ ) -> None:
57
+ """Validate that all providers referenced in tool configs exist.
58
+
59
+ Args:
60
+ tool_configs: List of tool configurations to validate.
61
+ mcp_providers: List of available MCP provider configurations.
62
+
63
+ Raises:
64
+ ValueError: If a tool config references a provider that doesn't exist.
65
+ """
66
+ available_providers = {p.name for p in mcp_providers}
67
+
68
+ for tc in tool_configs:
69
+ missing_providers = [p for p in tc.providers if p not in available_providers]
70
+ if missing_providers:
71
+ available_list = sorted(available_providers) if available_providers else ["(none configured)"]
72
+ raise ValueError(
73
+ f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} "
74
+ f"which are not registered. Available providers: {available_list}"
75
+ )
76
+
77
+
46
78
  def create_resource_provider(
47
79
  *,
48
80
  artifact_storage: ArtifactStorage,
@@ -53,9 +85,31 @@ def create_resource_provider(
53
85
  blob_storage: ManagedBlobStorage | None = None,
54
86
  seed_dataset_source: SeedSource | None = None,
55
87
  run_config: RunConfig | None = None,
88
+ mcp_providers: list[MCPProviderT] | None = None,
89
+ tool_configs: list[ToolConfig] | None = None,
56
90
  ) -> ResourceProvider:
57
91
  """Factory function for creating a ResourceProvider instance.
92
+
58
93
  This function triggers lazy loading of heavy dependencies like litellm.
94
+ The creation order is:
95
+ 1. MCPProviderRegistry (can be empty)
96
+ 2. MCPRegistry with tool_configs
97
+ 3. ModelRegistry with mcp_registry
98
+
99
+ Args:
100
+ artifact_storage: Storage for build artifacts.
101
+ model_configs: List of model configurations.
102
+ secret_resolver: Resolver for secrets.
103
+ model_provider_registry: Registry of model providers.
104
+ seed_reader_registry: Registry of seed readers.
105
+ blob_storage: Optional blob storage for large files.
106
+ seed_dataset_source: Optional source for seed datasets.
107
+ run_config: Optional runtime configuration.
108
+ mcp_providers: Optional list of MCP provider configurations.
109
+ tool_configs: Optional list of tool configurations.
110
+
111
+ Returns:
112
+ A configured ResourceProvider instance.
59
113
  """
60
114
  seed_reader = None
61
115
  if seed_dataset_source:
@@ -64,14 +118,29 @@ def create_resource_provider(
64
118
  secret_resolver,
65
119
  )
66
120
 
121
+ # Create MCPProviderRegistry first (can be empty)
122
+ mcp_provider_registry = resolve_mcp_provider_registry(mcp_providers)
123
+
124
+ # Create MCPRegistry with tool configs (only if tool_configs provided)
125
+ # Tool validation is performed during dataset builder health checks.
126
+ mcp_registry = None
127
+ if tool_configs:
128
+ mcp_registry = create_mcp_registry(
129
+ tool_configs=tool_configs,
130
+ secret_resolver=secret_resolver,
131
+ mcp_provider_registry=mcp_provider_registry,
132
+ )
133
+
67
134
  return ResourceProvider(
68
135
  artifact_storage=artifact_storage,
69
136
  model_registry=create_model_registry(
70
137
  model_configs=model_configs,
71
138
  secret_resolver=secret_resolver,
72
139
  model_provider_registry=model_provider_registry,
140
+ mcp_registry=mcp_registry,
73
141
  ),
74
- blob_storage=blob_storage or init_managed_blob_storage(),
142
+ blob_storage=blob_storage,
143
+ mcp_registry=mcp_registry,
75
144
  seed_reader=seed_reader,
76
145
  run_config=run_config or RunConfig(),
77
146
  )