data-designer 0.4.0rc2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/cli/commands/list.py +143 -6
- data_designer/cli/commands/mcp.py +13 -0
- data_designer/cli/commands/tools.py +13 -0
- data_designer/cli/controllers/mcp_provider_controller.py +241 -0
- data_designer/cli/controllers/provider_controller.py +6 -2
- data_designer/cli/controllers/tool_controller.py +236 -0
- data_designer/cli/forms/mcp_provider_builder.py +219 -0
- data_designer/cli/forms/tool_builder.py +204 -0
- data_designer/cli/main.py +3 -1
- data_designer/cli/repositories/mcp_provider_repository.py +48 -0
- data_designer/cli/repositories/tool_repository.py +44 -0
- data_designer/cli/services/mcp_provider_service.py +86 -0
- data_designer/cli/services/tool_service.py +83 -0
- data_designer/interface/data_designer.py +27 -0
- {data_designer-0.4.0rc2.dist-info → data_designer-0.5.0rc1.dist-info}/METADATA +3 -3
- {data_designer-0.4.0rc2.dist-info → data_designer-0.5.0rc1.dist-info}/RECORD +18 -9
- data_designer/interface/_version.py +0 -34
- {data_designer-0.4.0rc2.dist-info → data_designer-0.5.0rc1.dist-info}/WHEEL +0 -0
- {data_designer-0.4.0rc2.dist-info → data_designer-0.5.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from data_designer.cli.forms.tool_builder import ToolFormBuilder
|
|
10
|
+
from data_designer.cli.repositories.mcp_provider_repository import MCPProviderRepository
|
|
11
|
+
from data_designer.cli.repositories.tool_repository import ToolRepository
|
|
12
|
+
from data_designer.cli.services.mcp_provider_service import MCPProviderService
|
|
13
|
+
from data_designer.cli.services.tool_service import ToolService
|
|
14
|
+
from data_designer.cli.ui import (
|
|
15
|
+
confirm_action,
|
|
16
|
+
console,
|
|
17
|
+
display_config_preview,
|
|
18
|
+
print_error,
|
|
19
|
+
print_header,
|
|
20
|
+
print_info,
|
|
21
|
+
print_success,
|
|
22
|
+
select_with_arrows,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from data_designer.config.mcp import ToolConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ToolController:
|
|
30
|
+
"""Controller for tool configuration workflows."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config_dir: Path):
|
|
33
|
+
self.config_dir = config_dir
|
|
34
|
+
self.repository = ToolRepository(config_dir)
|
|
35
|
+
self.service = ToolService(self.repository)
|
|
36
|
+
self.mcp_provider_repository = MCPProviderRepository(config_dir)
|
|
37
|
+
self.mcp_provider_service = MCPProviderService(self.mcp_provider_repository)
|
|
38
|
+
|
|
39
|
+
def run(self) -> None:
|
|
40
|
+
"""Main entry point for tool configuration."""
|
|
41
|
+
print_header("Configure Tool Configs")
|
|
42
|
+
|
|
43
|
+
# Check if MCP providers are configured
|
|
44
|
+
available_providers = self._get_available_providers()
|
|
45
|
+
|
|
46
|
+
if not available_providers:
|
|
47
|
+
print_error("No MCP providers available!")
|
|
48
|
+
print_info("Please run 'data-designer config mcp' first to configure MCP providers.")
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
print_info(f"Configuration directory: {self.config_dir}")
|
|
52
|
+
console.print()
|
|
53
|
+
|
|
54
|
+
# Check for existing configuration
|
|
55
|
+
tool_configs = self.service.list_all()
|
|
56
|
+
|
|
57
|
+
if tool_configs:
|
|
58
|
+
self._show_existing_config()
|
|
59
|
+
mode = self._select_mode()
|
|
60
|
+
else:
|
|
61
|
+
print_info("No tool configs configured yet")
|
|
62
|
+
console.print()
|
|
63
|
+
mode = "add"
|
|
64
|
+
|
|
65
|
+
if mode is None:
|
|
66
|
+
print_info("No changes made")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# Execute selected mode
|
|
70
|
+
mode_handlers = {
|
|
71
|
+
"add": self._handle_add,
|
|
72
|
+
"update": self._handle_update,
|
|
73
|
+
"delete": self._handle_delete,
|
|
74
|
+
"delete_all": self._handle_delete_all,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
handler = mode_handlers.get(mode)
|
|
78
|
+
if handler:
|
|
79
|
+
handler(available_providers)
|
|
80
|
+
|
|
81
|
+
def _get_available_providers(self) -> list[str]:
|
|
82
|
+
"""Get list of available MCP providers."""
|
|
83
|
+
return [p.name for p in self.mcp_provider_service.list_all()]
|
|
84
|
+
|
|
85
|
+
def _show_existing_config(self) -> None:
|
|
86
|
+
"""Display current configuration."""
|
|
87
|
+
registry = self.repository.load()
|
|
88
|
+
if not registry:
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
print_info(f"Found {len(registry.tool_configs)} configured tool config(s)")
|
|
92
|
+
console.print()
|
|
93
|
+
|
|
94
|
+
# Display configuration
|
|
95
|
+
config_dict = registry.model_dump(mode="json", exclude_none=True)
|
|
96
|
+
display_config_preview(config_dict, "Current Configuration")
|
|
97
|
+
console.print()
|
|
98
|
+
|
|
99
|
+
def _select_mode(self) -> str | None:
|
|
100
|
+
"""Prompt user to select operation mode."""
|
|
101
|
+
options = {
|
|
102
|
+
"add": "Add a new tool config",
|
|
103
|
+
"update": "Update an existing tool config",
|
|
104
|
+
"delete": "Delete a tool config",
|
|
105
|
+
"delete_all": "Delete all tool configs",
|
|
106
|
+
"exit": "Exit without changes",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
result = select_with_arrows(
|
|
110
|
+
options,
|
|
111
|
+
"What would you like to do?",
|
|
112
|
+
default_key="add",
|
|
113
|
+
allow_back=False,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return None if result == "exit" or result is None else result
|
|
117
|
+
|
|
118
|
+
def _handle_add(self, available_providers: list[str]) -> None:
|
|
119
|
+
"""Handle adding new tool configs."""
|
|
120
|
+
existing_aliases = {c.tool_alias for c in self.service.list_all()}
|
|
121
|
+
|
|
122
|
+
while True:
|
|
123
|
+
# Create builder with current existing aliases
|
|
124
|
+
builder = ToolFormBuilder(existing_aliases, available_providers)
|
|
125
|
+
tool_config = builder.run()
|
|
126
|
+
|
|
127
|
+
if tool_config is None:
|
|
128
|
+
break
|
|
129
|
+
|
|
130
|
+
# Attempt to add
|
|
131
|
+
try:
|
|
132
|
+
self.service.add(tool_config)
|
|
133
|
+
print_success(f"Tool config '{tool_config.tool_alias}' added successfully")
|
|
134
|
+
existing_aliases.add(tool_config.tool_alias)
|
|
135
|
+
except ValueError as e:
|
|
136
|
+
print_error(f"Failed to add tool config: {e}")
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Ask if they want to add more
|
|
140
|
+
if not self._confirm_add_another():
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
def _handle_update(self, available_providers: list[str]) -> None:
|
|
144
|
+
"""Handle updating an existing tool config."""
|
|
145
|
+
tool_configs = self.service.list_all()
|
|
146
|
+
if not tool_configs:
|
|
147
|
+
print_error("No tool configs to update")
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
# Select tool config to update
|
|
151
|
+
selected_alias = self._select_tool_config(tool_configs, "Select tool config to update")
|
|
152
|
+
if selected_alias is None:
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
tool_config = self.service.get_by_alias(selected_alias)
|
|
156
|
+
if not tool_config:
|
|
157
|
+
print_error(f"Tool config '{selected_alias}' not found")
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
# Run builder with existing data
|
|
161
|
+
existing_aliases = {c.tool_alias for c in tool_configs if c.tool_alias != selected_alias}
|
|
162
|
+
builder = ToolFormBuilder(existing_aliases, available_providers)
|
|
163
|
+
initial_data = tool_config.model_dump(mode="json", exclude_none=True)
|
|
164
|
+
updated_config = builder.run(initial_data)
|
|
165
|
+
|
|
166
|
+
if updated_config:
|
|
167
|
+
try:
|
|
168
|
+
self.service.update(selected_alias, updated_config)
|
|
169
|
+
print_success(f"Tool config '{updated_config.tool_alias}' updated successfully")
|
|
170
|
+
except ValueError as e:
|
|
171
|
+
print_error(f"Failed to update tool config: {e}")
|
|
172
|
+
|
|
173
|
+
def _handle_delete(self, available_providers: list[str]) -> None:
|
|
174
|
+
"""Handle deleting a tool config."""
|
|
175
|
+
tool_configs = self.service.list_all()
|
|
176
|
+
if not tool_configs:
|
|
177
|
+
print_error("No tool configs to delete")
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
# Select tool config to delete
|
|
181
|
+
selected_alias = self._select_tool_config(tool_configs, "Select tool config to delete")
|
|
182
|
+
if selected_alias is None:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Confirm deletion
|
|
186
|
+
console.print()
|
|
187
|
+
if confirm_action(f"Delete tool config '{selected_alias}'?", default=False):
|
|
188
|
+
try:
|
|
189
|
+
self.service.delete(selected_alias)
|
|
190
|
+
print_success(f"Tool config '{selected_alias}' deleted successfully")
|
|
191
|
+
except ValueError as e:
|
|
192
|
+
print_error(f"Failed to delete tool config: {e}")
|
|
193
|
+
|
|
194
|
+
def _handle_delete_all(self, available_providers: list[str]) -> None:
|
|
195
|
+
"""Handle deleting all tool configs."""
|
|
196
|
+
tool_configs = self.service.list_all()
|
|
197
|
+
if not tool_configs:
|
|
198
|
+
print_error("No tool configs to delete")
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
# List tool configs to be deleted
|
|
202
|
+
console.print()
|
|
203
|
+
config_count = len(tool_configs)
|
|
204
|
+
config_aliases = ", ".join([f"'{c.tool_alias}'" for c in tool_configs])
|
|
205
|
+
|
|
206
|
+
if confirm_action(
|
|
207
|
+
f"Delete ALL ({config_count}) tool config(s): {config_aliases}?\n This action cannot be undone.",
|
|
208
|
+
default=False,
|
|
209
|
+
):
|
|
210
|
+
try:
|
|
211
|
+
self.repository.delete()
|
|
212
|
+
print_success(f"All ({config_count}) tool config(s) deleted successfully")
|
|
213
|
+
except Exception as e:
|
|
214
|
+
print_error(f"Failed to delete all tool configs: {e}")
|
|
215
|
+
|
|
216
|
+
def _select_tool_config(
|
|
217
|
+
self, tool_configs: list[ToolConfig], prompt: str, default: str | None = None
|
|
218
|
+
) -> str | None:
|
|
219
|
+
"""Helper to select a tool config from list."""
|
|
220
|
+
options = {c.tool_alias: f"{c.tool_alias} (providers: {', '.join(c.providers)})" for c in tool_configs}
|
|
221
|
+
return select_with_arrows(
|
|
222
|
+
options,
|
|
223
|
+
prompt,
|
|
224
|
+
default_key=default or tool_configs[0].tool_alias,
|
|
225
|
+
allow_back=False,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def _confirm_add_another(self) -> bool:
|
|
229
|
+
"""Ask if user wants to add another tool config."""
|
|
230
|
+
result = select_with_arrows(
|
|
231
|
+
{"yes": "Add another tool config", "no": "Finish"},
|
|
232
|
+
"Add another tool config?",
|
|
233
|
+
default_key="no",
|
|
234
|
+
allow_back=False,
|
|
235
|
+
)
|
|
236
|
+
return result == "yes"
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from data_designer.cli.forms.field import TextField
|
|
9
|
+
from data_designer.cli.forms.form import Form
|
|
10
|
+
from data_designer.cli.ui import (
|
|
11
|
+
confirm_action,
|
|
12
|
+
console,
|
|
13
|
+
print_error,
|
|
14
|
+
print_header,
|
|
15
|
+
select_with_arrows,
|
|
16
|
+
)
|
|
17
|
+
from data_designer.cli.utils import validate_url
|
|
18
|
+
from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MCPProviderFormBuilder:
|
|
22
|
+
"""Builds interactive forms for MCP provider configuration.
|
|
23
|
+
|
|
24
|
+
Supports both MCPProvider (remote SSE) and LocalStdioMCPProvider (subprocess).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, existing_names: set[str] | None = None):
|
|
28
|
+
self.title = "MCP Provider Configuration"
|
|
29
|
+
self.existing_names = existing_names or set()
|
|
30
|
+
|
|
31
|
+
def run(self, initial_data: dict[str, Any] | None = None) -> MCPProviderT | None:
|
|
32
|
+
"""Run the interactive MCP provider configuration and return configured object."""
|
|
33
|
+
print_header(self.title)
|
|
34
|
+
|
|
35
|
+
while True:
|
|
36
|
+
# Determine provider type
|
|
37
|
+
if initial_data and initial_data.get("provider_type"):
|
|
38
|
+
provider_type = initial_data["provider_type"]
|
|
39
|
+
else:
|
|
40
|
+
provider_type = self._select_provider_type()
|
|
41
|
+
if provider_type is None:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
# Run appropriate form based on provider type
|
|
45
|
+
if provider_type == "sse":
|
|
46
|
+
result = self._run_sse_form(initial_data)
|
|
47
|
+
else: # stdio
|
|
48
|
+
result = self._run_stdio_form(initial_data)
|
|
49
|
+
|
|
50
|
+
if result is not None:
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
# If form was cancelled, ask if they want to try again
|
|
54
|
+
if not confirm_action("Try a different provider type?", default=False):
|
|
55
|
+
return None
|
|
56
|
+
initial_data = None # Reset for new selection
|
|
57
|
+
|
|
58
|
+
def _select_provider_type(self) -> str | None:
|
|
59
|
+
"""Prompt user to select provider type."""
|
|
60
|
+
options = {
|
|
61
|
+
"sse": "Remote SSE server (connect to existing server)",
|
|
62
|
+
"stdio": "Local stdio subprocess (launch server as subprocess)",
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
console.print()
|
|
66
|
+
return select_with_arrows(
|
|
67
|
+
options,
|
|
68
|
+
"What type of MCP provider?",
|
|
69
|
+
default_key="sse",
|
|
70
|
+
allow_back=True,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvider | None:
|
|
74
|
+
"""Run form for remote SSE provider."""
|
|
75
|
+
fields = [
|
|
76
|
+
TextField(
|
|
77
|
+
"name",
|
|
78
|
+
"MCP provider name",
|
|
79
|
+
default=initial_data.get("name") if initial_data else None,
|
|
80
|
+
required=True,
|
|
81
|
+
validator=self._validate_name,
|
|
82
|
+
),
|
|
83
|
+
TextField(
|
|
84
|
+
"endpoint",
|
|
85
|
+
"SSE endpoint URL",
|
|
86
|
+
default=initial_data.get("endpoint") if initial_data else None,
|
|
87
|
+
required=True,
|
|
88
|
+
validator=self._validate_endpoint,
|
|
89
|
+
),
|
|
90
|
+
TextField(
|
|
91
|
+
"api_key",
|
|
92
|
+
"API key or environment variable name (optional)",
|
|
93
|
+
default=initial_data.get("api_key") if initial_data else None,
|
|
94
|
+
required=False,
|
|
95
|
+
),
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
form = Form("Remote SSE Provider", fields)
|
|
99
|
+
if initial_data:
|
|
100
|
+
form.set_values(initial_data)
|
|
101
|
+
|
|
102
|
+
result = form.prompt_all(allow_back=True)
|
|
103
|
+
if result is None:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
return MCPProvider(
|
|
108
|
+
name=result["name"],
|
|
109
|
+
endpoint=result["endpoint"],
|
|
110
|
+
api_key=result.get("api_key") or None,
|
|
111
|
+
)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
print_error(f"Configuration error: {e}")
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
def _run_stdio_form(self, initial_data: dict[str, Any] | None = None) -> LocalStdioMCPProvider | None:
|
|
117
|
+
"""Run form for local stdio provider."""
|
|
118
|
+
# Convert args list to comma-separated string for display
|
|
119
|
+
args_default = None
|
|
120
|
+
if initial_data and initial_data.get("args"):
|
|
121
|
+
args_default = ",".join(initial_data["args"])
|
|
122
|
+
|
|
123
|
+
# Convert env dict to KEY=VALUE,KEY2=VALUE2 format for display
|
|
124
|
+
env_default = None
|
|
125
|
+
if initial_data and initial_data.get("env"):
|
|
126
|
+
env_default = ",".join(f"{k}={v}" for k, v in initial_data["env"].items())
|
|
127
|
+
|
|
128
|
+
fields = [
|
|
129
|
+
TextField(
|
|
130
|
+
"name",
|
|
131
|
+
"MCP provider name",
|
|
132
|
+
default=initial_data.get("name") if initial_data else None,
|
|
133
|
+
required=True,
|
|
134
|
+
validator=self._validate_name,
|
|
135
|
+
),
|
|
136
|
+
TextField(
|
|
137
|
+
"command",
|
|
138
|
+
"Command to run (e.g., python, node, npx)",
|
|
139
|
+
default=initial_data.get("command") if initial_data else None,
|
|
140
|
+
required=True,
|
|
141
|
+
validator=self._validate_command,
|
|
142
|
+
),
|
|
143
|
+
TextField(
|
|
144
|
+
"args",
|
|
145
|
+
"Arguments (comma-separated, e.g., -m,my_server,--port,8080)",
|
|
146
|
+
default=args_default,
|
|
147
|
+
required=False,
|
|
148
|
+
),
|
|
149
|
+
TextField(
|
|
150
|
+
"env",
|
|
151
|
+
"Environment variables (KEY=VALUE, comma-separated)",
|
|
152
|
+
default=env_default,
|
|
153
|
+
required=False,
|
|
154
|
+
),
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
form = Form("Local Stdio Provider", fields)
|
|
158
|
+
if initial_data:
|
|
159
|
+
form.set_values(
|
|
160
|
+
{
|
|
161
|
+
"name": initial_data.get("name"),
|
|
162
|
+
"command": initial_data.get("command"),
|
|
163
|
+
"args": args_default,
|
|
164
|
+
"env": env_default,
|
|
165
|
+
}
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
result = form.prompt_all(allow_back=True)
|
|
169
|
+
if result is None:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Parse args from comma-separated string
|
|
174
|
+
args: list[str] = []
|
|
175
|
+
if result.get("args"):
|
|
176
|
+
args = [a.strip() for a in result["args"].split(",") if a.strip()]
|
|
177
|
+
|
|
178
|
+
# Parse env from KEY=VALUE format
|
|
179
|
+
env: dict[str, str] = {}
|
|
180
|
+
if result.get("env"):
|
|
181
|
+
for pair in result["env"].split(","):
|
|
182
|
+
pair = pair.strip()
|
|
183
|
+
if "=" in pair:
|
|
184
|
+
key, value = pair.split("=", 1)
|
|
185
|
+
env[key.strip()] = value.strip()
|
|
186
|
+
|
|
187
|
+
return LocalStdioMCPProvider(
|
|
188
|
+
name=result["name"],
|
|
189
|
+
command=result["command"],
|
|
190
|
+
args=args,
|
|
191
|
+
env=env,
|
|
192
|
+
)
|
|
193
|
+
except Exception as e:
|
|
194
|
+
print_error(f"Configuration error: {e}")
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
def _validate_name(self, name: str) -> tuple[bool, str | None]:
|
|
198
|
+
"""Validate MCP provider name."""
|
|
199
|
+
if not name:
|
|
200
|
+
return False, "MCP provider name is required"
|
|
201
|
+
if name in self.existing_names:
|
|
202
|
+
return False, f"MCP provider '{name}' already exists"
|
|
203
|
+
return True, None
|
|
204
|
+
|
|
205
|
+
def _validate_endpoint(self, endpoint: str) -> tuple[bool, str | None]:
|
|
206
|
+
"""Validate endpoint URL."""
|
|
207
|
+
if not endpoint:
|
|
208
|
+
return False, "Endpoint URL is required"
|
|
209
|
+
if not validate_url(endpoint):
|
|
210
|
+
return False, "Invalid URL format (must start with http:// or https://)"
|
|
211
|
+
return True, None
|
|
212
|
+
|
|
213
|
+
def _validate_command(self, command: str) -> tuple[bool, str | None]:
|
|
214
|
+
"""Validate command."""
|
|
215
|
+
if not command:
|
|
216
|
+
return False, "Command is required"
|
|
217
|
+
if not command.strip():
|
|
218
|
+
return False, "Command cannot be empty"
|
|
219
|
+
return True, None
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from data_designer.cli.forms.field import TextField
|
|
9
|
+
from data_designer.cli.forms.form import Form
|
|
10
|
+
from data_designer.cli.ui import (
|
|
11
|
+
confirm_action,
|
|
12
|
+
console,
|
|
13
|
+
print_error,
|
|
14
|
+
print_header,
|
|
15
|
+
print_info,
|
|
16
|
+
select_multiple_with_arrows,
|
|
17
|
+
)
|
|
18
|
+
from data_designer.config.mcp import ToolConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ToolFormBuilder:
|
|
22
|
+
"""Builds interactive forms for tool configuration.
|
|
23
|
+
|
|
24
|
+
This builder uses a custom flow with multi-select for providers
|
|
25
|
+
rather than the standard FormBuilder pattern.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
existing_aliases: set[str] | None = None,
|
|
31
|
+
available_providers: list[str] | None = None,
|
|
32
|
+
):
|
|
33
|
+
self.title = "Tool Configuration"
|
|
34
|
+
self.existing_aliases = existing_aliases or set()
|
|
35
|
+
self.available_providers = available_providers or []
|
|
36
|
+
|
|
37
|
+
def run(self, initial_data: dict[str, Any] | None = None) -> ToolConfig | None:
|
|
38
|
+
"""Run the interactive tool configuration and return configured object."""
|
|
39
|
+
print_header(self.title)
|
|
40
|
+
|
|
41
|
+
while True:
|
|
42
|
+
# Step 1: Get tool alias
|
|
43
|
+
form = self._create_alias_form(initial_data)
|
|
44
|
+
if initial_data:
|
|
45
|
+
form.set_values(initial_data)
|
|
46
|
+
|
|
47
|
+
result = form.prompt_all(allow_back=True)
|
|
48
|
+
if result is None:
|
|
49
|
+
if confirm_action("Cancel configuration?", default=False):
|
|
50
|
+
return None
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
tool_alias = result["tool_alias"]
|
|
54
|
+
|
|
55
|
+
# Step 2: Select providers (multi-select with checkboxes)
|
|
56
|
+
if not self.available_providers:
|
|
57
|
+
print_error("No MCP providers available. Please configure MCP providers first.")
|
|
58
|
+
print_info("Run 'data-designer config mcp' to configure MCP providers.")
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
console.print()
|
|
62
|
+
print_info("Select one or more MCP providers for this tool configuration:")
|
|
63
|
+
console.print()
|
|
64
|
+
|
|
65
|
+
default_providers = initial_data.get("providers") if initial_data else None
|
|
66
|
+
provider_options = {p: p for p in self.available_providers}
|
|
67
|
+
selected_providers = select_multiple_with_arrows(
|
|
68
|
+
provider_options,
|
|
69
|
+
"Select MCP providers (Space to toggle, Enter to confirm):",
|
|
70
|
+
default_keys=default_providers,
|
|
71
|
+
allow_empty=False,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if selected_providers is None:
|
|
75
|
+
if confirm_action("Cancel configuration?", default=False):
|
|
76
|
+
return None
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
if not selected_providers:
|
|
80
|
+
print_error("At least one provider must be selected")
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
# Step 3: Get optional settings
|
|
84
|
+
optional_form = self._create_optional_form(initial_data)
|
|
85
|
+
if initial_data:
|
|
86
|
+
optional_form.set_values(initial_data)
|
|
87
|
+
|
|
88
|
+
optional_result = optional_form.prompt_all(allow_back=True)
|
|
89
|
+
if optional_result is None:
|
|
90
|
+
continue # Go back to start
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
config = self._build_config(tool_alias, selected_providers, optional_result)
|
|
94
|
+
return config
|
|
95
|
+
except Exception as e:
|
|
96
|
+
print_error(f"Configuration error: {e}")
|
|
97
|
+
if not confirm_action("Try again?", default=True):
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
def _create_alias_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
101
|
+
"""Create the form for tool alias."""
|
|
102
|
+
fields = [
|
|
103
|
+
TextField(
|
|
104
|
+
"tool_alias",
|
|
105
|
+
"Tool alias (used to reference this config in columns)",
|
|
106
|
+
default=initial_data.get("tool_alias") if initial_data else None,
|
|
107
|
+
required=True,
|
|
108
|
+
validator=self._validate_alias,
|
|
109
|
+
),
|
|
110
|
+
]
|
|
111
|
+
return Form("Tool Alias", fields)
|
|
112
|
+
|
|
113
|
+
def _create_optional_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
114
|
+
"""Create the form for optional tool settings."""
|
|
115
|
+
# Convert allow_tools list to comma-separated string for display
|
|
116
|
+
allow_tools_default = None
|
|
117
|
+
if initial_data and initial_data.get("allow_tools"):
|
|
118
|
+
allow_tools_default = ", ".join(initial_data["allow_tools"])
|
|
119
|
+
|
|
120
|
+
fields = [
|
|
121
|
+
TextField(
|
|
122
|
+
"allow_tools",
|
|
123
|
+
"Allowed tools (comma-separated, leave empty for all)",
|
|
124
|
+
default=allow_tools_default,
|
|
125
|
+
required=False,
|
|
126
|
+
),
|
|
127
|
+
TextField(
|
|
128
|
+
"max_tool_call_turns",
|
|
129
|
+
"Max tool-calling turns (a turn may execute multiple parallel tools)",
|
|
130
|
+
default=str(initial_data.get("max_tool_call_turns", 5)) if initial_data else "5",
|
|
131
|
+
required=False,
|
|
132
|
+
validator=self._validate_max_tool_call_turns,
|
|
133
|
+
),
|
|
134
|
+
TextField(
|
|
135
|
+
"timeout_sec",
|
|
136
|
+
"Timeout in seconds per tool call (leave empty for no timeout)",
|
|
137
|
+
default=str(initial_data.get("timeout_sec", "")) if initial_data else None,
|
|
138
|
+
required=False,
|
|
139
|
+
validator=self._validate_timeout,
|
|
140
|
+
),
|
|
141
|
+
]
|
|
142
|
+
return Form("Optional Settings", fields)
|
|
143
|
+
|
|
144
|
+
def _validate_alias(self, alias: str) -> tuple[bool, str | None]:
|
|
145
|
+
"""Validate tool alias."""
|
|
146
|
+
if not alias:
|
|
147
|
+
return False, "Tool alias is required"
|
|
148
|
+
if alias in self.existing_aliases:
|
|
149
|
+
return False, f"Tool alias '{alias}' already exists"
|
|
150
|
+
return True, None
|
|
151
|
+
|
|
152
|
+
def _validate_max_tool_call_turns(self, value: str) -> tuple[bool, str | None]:
|
|
153
|
+
"""Validate max_tool_call_turns."""
|
|
154
|
+
if not value:
|
|
155
|
+
return True, None # Will use default
|
|
156
|
+
try:
|
|
157
|
+
int_value = int(value)
|
|
158
|
+
if int_value < 1:
|
|
159
|
+
return False, "Must be at least 1"
|
|
160
|
+
return True, None
|
|
161
|
+
except ValueError:
|
|
162
|
+
return False, "Must be a positive integer"
|
|
163
|
+
|
|
164
|
+
def _validate_timeout(self, value: str) -> tuple[bool, str | None]:
|
|
165
|
+
"""Validate timeout_sec."""
|
|
166
|
+
if not value:
|
|
167
|
+
return True, None # No timeout is valid
|
|
168
|
+
try:
|
|
169
|
+
float_value = float(value)
|
|
170
|
+
if float_value <= 0:
|
|
171
|
+
return False, "Must be greater than 0"
|
|
172
|
+
return True, None
|
|
173
|
+
except ValueError:
|
|
174
|
+
return False, "Must be a positive number"
|
|
175
|
+
|
|
176
|
+
def _build_config(
|
|
177
|
+
self,
|
|
178
|
+
tool_alias: str,
|
|
179
|
+
providers: list[str],
|
|
180
|
+
optional_data: dict[str, Any],
|
|
181
|
+
) -> ToolConfig:
|
|
182
|
+
"""Build ToolConfig from collected data."""
|
|
183
|
+
# Parse allow_tools from comma-separated string
|
|
184
|
+
allow_tools = None
|
|
185
|
+
if optional_data.get("allow_tools"):
|
|
186
|
+
allow_tools = [t.strip() for t in optional_data["allow_tools"].split(",") if t.strip()]
|
|
187
|
+
|
|
188
|
+
# Parse max_tool_call_turns
|
|
189
|
+
max_tool_call_turns = 5
|
|
190
|
+
if optional_data.get("max_tool_call_turns"):
|
|
191
|
+
max_tool_call_turns = int(optional_data["max_tool_call_turns"])
|
|
192
|
+
|
|
193
|
+
# Parse timeout_sec
|
|
194
|
+
timeout_sec = None
|
|
195
|
+
if optional_data.get("timeout_sec"):
|
|
196
|
+
timeout_sec = float(optional_data["timeout_sec"])
|
|
197
|
+
|
|
198
|
+
return ToolConfig(
|
|
199
|
+
tool_alias=tool_alias,
|
|
200
|
+
providers=providers,
|
|
201
|
+
allow_tools=allow_tools if allow_tools else None,
|
|
202
|
+
max_tool_call_turns=max_tool_call_turns,
|
|
203
|
+
timeout_sec=timeout_sec,
|
|
204
|
+
)
|
data_designer/cli/main.py
CHANGED
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import typer
|
|
7
7
|
|
|
8
|
-
from data_designer.cli.commands import download, models, providers, reset
|
|
8
|
+
from data_designer.cli.commands import download, mcp, models, providers, reset, tools
|
|
9
9
|
from data_designer.cli.commands import list as list_cmd
|
|
10
10
|
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
|
|
11
11
|
from data_designer.config.utils.misc import can_run_data_designer_locally
|
|
@@ -31,6 +31,8 @@ config_app = typer.Typer(
|
|
|
31
31
|
)
|
|
32
32
|
config_app.command(name="providers", help="Configure model providers interactively")(providers.providers_command)
|
|
33
33
|
config_app.command(name="models", help="Configure models interactively")(models.models_command)
|
|
34
|
+
config_app.command(name="mcp", help="Configure MCP providers interactively")(mcp.mcp_command)
|
|
35
|
+
config_app.command(name="tools", help="Configure tool configs interactively")(tools.tools_command)
|
|
34
36
|
config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
|
|
35
37
|
config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
|
|
36
38
|
|