xfmr-zem 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/__init__.py +35 -0
- xfmr_zem/cli.py +295 -0
- xfmr_zem/client.py +208 -0
- xfmr_zem/orchestrators/parallel_local.py +92 -0
- xfmr_zem/schemas.py +15 -0
- xfmr_zem/server.py +188 -0
- xfmr_zem/servers/data_juicer/parameter.yaml +17 -0
- xfmr_zem/servers/data_juicer/server.py +113 -0
- xfmr_zem/servers/instruction_gen/parameter.yaml +12 -0
- xfmr_zem/servers/instruction_gen/server.py +90 -0
- xfmr_zem/servers/io/parameter.yaml +10 -0
- xfmr_zem/servers/io/server.py +95 -0
- xfmr_zem/servers/llm/server.py +47 -0
- xfmr_zem/servers/nemo_curator/parameter.yaml +17 -0
- xfmr_zem/servers/nemo_curator/server.py +118 -0
- xfmr_zem/servers/profiler/server.py +76 -0
- xfmr_zem/servers/sinks/server.py +48 -0
- xfmr_zem/zenml_wrapper.py +203 -0
- xfmr_zem-0.2.0.dist-info/METADATA +152 -0
- xfmr_zem-0.2.0.dist-info/RECORD +23 -0
- xfmr_zem-0.2.0.dist-info/WHEEL +4 -0
- xfmr_zem-0.2.0.dist-info/entry_points.txt +3 -0
- xfmr_zem-0.2.0.dist-info/licenses/LICENSE +201 -0
xfmr_zem/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
xfmr-zem (Zem)
|
|
3
|
+
==============
|
|
4
|
+
|
|
5
|
+
A unified data pipeline framework combining:
|
|
6
|
+
- Model Context Protocol (MCP): For modular, specialized processing servers.
|
|
7
|
+
- ZenML: For production-grade orchestration and pipeline tracking.
|
|
8
|
+
|
|
9
|
+
xfmr-zem allows you to build complex data processing workflows by connecting
|
|
10
|
+
multiple MCP servers as pipeline steps, all orchestrated by ZenML.
|
|
11
|
+
|
|
12
|
+
Key Features:
|
|
13
|
+
* **Config-Driven Architecture**: Define your entire pipeline in a simple YAML configuration.
|
|
14
|
+
* **MCP Server Integration**: Leverage any MCP-compatible server as a processing block.
|
|
15
|
+
* **ZenML Orchestration**: Production-grade tracking, caching, and visualization of data flows.
|
|
16
|
+
* **Multi-Domain Ready**: Designed for modular tasks like curation, extraction, and filtering.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
from xfmr_zem import PipelineClient
|
|
20
|
+
|
|
21
|
+
# Initialize client with a pipeline configuration
|
|
22
|
+
client = PipelineClient("configs/medical_pipeline.yaml")
|
|
23
|
+
|
|
24
|
+
# Build and execute the ZenML pipeline
|
|
25
|
+
client.run()
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
__version__ = "0.1.0"
|
|
29
|
+
__author__ = "Khai Hoang"
|
|
30
|
+
|
|
31
|
+
from xfmr_zem.client import PipelineClient
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"PipelineClient",
|
|
35
|
+
]
|
xfmr_zem/cli.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI for Zem - Unified Data Pipeline Framework (MCP + ZenML)
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import click
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.table import Table
|
|
9
|
+
from loguru import logger
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
from xfmr_zem.client import PipelineClient
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@click.group()
|
|
19
|
+
@click.version_option(version="0.1.0")
|
|
20
|
+
def main():
|
|
21
|
+
"""Zem CLI - ZenML + MCP (NeMo Curator & DataJuicer)"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@main.command()
|
|
26
|
+
def info():
|
|
27
|
+
"""Show framework information"""
|
|
28
|
+
console.print("[bold blue]Zem: Unified Data Pipeline Framework[/bold blue]")
|
|
29
|
+
console.print("Version: 0.1.0")
|
|
30
|
+
console.print("\nArchitecture: [green]Model Context Protocol (MCP) + ZenML[/green]")
|
|
31
|
+
console.print("\nIntegrations:")
|
|
32
|
+
console.print(" - [bold]ZenML[/bold]: Orchestration, Visualization & Artifact Tracking")
|
|
33
|
+
console.print(" - [bold]MCP Servers[/bold]: Standalone units for domain-specific logic")
|
|
34
|
+
console.print(" - [bold]NeMo Curator[/bold]: NVIDIA's high-performance curation")
|
|
35
|
+
console.print(" - [bold]DataJuicer[/bold]: Comprehensive data processing operators")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@main.command(name="list-tools")
|
|
39
|
+
@click.option("--config", "-c", type=click.Path(exists=True), help="Config file to discover tools from")
|
|
40
|
+
def list_tools(config):
|
|
41
|
+
"""List available MCP tools dynamically from servers"""
|
|
42
|
+
if not config:
|
|
43
|
+
# Fallback to hardcoded list if no config provided (legacy behavior)
|
|
44
|
+
console.print("[yellow]Hint: Provide a config file to see dynamic tool list: zem list-tools -c your_config.yaml[/yellow]")
|
|
45
|
+
_print_static_operators()
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
client = PipelineClient(config)
|
|
50
|
+
all_tools = client.discover_tools()
|
|
51
|
+
|
|
52
|
+
for srv_name, tools in all_tools.items():
|
|
53
|
+
console.print(f"\n[bold magenta]{srv_name} Server Tools:[/bold magenta]")
|
|
54
|
+
table = Table(show_header=True, header_style="bold cyan")
|
|
55
|
+
table.add_column("Tool Name")
|
|
56
|
+
table.add_column("Description")
|
|
57
|
+
|
|
58
|
+
for tool in tools:
|
|
59
|
+
table.add_row(tool.get("name", "N/A"), tool.get("description", "No description"))
|
|
60
|
+
console.print(table)
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
console.print(f"[bold red]Error discovering tools:[/bold red] {e}")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@main.command()
|
|
67
|
+
@click.argument("project_name")
|
|
68
|
+
def init(project_name: str):
|
|
69
|
+
"""Bootstrap a new Zem project structure."""
|
|
70
|
+
base_path = Path(project_name)
|
|
71
|
+
if base_path.exists():
|
|
72
|
+
console.print(f"[bold red]Error:[/bold red] Path '{project_name}' already exists.")
|
|
73
|
+
sys.exit(1)
|
|
74
|
+
|
|
75
|
+
# Create directories
|
|
76
|
+
(base_path / "servers").mkdir(parents=True)
|
|
77
|
+
(base_path / "tests/manual").mkdir(parents=True)
|
|
78
|
+
(base_path / "data").mkdir(parents=True)
|
|
79
|
+
|
|
80
|
+
# Create sample server
|
|
81
|
+
sample_server_py = """from xfmr_zem.server import ZemServer
|
|
82
|
+
from typing import Any, List
|
|
83
|
+
|
|
84
|
+
# Initialize the sample server
|
|
85
|
+
mcp = ZemServer("SampleAgent")
|
|
86
|
+
|
|
87
|
+
@mcp.tool()
|
|
88
|
+
def hello_world(data: Any) -> List[Any]:
|
|
89
|
+
\"\"\"
|
|
90
|
+
A simple tool that adds a 'greeting' field to each record.
|
|
91
|
+
\"\"\"
|
|
92
|
+
dataset = mcp.get_data(data)
|
|
93
|
+
for item in dataset:
|
|
94
|
+
item["greeting"] = "Hello from your standalone Zem project!"
|
|
95
|
+
return dataset
|
|
96
|
+
|
|
97
|
+
if __name__ == "__main__":
|
|
98
|
+
mcp.run()
|
|
99
|
+
"""
|
|
100
|
+
(base_path / "servers" / "sample_server.py").write_text(sample_server_py)
|
|
101
|
+
|
|
102
|
+
# Create sample pipeline
|
|
103
|
+
pipeline_yaml = f"""name: {project_name}_pipeline
|
|
104
|
+
|
|
105
|
+
servers:
|
|
106
|
+
agent: servers/sample_server.py
|
|
107
|
+
|
|
108
|
+
pipeline:
|
|
109
|
+
- name: my_first_step
|
|
110
|
+
agent.hello_world:
|
|
111
|
+
input:
|
|
112
|
+
data: [{{"text": "Zem is awesome!"}}]
|
|
113
|
+
"""
|
|
114
|
+
(base_path / "pipeline.yaml").write_text(pipeline_yaml)
|
|
115
|
+
|
|
116
|
+
console.print(f"[bold green]Success![/bold green] Project '{project_name}' initialized.")
|
|
117
|
+
console.print(f"Created standalone sample server: [cyan]{project_name}/servers/sample_server.py[/cyan]")
|
|
118
|
+
console.print(f"Next steps:\n cd {project_name}\n zem list-tools -c pipeline.yaml\n zem run pipeline.yaml")
|
|
119
|
+
|
|
120
|
+
@main.command()
|
|
121
|
+
def operators():
|
|
122
|
+
"""List available MCP tools (Static legacy list)"""
|
|
123
|
+
_print_static_operators()
|
|
124
|
+
|
|
125
|
+
def _print_static_operators():
|
|
126
|
+
# NeMo Curator Tools
|
|
127
|
+
console.print("\n[bold magenta]NeMo Curator Server Tools:[/bold magenta]")
|
|
128
|
+
nemo_table = Table(show_header=True, header_style="bold cyan")
|
|
129
|
+
nemo_table.add_column("Tool Name")
|
|
130
|
+
nemo_table.add_column("Description")
|
|
131
|
+
|
|
132
|
+
nemo_table.add_row("pii_removal", "Remove PII using NeMo Curator")
|
|
133
|
+
nemo_table.add_row("text_cleaning", "General text cleaning using NeMo Curator")
|
|
134
|
+
console.print(nemo_table)
|
|
135
|
+
|
|
136
|
+
# DataJuicer Tools
|
|
137
|
+
console.print("\n[bold magenta]DataJuicer Server Tools:[/bold magenta]")
|
|
138
|
+
dj_table = Table(show_header=True, header_style="bold cyan")
|
|
139
|
+
dj_table.add_column("Tool Name")
|
|
140
|
+
dj_table.add_column("Description")
|
|
141
|
+
|
|
142
|
+
dj_table.add_row("clean_html", "Remove HTML tags")
|
|
143
|
+
dj_table.add_row("clean_links", "Remove URLs/Links")
|
|
144
|
+
dj_table.add_row("fix_unicode", "Normalize Unicode (NFKC)")
|
|
145
|
+
dj_table.add_row("whitespace_normalization", "Clean extra spaces/newlines")
|
|
146
|
+
dj_table.add_row("text_length_filter", "Filter by character length")
|
|
147
|
+
dj_table.add_row("language_filter", "Heuristic-based language filtering")
|
|
148
|
+
dj_table.add_row("document_simhash_dedup", "Simple SimHash-based deduplication")
|
|
149
|
+
console.print(dj_table)
|
|
150
|
+
|
|
151
|
+
# IO Tools
|
|
152
|
+
console.print("\n[bold magenta]IO Server Tools (File Handling):[/bold magenta]")
|
|
153
|
+
io_table = Table(show_header=True, header_style="bold cyan")
|
|
154
|
+
io_table.add_column("Tool Name")
|
|
155
|
+
io_table.add_column("Description")
|
|
156
|
+
|
|
157
|
+
io_table.add_row("load_jsonl", "Load data from JSONL file")
|
|
158
|
+
io_table.add_row("save_jsonl", "Save data to JSONL file")
|
|
159
|
+
io_table.add_row("load_csv", "Load data from CSV file")
|
|
160
|
+
io_table.add_row("save_csv", "Save data to CSV file")
|
|
161
|
+
console.print(io_table)
|
|
162
|
+
|
|
163
|
+
# Profiler Tools
|
|
164
|
+
console.print("\n[bold magenta]Profiler Server Tools:[/bold magenta]")
|
|
165
|
+
prof_table = Table(show_header=True, header_style="bold cyan")
|
|
166
|
+
prof_table.add_column("Tool Name")
|
|
167
|
+
prof_table.add_column("Description")
|
|
168
|
+
|
|
169
|
+
prof_table.add_row("profile_data", "Generate summary & metrics for input data")
|
|
170
|
+
console.print(prof_table)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@main.command()
|
|
174
|
+
@click.argument("config_file", type=click.Path(exists=True))
|
|
175
|
+
@click.option("--params", "-p", type=click.Path(exists=True), help="Path to custom parameters.yml")
|
|
176
|
+
def run(config_file, params):
|
|
177
|
+
"""Run a pipeline from a YAML configuration file"""
|
|
178
|
+
abs_config = os.path.abspath(config_file)
|
|
179
|
+
console.print(f"[bold green]Starting Pipeline:[/bold green] {abs_config}")
|
|
180
|
+
if params:
|
|
181
|
+
console.print(f"[bold blue]Custom Parameters:[/bold blue] {params}")
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
client = PipelineClient(abs_config, params_path=params)
|
|
185
|
+
run_response = client.run()
|
|
186
|
+
|
|
187
|
+
console.print(f"\n[bold blue]Pipeline Execution Finished![/bold blue]")
|
|
188
|
+
console.print(f"Run Name: [cyan]{run_response.name}[/cyan]")
|
|
189
|
+
console.print(f"Status: [yellow]{run_response.status}[/yellow]")
|
|
190
|
+
|
|
191
|
+
console.print(f"\n[dim]To visualize this run, ensure ZenML dashboard is running:[/dim]")
|
|
192
|
+
console.print(f"[dim]uv run zenml up --port 8871[/dim]")
|
|
193
|
+
console.print(f"[dim]Or view runs via: zem dashboard[/dim]") # Future proofing hint
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
console.print(f"\n[bold red]Pipeline Failed:[/bold red] {e}")
|
|
197
|
+
if os.path.exists("/tmp/zenml_error.log"):
|
|
198
|
+
console.print("\n[bold yellow]Error log snippet (/tmp/zenml_error.log):[/bold yellow]")
|
|
199
|
+
with open("/tmp/zenml_error.log", "r") as f:
|
|
200
|
+
console.print(f.read())
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@main.command()
|
|
204
|
+
def dashboard():
|
|
205
|
+
"""Open the ZenML dashboard."""
|
|
206
|
+
import subprocess
|
|
207
|
+
import webbrowser
|
|
208
|
+
console.print("[bold blue]Checking ZenML Dashboard...[/bold blue]")
|
|
209
|
+
# Default ZenML port used in this project
|
|
210
|
+
url = "http://127.0.0.1:8871"
|
|
211
|
+
console.print(f"URL: [link={url}]{url}[/link]")
|
|
212
|
+
try:
|
|
213
|
+
webbrowser.open(url)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
console.print(f"[yellow]Could not open browser automatically: {e}[/yellow]")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@main.command()
|
|
219
|
+
@click.argument("artifact_id")
|
|
220
|
+
@click.option("--id2", help="Secondary artifact ID for comparison (diff mode)")
|
|
221
|
+
@click.option("--limit", "-n", default=10, help="Number of rows to preview")
|
|
222
|
+
@click.option("--sample", is_flag=True, help="Show a random sample instead of the head")
|
|
223
|
+
def preview(artifact_id, id2, limit, sample):
|
|
224
|
+
"""Preview a ZenML artifact (supports diff mode and sampling)"""
|
|
225
|
+
from zenml.client import Client
|
|
226
|
+
import pandas as pd
|
|
227
|
+
import json
|
|
228
|
+
|
|
229
|
+
def load_art_df(uid):
|
|
230
|
+
art = Client().get_artifact_version(uid)
|
|
231
|
+
d = art.load()
|
|
232
|
+
if isinstance(d, dict) and "path" in d:
|
|
233
|
+
p = d["path"]
|
|
234
|
+
ext = os.path.splitext(p)[1].lower()
|
|
235
|
+
if ext == ".parquet": return pd.read_parquet(p)
|
|
236
|
+
elif ext == ".csv": return pd.read_csv(p)
|
|
237
|
+
elif ext == ".jsonl":
|
|
238
|
+
with open(p, "r") as f: lines = [json.loads(l) for l in f]
|
|
239
|
+
return pd.DataFrame(lines)
|
|
240
|
+
elif isinstance(d, list): return pd.DataFrame(d)
|
|
241
|
+
elif isinstance(d, pd.DataFrame): return d
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
df1 = load_art_df(artifact_id)
|
|
246
|
+
if df1 is None:
|
|
247
|
+
console.print("[bold red]Error:[/bold red] Could not load artifact as tabular data.")
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
if id2:
|
|
251
|
+
df2 = load_art_df(id2)
|
|
252
|
+
if df2 is None:
|
|
253
|
+
console.print("[bold red]Error:[/bold red] Could not load second artifact.")
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
console.print(f"[bold blue]Comparing Artifacts:[/bold blue] {artifact_id} vs {id2}")
|
|
257
|
+
# Simple column/row diff
|
|
258
|
+
cols1, cols2 = set(df1.columns), set(df2.columns)
|
|
259
|
+
console.print(f" Rows: {len(df1)} -> {len(df2)} ({len(df2)-len(df1):+d})")
|
|
260
|
+
console.print(f" Cols: {len(df1.columns)} -> {len(df2.columns)}")
|
|
261
|
+
if cols1 != cols2:
|
|
262
|
+
added = cols2 - cols1
|
|
263
|
+
removed = cols1 - cols2
|
|
264
|
+
if added: console.print(f" [green]+ Added columns:[/green] {added}")
|
|
265
|
+
if removed: console.print(f" [red]- Removed columns:[/red] {removed}")
|
|
266
|
+
|
|
267
|
+
# Show a few sample rows from both for side-by-side or sequential feel
|
|
268
|
+
console.print("\n[bold magenta]Diff Sample (df2):[/bold magenta]")
|
|
269
|
+
df_to_show = df2
|
|
270
|
+
else:
|
|
271
|
+
df_to_show = df1
|
|
272
|
+
|
|
273
|
+
if sample and len(df_to_show) > limit:
|
|
274
|
+
preview_df = df_to_show.sample(limit)
|
|
275
|
+
else:
|
|
276
|
+
preview_df = df_to_show.head(limit)
|
|
277
|
+
|
|
278
|
+
table = Table(show_header=True, header_style="bold magenta", title=f"Preview ({'Sample' if sample else 'Head'} {limit} rows)")
|
|
279
|
+
for col in preview_df.columns: table.add_column(str(col))
|
|
280
|
+
for _, row in preview_df.iterrows():
|
|
281
|
+
row_values = []
|
|
282
|
+
for val in row:
|
|
283
|
+
v_str = str(val)
|
|
284
|
+
if len(v_str) > 100: v_str = v_str[:97] + "..."
|
|
285
|
+
row_values.append(v_str)
|
|
286
|
+
table.add_row(*row_values)
|
|
287
|
+
console.print(table)
|
|
288
|
+
|
|
289
|
+
except Exception as e:
|
|
290
|
+
console.print(f"[bold red]Error previewing artifact:[/bold red] {e}")
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
if __name__ == "__main__":
|
|
294
|
+
main()
|
|
295
|
+
|
xfmr_zem/client.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Dict, Any, List
|
|
3
|
+
import yaml
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from zenml import pipeline
|
|
6
|
+
from .zenml_wrapper import mcp_generic_step
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
from .schemas import ZemConfig
|
|
11
|
+
from .zenml_wrapper import mcp_generic_step, list_mcp_tools
|
|
12
|
+
|
|
13
|
+
class PipelineClient:
|
|
14
|
+
"""
|
|
15
|
+
Client to run Zem pipelines using MCP servers and ZenML orchestration.
|
|
16
|
+
"""
|
|
17
|
+
def __init__(self, config_path: str, params_path: str = None):
|
|
18
|
+
self.config_path = Path(config_path)
|
|
19
|
+
self.params_path = params_path
|
|
20
|
+
self.params = {}
|
|
21
|
+
self.config_dict = self._load_config_dict(self.config_path)
|
|
22
|
+
|
|
23
|
+
# 6. Validate with Pydantic
|
|
24
|
+
self.config = ZemConfig(**self.config_dict)
|
|
25
|
+
self.server_configs = self._load_server_configs()
|
|
26
|
+
|
|
27
|
+
def _load_params(self, params_path: str = None) -> Dict[str, Any]:
|
|
28
|
+
"""Load parameters from default or specified path."""
|
|
29
|
+
params = {}
|
|
30
|
+
default_params_path = self.config_path.parent / "parameters.yml"
|
|
31
|
+
if default_params_path.exists():
|
|
32
|
+
with open(default_params_path, "r") as f:
|
|
33
|
+
params.update(yaml.safe_load(f) or {})
|
|
34
|
+
|
|
35
|
+
if params_path:
|
|
36
|
+
p_path = Path(params_path)
|
|
37
|
+
if p_path.exists():
|
|
38
|
+
with open(p_path, "r") as f:
|
|
39
|
+
params.update(yaml.safe_load(f) or {})
|
|
40
|
+
return params
|
|
41
|
+
|
|
42
|
+
def _load_config_dict(self, path: Path) -> Dict[str, Any]:
|
|
43
|
+
"""Load YAML config and perform substitution."""
|
|
44
|
+
with open(path, "r") as f:
|
|
45
|
+
raw_content = f.read()
|
|
46
|
+
|
|
47
|
+
self.params = self._load_params(None)
|
|
48
|
+
preliminary_dict = yaml.safe_load(raw_content) or {}
|
|
49
|
+
internal_params = preliminary_dict.get("parameters", {})
|
|
50
|
+
if internal_params:
|
|
51
|
+
self.params.update(internal_params)
|
|
52
|
+
|
|
53
|
+
if self.params_path:
|
|
54
|
+
custom_params = self._load_params(self.params_path)
|
|
55
|
+
self.params.update(custom_params)
|
|
56
|
+
|
|
57
|
+
content = raw_content
|
|
58
|
+
for key, value in self.params.items():
|
|
59
|
+
content = content.replace(f"{{{{ {key} }}}}", str(value))
|
|
60
|
+
content = content.replace(f"{{{{{key}}}}}", str(value))
|
|
61
|
+
|
|
62
|
+
return yaml.safe_load(content)
|
|
63
|
+
|
|
64
|
+
def _load_server_configs(self) -> Dict[str, Any]:
|
|
65
|
+
servers = self.config.servers
|
|
66
|
+
configs = {}
|
|
67
|
+
for name, path_str in servers.items():
|
|
68
|
+
# 1. Try relative to config file (User's project)
|
|
69
|
+
abs_path = (self.config_path.parent / path_str).resolve()
|
|
70
|
+
|
|
71
|
+
# 2. If it doesn't exist AND starts with "servers/", check internal package
|
|
72
|
+
if not abs_path.exists() and path_str.startswith("servers/"):
|
|
73
|
+
package_root = Path(__file__).parent.resolve()
|
|
74
|
+
abs_path = (package_root / path_str / "server.py").resolve()
|
|
75
|
+
|
|
76
|
+
# 3. If it still doesn't exist, try relative to project root
|
|
77
|
+
if not abs_path.exists():
|
|
78
|
+
project_root = Path(__file__).parent.parent.parent.resolve()
|
|
79
|
+
abs_path = (project_root / path_str).resolve()
|
|
80
|
+
|
|
81
|
+
# 4. If it's a directory, append default filename
|
|
82
|
+
if abs_path.exists() and abs_path.is_dir():
|
|
83
|
+
abs_path = (abs_path / "server.py").resolve()
|
|
84
|
+
|
|
85
|
+
env = os.environ.copy()
|
|
86
|
+
src_path = str(Path(__file__).parent.parent.resolve())
|
|
87
|
+
current_pythonpath = env.get("PYTHONPATH", "")
|
|
88
|
+
env["PYTHONPATH"] = f"{src_path}:{current_pythonpath}" if current_pythonpath else src_path
|
|
89
|
+
|
|
90
|
+
server_specific_params = {}
|
|
91
|
+
prefix = f"{name}."
|
|
92
|
+
for key, value in self.params.items():
|
|
93
|
+
if key.startswith(prefix):
|
|
94
|
+
server_specific_params[key[len(prefix):]] = value
|
|
95
|
+
else:
|
|
96
|
+
server_specific_params[key] = value
|
|
97
|
+
|
|
98
|
+
env["ZEM_PARAMETERS"] = yaml.dump(server_specific_params)
|
|
99
|
+
configs[name] = {
|
|
100
|
+
"command": sys.executable,
|
|
101
|
+
"args": [str(abs_path)],
|
|
102
|
+
"env": env
|
|
103
|
+
}
|
|
104
|
+
return configs
|
|
105
|
+
|
|
106
|
+
def discover_tools(self) -> Dict[str, List[Dict[str, Any]]]:
|
|
107
|
+
"""Fetch tools from all registered servers."""
|
|
108
|
+
all_tools = {}
|
|
109
|
+
for name, cfg in self.server_configs.items():
|
|
110
|
+
all_tools[name] = list_mcp_tools(cfg["command"], cfg["args"], cfg["env"])
|
|
111
|
+
return all_tools
|
|
112
|
+
|
|
113
|
+
def run(self):
|
|
114
|
+
"""Build and run the ZenML pipeline."""
|
|
115
|
+
pipeline_steps = self.config.pipeline
|
|
116
|
+
server_configs = self.server_configs
|
|
117
|
+
pipeline_name = self.config.name
|
|
118
|
+
|
|
119
|
+
@pipeline(name=pipeline_name, enable_cache=False)
|
|
120
|
+
def dynamic_generated_pipeline(pipeline_params: Dict[str, Any]):
|
|
121
|
+
step_outputs = {}
|
|
122
|
+
last_output = {}
|
|
123
|
+
|
|
124
|
+
for i, p_step in enumerate(pipeline_steps):
|
|
125
|
+
step_def = p_step.root
|
|
126
|
+
srv = ""
|
|
127
|
+
tool = ""
|
|
128
|
+
tool_args = {}
|
|
129
|
+
step_alias = f"step_{i}"
|
|
130
|
+
|
|
131
|
+
if isinstance(step_def, str):
|
|
132
|
+
srv, tool = step_def.split(".")
|
|
133
|
+
elif isinstance(step_def, dict):
|
|
134
|
+
# Check for name at top level or inside the tool dict
|
|
135
|
+
step_alias = step_def.get("name")
|
|
136
|
+
|
|
137
|
+
# Exclude control keywords
|
|
138
|
+
control_keys = ["name", "cache"]
|
|
139
|
+
keys = [k for k in step_def.keys() if k not in control_keys]
|
|
140
|
+
if not keys: continue
|
|
141
|
+
key = keys[0]
|
|
142
|
+
|
|
143
|
+
if "." not in key:
|
|
144
|
+
# Might be another control key or misconfig
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
srv, tool = key.split(".")
|
|
148
|
+
|
|
149
|
+
if not step_alias:
|
|
150
|
+
step_alias = step_def[key].get("name")
|
|
151
|
+
|
|
152
|
+
step_alias = step_alias or f"{srv}.{tool}.{i}"
|
|
153
|
+
tool_args = step_def[key].get("input", {}) or {}
|
|
154
|
+
|
|
155
|
+
# Smart Parallelization & DAG Logic:
|
|
156
|
+
# 1. By default, a step is a root (None) unless it has no 'data' input,
|
|
157
|
+
# in which case it inherits from the previous step (linear chain).
|
|
158
|
+
# 2. If 'data' is a reference ($step), it depends on that specific step.
|
|
159
|
+
|
|
160
|
+
current_prev_output = None
|
|
161
|
+
has_explicit_data = "data" in tool_args
|
|
162
|
+
|
|
163
|
+
if not has_explicit_data:
|
|
164
|
+
# No data provided? Inherit from the last executed step to keep simple sequences working
|
|
165
|
+
current_prev_output = last_output
|
|
166
|
+
else:
|
|
167
|
+
# Data provided? Check if it's a reference or raw data
|
|
168
|
+
for k, v in list(tool_args.items()):
|
|
169
|
+
if isinstance(v, str) and v.startswith("$"):
|
|
170
|
+
target_step = v[1:]
|
|
171
|
+
if target_step in step_outputs:
|
|
172
|
+
if k == "data":
|
|
173
|
+
current_prev_output = step_outputs[target_step]
|
|
174
|
+
del tool_args[k]
|
|
175
|
+
else:
|
|
176
|
+
# Limitation: ZenML doesn't materialize artifacts nested in dicts
|
|
177
|
+
print(f"[Warning] Tool argument '{k}' uses a step reference '{v}'. "
|
|
178
|
+
"Currently, only the 'data' field supports cross-step dependencies. "
|
|
179
|
+
"This value will be passed as a raw string.")
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Step reference '{v}' not found in previous steps. Available: {list(step_outputs.keys())}")
|
|
182
|
+
|
|
183
|
+
# 3. Adaptive Caching:
|
|
184
|
+
# Check for 'cache' at top level
|
|
185
|
+
enable_cache = step_def.get("cache", True) if isinstance(step_def, dict) else True
|
|
186
|
+
|
|
187
|
+
from zenml import step as zenml_step
|
|
188
|
+
unique_step_name = f"{srv}.{tool}.{i}"
|
|
189
|
+
dynamic_step = zenml_step(
|
|
190
|
+
mcp_generic_step.entrypoint,
|
|
191
|
+
name=unique_step_name,
|
|
192
|
+
enable_cache=enable_cache
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
step_output = dynamic_step(
|
|
196
|
+
server_name=srv,
|
|
197
|
+
tool_name=tool,
|
|
198
|
+
server_config=server_configs.get(srv, {}),
|
|
199
|
+
tool_args=tool_args,
|
|
200
|
+
previous_output=current_prev_output
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
step_outputs[step_alias] = step_output
|
|
204
|
+
last_output = step_output
|
|
205
|
+
|
|
206
|
+
return last_output
|
|
207
|
+
|
|
208
|
+
return dynamic_generated_pipeline(pipeline_params=self.params)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from zenml.enums import ExecutionMode
|
|
6
|
+
from zenml.logger import get_logger
|
|
7
|
+
from zenml.orchestrators import (
|
|
8
|
+
BaseOrchestrator,
|
|
9
|
+
BaseOrchestratorConfig,
|
|
10
|
+
BaseOrchestratorFlavor,
|
|
11
|
+
)
|
|
12
|
+
from zenml.orchestrators.dag_runner import ThreadedDagRunner
|
|
13
|
+
from zenml.utils import string_utils
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from zenml.models import PipelineRunResponse, PipelineSnapshotResponse
|
|
17
|
+
from zenml.stack import Stack
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
class ParallelLocalOrchestrator(BaseOrchestrator):
|
|
22
|
+
"""Orchestrator responsible for running pipelines locally in parallel."""
|
|
23
|
+
_orchestrator_run_id: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
def submit_pipeline(
|
|
26
|
+
self,
|
|
27
|
+
snapshot: "PipelineSnapshotResponse",
|
|
28
|
+
stack: "Stack",
|
|
29
|
+
base_environment: Dict[str, str],
|
|
30
|
+
step_environments: Dict[str, Dict[str, str]],
|
|
31
|
+
placeholder_run: Optional["PipelineRunResponse"] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Submits a pipeline to the orchestrator."""
|
|
34
|
+
self._orchestrator_run_id = str(uuid4())
|
|
35
|
+
start_time = time.time()
|
|
36
|
+
|
|
37
|
+
# Build DAG
|
|
38
|
+
dag = {
|
|
39
|
+
step_name: step.spec.upstream_steps
|
|
40
|
+
for step_name, step in snapshot.step_configurations.items()
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def run_step_wrapper(step_name: str) -> None:
|
|
44
|
+
step = snapshot.step_configurations[step_name]
|
|
45
|
+
self.run_step(step=step)
|
|
46
|
+
|
|
47
|
+
# Use ThreadedDagRunner for parallel execution
|
|
48
|
+
dag_runner = ThreadedDagRunner(
|
|
49
|
+
dag=dag,
|
|
50
|
+
run_fn=run_step_wrapper
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
logger.info("Starting parallel local execution...")
|
|
54
|
+
dag_runner.run()
|
|
55
|
+
|
|
56
|
+
run_duration = time.time() - start_time
|
|
57
|
+
logger.info(
|
|
58
|
+
"Parallel pipeline run has finished in `%s`.",
|
|
59
|
+
string_utils.get_human_readable_time(run_duration),
|
|
60
|
+
)
|
|
61
|
+
self._orchestrator_run_id = None
|
|
62
|
+
|
|
63
|
+
def get_orchestrator_run_id(self) -> str:
|
|
64
|
+
"""Returns the active orchestrator run id."""
|
|
65
|
+
if not self._orchestrator_run_id:
|
|
66
|
+
raise RuntimeError("No run id set.")
|
|
67
|
+
return self._orchestrator_run_id
|
|
68
|
+
|
|
69
|
+
class ParallelLocalOrchestratorConfig(BaseOrchestratorConfig):
|
|
70
|
+
"""Parallel local orchestrator config."""
|
|
71
|
+
@property
|
|
72
|
+
def is_local(self) -> bool:
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
class ParallelLocalOrchestratorFlavor(BaseOrchestratorFlavor):
|
|
76
|
+
"""Class for the `ParallelLocalOrchestratorFlavor`."""
|
|
77
|
+
@property
|
|
78
|
+
def name(self) -> str:
|
|
79
|
+
return "parallel_local"
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def config_class(self) -> Type[ParallelLocalOrchestratorConfig]:
|
|
83
|
+
return ParallelLocalOrchestratorConfig
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def logo_url(self) -> str:
|
|
87
|
+
"""A URL to represent the flavor in the dashboard."""
|
|
88
|
+
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/local.png"
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def implementation_class(self) -> Type[ParallelLocalOrchestrator]:
|
|
92
|
+
return ParallelLocalOrchestrator
|
xfmr_zem/schemas.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional, Union
|
|
2
|
+
from pydantic import BaseModel, Field, RootModel
|
|
3
|
+
|
|
4
|
+
class StepInput(BaseModel):
|
|
5
|
+
data: Optional[Any] = None
|
|
6
|
+
model_config = {"extra": "allow"}
|
|
7
|
+
|
|
8
|
+
class PipelineStep(RootModel):
|
|
9
|
+
root: Union[str, Dict[str, Any]]
|
|
10
|
+
|
|
11
|
+
class ZemConfig(BaseModel):
|
|
12
|
+
name: str = "dynamic_generated_pipeline"
|
|
13
|
+
parameters: Dict[str, Any] = Field(default_factory=dict)
|
|
14
|
+
servers: Dict[str, str] = Field(default_factory=dict)
|
|
15
|
+
pipeline: List[PipelineStep]
|