data-designer-engine 0.4.0rc2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/analysis/column_profilers/base.py +1 -2
- data_designer/engine/analysis/dataset_profiler.py +1 -2
- data_designer/engine/column_generators/generators/base.py +1 -6
- data_designer/engine/column_generators/generators/custom.py +195 -0
- data_designer/engine/column_generators/generators/llm_completion.py +34 -4
- data_designer/engine/column_generators/registry.py +3 -0
- data_designer/engine/column_generators/utils/errors.py +3 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +47 -10
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
- data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
- data_designer/engine/mcp/__init__.py +30 -0
- data_designer/engine/mcp/errors.py +22 -0
- data_designer/engine/mcp/facade.py +485 -0
- data_designer/engine/mcp/factory.py +46 -0
- data_designer/engine/mcp/io.py +487 -0
- data_designer/engine/mcp/registry.py +203 -0
- data_designer/engine/model_provider.py +68 -0
- data_designer/engine/models/facade.py +92 -30
- data_designer/engine/models/factory.py +18 -1
- data_designer/engine/models/utils.py +111 -21
- data_designer/engine/resources/resource_provider.py +72 -3
- data_designer/engine/testing/fixtures.py +233 -0
- data_designer/engine/testing/stubs.py +1 -2
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +27 -19
- data_designer/engine/_version.py +0 -34
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -8,6 +8,7 @@ from functools import cached_property
|
|
|
8
8
|
from pydantic import BaseModel, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
+
from data_designer.config.mcp import MCPProviderT
|
|
11
12
|
from data_designer.config.models import ModelProvider
|
|
12
13
|
from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
|
|
13
14
|
|
|
@@ -75,3 +76,70 @@ def resolve_model_provider_registry(
|
|
|
75
76
|
providers=model_providers,
|
|
76
77
|
default=default_provider_name or model_providers[0].name,
|
|
77
78
|
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MCPProviderRegistry(BaseModel):
|
|
82
|
+
"""Registry for MCP providers.
|
|
83
|
+
|
|
84
|
+
Unlike ModelProviderRegistry, MCPProviderRegistry can be empty since MCP providers
|
|
85
|
+
are optional. Users only need to register MCP providers if they want to use MCP tools
|
|
86
|
+
for generation.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
providers: List of MCP providers (both MCPProvider and LocalStdioMCPProvider).
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
providers: list[MCPProviderT] = []
|
|
93
|
+
|
|
94
|
+
@field_validator("providers", mode="after")
|
|
95
|
+
@classmethod
|
|
96
|
+
def validate_providers_have_unique_names(cls, v: list[MCPProviderT]) -> list[MCPProviderT]:
|
|
97
|
+
names = set()
|
|
98
|
+
dupes = set()
|
|
99
|
+
for provider in v:
|
|
100
|
+
if provider.name in names:
|
|
101
|
+
dupes.add(provider.name)
|
|
102
|
+
names.add(provider.name)
|
|
103
|
+
|
|
104
|
+
if len(dupes) > 0:
|
|
105
|
+
raise ValueError(f"MCP providers must have unique names, found duplicates: {dupes}")
|
|
106
|
+
return v
|
|
107
|
+
|
|
108
|
+
@cached_property
|
|
109
|
+
def _providers_dict(self) -> dict[str, MCPProviderT]:
|
|
110
|
+
return {p.name: p for p in self.providers}
|
|
111
|
+
|
|
112
|
+
def get_provider(self, name: str) -> MCPProviderT:
|
|
113
|
+
"""Get an MCP provider by name.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
name: The name of the MCP provider.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The MCP provider with the given name.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
UnknownProviderError: If no provider with the given name is registered.
|
|
123
|
+
"""
|
|
124
|
+
try:
|
|
125
|
+
return self._providers_dict[name]
|
|
126
|
+
except KeyError:
|
|
127
|
+
raise UnknownProviderError(f"No MCP provider named {name!r} registered")
|
|
128
|
+
|
|
129
|
+
def is_empty(self) -> bool:
|
|
130
|
+
"""Check if the registry has no providers."""
|
|
131
|
+
return len(self.providers) == 0
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def resolve_mcp_provider_registry(
|
|
135
|
+
mcp_providers: list[MCPProviderT] | None = None,
|
|
136
|
+
) -> MCPProviderRegistry:
|
|
137
|
+
"""Create an MCPProviderRegistry from a list of MCP providers.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
mcp_providers: Optional list of MCP providers. If None or empty, returns an empty registry.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
An MCPProviderRegistry containing the provided MCP providers.
|
|
144
|
+
"""
|
|
145
|
+
return MCPProviderRegistry(providers=mcp_providers or [])
|
|
@@ -9,6 +9,7 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
11
|
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
12
|
+
from data_designer.engine.mcp.errors import MCPConfigurationError
|
|
12
13
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
13
14
|
from data_designer.engine.models.errors import (
|
|
14
15
|
GenerationValidationFailureError,
|
|
@@ -18,13 +19,22 @@ from data_designer.engine.models.errors import (
|
|
|
18
19
|
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
|
|
19
20
|
from data_designer.engine.models.parsers.errors import ParserException
|
|
20
21
|
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
21
|
-
from data_designer.engine.models.utils import
|
|
22
|
+
from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
|
|
22
23
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
23
24
|
from data_designer.lazy_heavy_imports import litellm
|
|
24
25
|
|
|
25
26
|
if TYPE_CHECKING:
|
|
26
27
|
import litellm
|
|
27
28
|
|
|
29
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
30
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _identity(x: Any) -> Any:
|
|
34
|
+
"""Identity function for default parser. Module-level for pickling compatibility."""
|
|
35
|
+
return x
|
|
36
|
+
|
|
37
|
+
|
|
28
38
|
logger = logging.getLogger(__name__)
|
|
29
39
|
|
|
30
40
|
|
|
@@ -34,10 +44,13 @@ class ModelFacade:
|
|
|
34
44
|
model_config: ModelConfig,
|
|
35
45
|
secret_resolver: SecretResolver,
|
|
36
46
|
model_provider_registry: ModelProviderRegistry,
|
|
37
|
-
|
|
47
|
+
*,
|
|
48
|
+
mcp_registry: MCPRegistry | None = None,
|
|
49
|
+
) -> None:
|
|
38
50
|
self._model_config = model_config
|
|
39
51
|
self._secret_resolver = secret_resolver
|
|
40
52
|
self._model_provider_registry = model_provider_registry
|
|
53
|
+
self._mcp_registry = mcp_registry
|
|
41
54
|
self._litellm_deployment = self._get_litellm_deployment(model_config)
|
|
42
55
|
self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
|
|
43
56
|
self._usage_stats = ModelUsageStats()
|
|
@@ -67,16 +80,17 @@ class ModelFacade:
|
|
|
67
80
|
return self._usage_stats
|
|
68
81
|
|
|
69
82
|
def completion(
|
|
70
|
-
self, messages: list[
|
|
83
|
+
self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs
|
|
71
84
|
) -> litellm.ModelResponse:
|
|
85
|
+
message_payloads = [message.to_dict() for message in messages]
|
|
72
86
|
logger.debug(
|
|
73
87
|
f"Prompting model {self.model_name!r}...",
|
|
74
|
-
extra={"model": self.model_name, "messages":
|
|
88
|
+
extra={"model": self.model_name, "messages": message_payloads},
|
|
75
89
|
)
|
|
76
90
|
response = None
|
|
77
91
|
kwargs = self.consolidate_kwargs(**kwargs)
|
|
78
92
|
try:
|
|
79
|
-
response = self._router.completion(model=self.model_name, messages=
|
|
93
|
+
response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs)
|
|
80
94
|
logger.debug(
|
|
81
95
|
f"Received completion from model {self.model_name!r}",
|
|
82
96
|
extra={
|
|
@@ -103,6 +117,17 @@ class ModelFacade:
|
|
|
103
117
|
kwargs["extra_headers"] = self.model_provider.extra_headers
|
|
104
118
|
return kwargs
|
|
105
119
|
|
|
120
|
+
def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
|
|
121
|
+
if tool_alias is None:
|
|
122
|
+
return None
|
|
123
|
+
if self._mcp_registry is None:
|
|
124
|
+
raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
return self._mcp_registry.get_mcp(tool_alias=tool_alias)
|
|
128
|
+
except ValueError as exc:
|
|
129
|
+
raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
|
|
130
|
+
|
|
106
131
|
@catch_llm_exceptions
|
|
107
132
|
def generate_text_embeddings(
|
|
108
133
|
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
@@ -141,15 +166,16 @@ class ModelFacade:
|
|
|
141
166
|
self,
|
|
142
167
|
prompt: str,
|
|
143
168
|
*,
|
|
144
|
-
parser: Callable[[str], Any],
|
|
169
|
+
parser: Callable[[str], Any] = _identity,
|
|
145
170
|
system_prompt: str | None = None,
|
|
146
171
|
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
172
|
+
tool_alias: str | None = None,
|
|
147
173
|
max_correction_steps: int = 0,
|
|
148
174
|
max_conversation_restarts: int = 0,
|
|
149
175
|
skip_usage_tracking: bool = False,
|
|
150
176
|
purpose: str | None = None,
|
|
151
177
|
**kwargs,
|
|
152
|
-
) -> tuple[Any,
|
|
178
|
+
) -> tuple[Any, list[ChatMessage]]:
|
|
153
179
|
"""Generate a parsed output with correction steps.
|
|
154
180
|
|
|
155
181
|
This generation call will attempt to generate an output which is
|
|
@@ -169,7 +195,10 @@ class ModelFacade:
|
|
|
169
195
|
no system message is provided and the model should use its default system
|
|
170
196
|
prompt.
|
|
171
197
|
parser (func(str) -> Any): A function applied to the LLM response which processes
|
|
172
|
-
an LLM response into some output object.
|
|
198
|
+
an LLM response into some output object. Default: identity function.
|
|
199
|
+
tool_alias (str | None): Optional tool configuration alias. When provided,
|
|
200
|
+
the model may call permitted tools from the configured MCP providers.
|
|
201
|
+
The alias must reference a ToolConfig registered in the MCPRegistry.
|
|
173
202
|
max_correction_steps (int): Maximum number of correction rounds permitted
|
|
174
203
|
within a single conversation. Note, many rounds can lead to increasing
|
|
175
204
|
context size without necessarily improving performance -- small language
|
|
@@ -182,37 +211,67 @@ class ModelFacade:
|
|
|
182
211
|
It is expected to be used by the @catch_llm_exceptions decorator.
|
|
183
212
|
**kwargs: Additional arguments to pass to the model.
|
|
184
213
|
|
|
214
|
+
Returns:
|
|
215
|
+
A tuple containing:
|
|
216
|
+
- The parsed output object from the parser.
|
|
217
|
+
- The full trace of ChatMessage entries in the conversation, including any tool calls,
|
|
218
|
+
corrections, and reasoning traces. Callers can decide whether to store this.
|
|
219
|
+
|
|
185
220
|
Raises:
|
|
186
221
|
GenerationValidationFailureError: If the maximum number of retries or
|
|
187
222
|
correction steps are met and the last response failures on
|
|
188
223
|
generation validation.
|
|
224
|
+
MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured.
|
|
189
225
|
"""
|
|
190
226
|
output_obj = None
|
|
227
|
+
tool_schemas = None
|
|
228
|
+
tool_call_turns = 0
|
|
191
229
|
curr_num_correction_steps = 0
|
|
192
230
|
curr_num_restarts = 0
|
|
193
|
-
curr_generation_attempt = 0
|
|
194
|
-
max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
|
|
195
231
|
|
|
196
|
-
|
|
232
|
+
mcp_facade = self._get_mcp_facade(tool_alias)
|
|
233
|
+
|
|
234
|
+
# Checkpoint for restarts - updated after tool calls so we don't repeat them
|
|
235
|
+
restart_checkpoint = prompt_to_messages(
|
|
197
236
|
user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
|
|
198
237
|
)
|
|
199
|
-
|
|
238
|
+
checkpoint_tool_call_turns = 0
|
|
239
|
+
messages: list[ChatMessage] = deepcopy(restart_checkpoint)
|
|
240
|
+
|
|
241
|
+
if mcp_facade is not None:
|
|
242
|
+
tool_schemas = mcp_facade.get_tool_schemas()
|
|
200
243
|
|
|
201
244
|
while True:
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
245
|
+
completion_kwargs = dict(kwargs)
|
|
246
|
+
if tool_schemas is not None:
|
|
247
|
+
completion_kwargs["tools"] = tool_schemas
|
|
248
|
+
|
|
249
|
+
completion_response = self.completion(
|
|
250
|
+
messages,
|
|
251
|
+
skip_usage_tracking=skip_usage_tracking,
|
|
252
|
+
**completion_kwargs,
|
|
205
253
|
)
|
|
206
254
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
255
|
+
# Process any tool calls in the response (handles parallel tool calling)
|
|
256
|
+
if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
|
|
257
|
+
tool_call_turns += 1
|
|
258
|
+
|
|
259
|
+
if tool_call_turns > mcp_facade.max_tool_call_turns:
|
|
260
|
+
# Gracefully refuse tool calls when budget is exhausted
|
|
261
|
+
messages.extend(mcp_facade.refuse_completion_response(completion_response))
|
|
262
|
+
else:
|
|
263
|
+
messages.extend(mcp_facade.process_completion_response(completion_response))
|
|
264
|
+
|
|
265
|
+
# Update checkpoint so restarts don't repeat tool calls
|
|
266
|
+
restart_checkpoint = deepcopy(messages)
|
|
267
|
+
checkpoint_tool_call_turns = tool_call_turns
|
|
210
268
|
|
|
211
|
-
|
|
212
|
-
## There are generally some extra newlines with how these get parsed.
|
|
213
|
-
response = response.strip()
|
|
214
|
-
reasoning_trace = reasoning_trace.strip()
|
|
269
|
+
continue # Back to top
|
|
215
270
|
|
|
271
|
+
# No tool calls remaining to process
|
|
272
|
+
response = completion_response.choices[0].message.content or ""
|
|
273
|
+
reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
|
|
274
|
+
messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
|
|
216
275
|
curr_num_correction_steps += 1
|
|
217
276
|
|
|
218
277
|
try:
|
|
@@ -223,21 +282,24 @@ class ModelFacade:
|
|
|
223
282
|
raise GenerationValidationFailureError(
|
|
224
283
|
"Unsuccessful generation attempt. No retries were attempted."
|
|
225
284
|
) from exc
|
|
285
|
+
|
|
226
286
|
if curr_num_correction_steps <= max_correction_steps:
|
|
227
|
-
|
|
228
|
-
messages
|
|
229
|
-
|
|
230
|
-
str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
|
|
231
|
-
]
|
|
287
|
+
# Add user message with error for correction
|
|
288
|
+
messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
|
|
289
|
+
|
|
232
290
|
elif curr_num_restarts < max_conversation_restarts:
|
|
233
291
|
curr_num_correction_steps = 0
|
|
234
292
|
curr_num_restarts += 1
|
|
235
|
-
messages = deepcopy(
|
|
293
|
+
messages = deepcopy(restart_checkpoint)
|
|
294
|
+
tool_call_turns = checkpoint_tool_call_turns
|
|
295
|
+
|
|
236
296
|
else:
|
|
237
297
|
raise GenerationValidationFailureError(
|
|
238
|
-
f"Unsuccessful generation
|
|
298
|
+
f"Unsuccessful generation despite {max_correction_steps} correction steps "
|
|
299
|
+
f"and {max_conversation_restarts} conversation restarts."
|
|
239
300
|
) from exc
|
|
240
|
-
|
|
301
|
+
|
|
302
|
+
return output_obj, messages
|
|
241
303
|
|
|
242
304
|
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
|
|
243
305
|
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
@@ -10,6 +10,7 @@ from data_designer.engine.model_provider import ModelProviderRegistry
|
|
|
10
10
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
13
14
|
from data_designer.engine.models.registry import ModelRegistry
|
|
14
15
|
|
|
15
16
|
|
|
@@ -18,12 +19,23 @@ def create_model_registry(
|
|
|
18
19
|
model_configs: list[ModelConfig] | None = None,
|
|
19
20
|
secret_resolver: SecretResolver,
|
|
20
21
|
model_provider_registry: ModelProviderRegistry,
|
|
22
|
+
mcp_registry: MCPRegistry | None = None,
|
|
21
23
|
) -> ModelRegistry:
|
|
22
24
|
"""Factory function for creating a ModelRegistry instance.
|
|
23
25
|
|
|
24
26
|
Heavy dependencies (litellm, httpx) are deferred until this function is called.
|
|
25
27
|
This is a factory function pattern - imports inside factories are idiomatic Python
|
|
26
28
|
for lazy initialization.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model_configs: Optional list of model configurations to register.
|
|
32
|
+
secret_resolver: Resolver for secrets referenced in provider configs.
|
|
33
|
+
model_provider_registry: Registry of model provider configurations.
|
|
34
|
+
mcp_registry: Optional MCP registry for tool operations. When provided,
|
|
35
|
+
ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A configured ModelRegistry instance.
|
|
27
39
|
"""
|
|
28
40
|
from data_designer.engine.models.facade import ModelFacade
|
|
29
41
|
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
@@ -32,7 +44,12 @@ def create_model_registry(
|
|
|
32
44
|
apply_litellm_patches()
|
|
33
45
|
|
|
34
46
|
def model_facade_factory(model_config, secret_resolver, model_provider_registry):
|
|
35
|
-
return ModelFacade(
|
|
47
|
+
return ModelFacade(
|
|
48
|
+
model_config,
|
|
49
|
+
secret_resolver,
|
|
50
|
+
model_provider_registry,
|
|
51
|
+
mcp_registry=mcp_registry,
|
|
52
|
+
)
|
|
36
53
|
|
|
37
54
|
return ModelRegistry(
|
|
38
55
|
model_configs=model_configs,
|
|
@@ -3,7 +3,84 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ChatMessage:
|
|
12
|
+
"""A chat message in an LLM conversation.
|
|
13
|
+
|
|
14
|
+
This dataclass represents messages exchanged in a conversation with an LLM,
|
|
15
|
+
supporting various message types including user prompts, assistant responses,
|
|
16
|
+
system instructions, and tool interactions.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
|
|
20
|
+
content: The message content. Can be a string or a list of content blocks
|
|
21
|
+
for multimodal messages (e.g., text + images).
|
|
22
|
+
reasoning_content: Optional reasoning/thinking content from the assistant,
|
|
23
|
+
typically from extended thinking or chain-of-thought models.
|
|
24
|
+
tool_calls: Optional list of tool calls requested by the assistant.
|
|
25
|
+
Each tool call contains 'id', 'type', and 'function' keys.
|
|
26
|
+
tool_call_id: Optional ID linking a tool response to its corresponding
|
|
27
|
+
tool call. Required for messages with role='tool'.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
role: Literal["user", "assistant", "system", "tool"]
|
|
31
|
+
content: str | list[dict[str, Any]] = ""
|
|
32
|
+
reasoning_content: str | None = None
|
|
33
|
+
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
34
|
+
tool_call_id: str | None = None
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
|
37
|
+
"""Convert the message to a dictionary format for API calls.
|
|
38
|
+
|
|
39
|
+
Content is normalized to a list of ChatML-style blocks to keep a
|
|
40
|
+
consistent schema across traces and API payloads.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A dictionary containing the message fields. Only includes non-empty
|
|
44
|
+
optional fields to keep the output clean.
|
|
45
|
+
"""
|
|
46
|
+
result: dict[str, Any] = {"role": self.role, "content": _normalize_content_blocks(self.content)}
|
|
47
|
+
if self.reasoning_content:
|
|
48
|
+
result["reasoning_content"] = self.reasoning_content
|
|
49
|
+
if self.tool_calls:
|
|
50
|
+
result["tool_calls"] = self.tool_calls
|
|
51
|
+
if self.tool_call_id:
|
|
52
|
+
result["tool_call_id"] = self.tool_call_id
|
|
53
|
+
return result
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
|
|
57
|
+
"""Create a user message."""
|
|
58
|
+
return cls(role="user", content=content)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def as_assistant(
|
|
62
|
+
cls,
|
|
63
|
+
content: str = "",
|
|
64
|
+
reasoning_content: str | None = None,
|
|
65
|
+
tool_calls: list[dict[str, Any]] | None = None,
|
|
66
|
+
) -> ChatMessage:
|
|
67
|
+
"""Create an assistant message."""
|
|
68
|
+
return cls(
|
|
69
|
+
role="assistant",
|
|
70
|
+
content=content,
|
|
71
|
+
reasoning_content=reasoning_content,
|
|
72
|
+
tool_calls=tool_calls or [],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def as_system(cls, content: str) -> ChatMessage:
|
|
77
|
+
"""Create a system message."""
|
|
78
|
+
return cls(role="system", content=content)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
|
|
82
|
+
"""Create a tool response message."""
|
|
83
|
+
return cls(role="tool", content=content, tool_call_id=tool_call_id)
|
|
7
84
|
|
|
8
85
|
|
|
9
86
|
def prompt_to_messages(
|
|
@@ -11,28 +88,41 @@ def prompt_to_messages(
|
|
|
11
88
|
user_prompt: str,
|
|
12
89
|
system_prompt: str | None = None,
|
|
13
90
|
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
14
|
-
) -> list[
|
|
15
|
-
"""Convert a user and system prompt into
|
|
91
|
+
) -> list[ChatMessage]:
|
|
92
|
+
"""Convert a user and system prompt into ChatMessage list.
|
|
16
93
|
|
|
17
94
|
Args:
|
|
18
95
|
user_prompt (str): A user prompt.
|
|
19
96
|
system_prompt (str, optional): An optional system prompt.
|
|
20
97
|
"""
|
|
21
|
-
user_content = user_prompt
|
|
22
|
-
if multi_modal_context
|
|
23
|
-
user_content = []
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return (
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
98
|
+
user_content: str | list[dict[str, Any]] = user_prompt
|
|
99
|
+
if multi_modal_context:
|
|
100
|
+
user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
|
|
101
|
+
|
|
102
|
+
if system_prompt:
|
|
103
|
+
return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
|
|
104
|
+
return [ChatMessage.as_user(user_content)]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalize_content_blocks(content: Any) -> list[dict[str, Any]]:
|
|
108
|
+
if isinstance(content, list):
|
|
109
|
+
return [_normalize_content_block(block) for block in content]
|
|
110
|
+
if content is None:
|
|
111
|
+
return []
|
|
112
|
+
return [_text_block(content)]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _normalize_content_block(block: Any) -> dict[str, Any]:
|
|
116
|
+
if isinstance(block, dict) and "type" in block:
|
|
117
|
+
return block
|
|
118
|
+
return _text_block(block)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _text_block(value: Any) -> dict[str, Any]:
|
|
122
|
+
if value is None:
|
|
123
|
+
text_value = ""
|
|
124
|
+
elif isinstance(value, str):
|
|
125
|
+
text_value = value
|
|
126
|
+
else:
|
|
127
|
+
text_value = str(value)
|
|
128
|
+
return {"type": "text", "text": text_value}
|
|
@@ -5,15 +5,21 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from data_designer.config.base import ConfigBase
|
|
7
7
|
from data_designer.config.dataset_metadata import DatasetMetadata
|
|
8
|
+
from data_designer.config.mcp import MCPProviderT, ToolConfig
|
|
8
9
|
from data_designer.config.models import ModelConfig
|
|
9
10
|
from data_designer.config.run_config import RunConfig
|
|
10
11
|
from data_designer.config.seed_source import SeedSource
|
|
11
12
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
12
13
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
13
|
-
from data_designer.engine.
|
|
14
|
+
from data_designer.engine.mcp.factory import create_mcp_registry
|
|
15
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
16
|
+
from data_designer.engine.model_provider import (
|
|
17
|
+
ModelProviderRegistry,
|
|
18
|
+
resolve_mcp_provider_registry,
|
|
19
|
+
)
|
|
14
20
|
from data_designer.engine.models.factory import create_model_registry
|
|
15
21
|
from data_designer.engine.models.registry import ModelRegistry
|
|
16
|
-
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
|
|
22
|
+
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
|
|
17
23
|
from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
|
|
18
24
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
19
25
|
|
|
@@ -28,6 +34,7 @@ class ResourceProvider(ConfigBase):
|
|
|
28
34
|
artifact_storage: ArtifactStorage
|
|
29
35
|
blob_storage: ManagedBlobStorage | None = None
|
|
30
36
|
model_registry: ModelRegistry | None = None
|
|
37
|
+
mcp_registry: MCPRegistry | None = None
|
|
31
38
|
run_config: RunConfig = RunConfig()
|
|
32
39
|
seed_reader: SeedReader | None = None
|
|
33
40
|
|
|
@@ -43,6 +50,31 @@ class ResourceProvider(ConfigBase):
|
|
|
43
50
|
return DatasetMetadata(seed_column_names=seed_column_names)
|
|
44
51
|
|
|
45
52
|
|
|
53
|
+
def _validate_tool_configs_against_providers(
|
|
54
|
+
tool_configs: list[ToolConfig],
|
|
55
|
+
mcp_providers: list[MCPProviderT],
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Validate that all providers referenced in tool configs exist.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
tool_configs: List of tool configurations to validate.
|
|
61
|
+
mcp_providers: List of available MCP provider configurations.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If a tool config references a provider that doesn't exist.
|
|
65
|
+
"""
|
|
66
|
+
available_providers = {p.name for p in mcp_providers}
|
|
67
|
+
|
|
68
|
+
for tc in tool_configs:
|
|
69
|
+
missing_providers = [p for p in tc.providers if p not in available_providers]
|
|
70
|
+
if missing_providers:
|
|
71
|
+
available_list = sorted(available_providers) if available_providers else ["(none configured)"]
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} "
|
|
74
|
+
f"which are not registered. Available providers: {available_list}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
46
78
|
def create_resource_provider(
|
|
47
79
|
*,
|
|
48
80
|
artifact_storage: ArtifactStorage,
|
|
@@ -53,9 +85,31 @@ def create_resource_provider(
|
|
|
53
85
|
blob_storage: ManagedBlobStorage | None = None,
|
|
54
86
|
seed_dataset_source: SeedSource | None = None,
|
|
55
87
|
run_config: RunConfig | None = None,
|
|
88
|
+
mcp_providers: list[MCPProviderT] | None = None,
|
|
89
|
+
tool_configs: list[ToolConfig] | None = None,
|
|
56
90
|
) -> ResourceProvider:
|
|
57
91
|
"""Factory function for creating a ResourceProvider instance.
|
|
92
|
+
|
|
58
93
|
This function triggers lazy loading of heavy dependencies like litellm.
|
|
94
|
+
The creation order is:
|
|
95
|
+
1. MCPProviderRegistry (can be empty)
|
|
96
|
+
2. MCPRegistry with tool_configs
|
|
97
|
+
3. ModelRegistry with mcp_registry
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
artifact_storage: Storage for build artifacts.
|
|
101
|
+
model_configs: List of model configurations.
|
|
102
|
+
secret_resolver: Resolver for secrets.
|
|
103
|
+
model_provider_registry: Registry of model providers.
|
|
104
|
+
seed_reader_registry: Registry of seed readers.
|
|
105
|
+
blob_storage: Optional blob storage for large files.
|
|
106
|
+
seed_dataset_source: Optional source for seed datasets.
|
|
107
|
+
run_config: Optional runtime configuration.
|
|
108
|
+
mcp_providers: Optional list of MCP provider configurations.
|
|
109
|
+
tool_configs: Optional list of tool configurations.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A configured ResourceProvider instance.
|
|
59
113
|
"""
|
|
60
114
|
seed_reader = None
|
|
61
115
|
if seed_dataset_source:
|
|
@@ -64,14 +118,29 @@ def create_resource_provider(
|
|
|
64
118
|
secret_resolver,
|
|
65
119
|
)
|
|
66
120
|
|
|
121
|
+
# Create MCPProviderRegistry first (can be empty)
|
|
122
|
+
mcp_provider_registry = resolve_mcp_provider_registry(mcp_providers)
|
|
123
|
+
|
|
124
|
+
# Create MCPRegistry with tool configs (only if tool_configs provided)
|
|
125
|
+
# Tool validation is performed during dataset builder health checks.
|
|
126
|
+
mcp_registry = None
|
|
127
|
+
if tool_configs:
|
|
128
|
+
mcp_registry = create_mcp_registry(
|
|
129
|
+
tool_configs=tool_configs,
|
|
130
|
+
secret_resolver=secret_resolver,
|
|
131
|
+
mcp_provider_registry=mcp_provider_registry,
|
|
132
|
+
)
|
|
133
|
+
|
|
67
134
|
return ResourceProvider(
|
|
68
135
|
artifact_storage=artifact_storage,
|
|
69
136
|
model_registry=create_model_registry(
|
|
70
137
|
model_configs=model_configs,
|
|
71
138
|
secret_resolver=secret_resolver,
|
|
72
139
|
model_provider_registry=model_provider_registry,
|
|
140
|
+
mcp_registry=mcp_registry,
|
|
73
141
|
),
|
|
74
|
-
blob_storage=blob_storage
|
|
142
|
+
blob_storage=blob_storage,
|
|
143
|
+
mcp_registry=mcp_registry,
|
|
75
144
|
seed_reader=seed_reader,
|
|
76
145
|
run_config=run_config or RunConfig(),
|
|
77
146
|
)
|