xfmr-zem 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xfmr_zem/__init__.py ADDED
@@ -0,0 +1,35 @@
1
+ """
2
+ xfmr-zem (Zem)
3
+ ==============
4
+
5
+ A unified data pipeline framework combining:
6
+ - Model Context Protocol (MCP): For modular, specialized processing servers.
7
+ - ZenML: For production-grade orchestration and pipeline tracking.
8
+
9
+ xfmr-zem allows you to build complex data processing workflows by connecting
10
+ multiple MCP servers as pipeline steps, all orchestrated by ZenML.
11
+
12
+ Key Features:
13
+ * **Config-Driven Architecture**: Define your entire pipeline in a simple YAML configuration.
14
+ * **MCP Server Integration**: Leverage any MCP-compatible server as a processing block.
15
+ * **ZenML Orchestration**: Production-grade tracking, caching, and visualization of data flows.
16
+ * **Multi-Domain Ready**: Designed for modular tasks like curation, extraction, and filtering.
17
+
18
+ Example:
19
+ from xfmr_zem import PipelineClient
20
+
21
+ # Initialize client with a pipeline configuration
22
+ client = PipelineClient("configs/medical_pipeline.yaml")
23
+
24
+ # Build and execute the ZenML pipeline
25
+ client.run()
26
+ """
27
+
28
+ __version__ = "0.1.0"
29
+ __author__ = "Khai Hoang"
30
+
31
+ from xfmr_zem.client import PipelineClient
32
+
33
+ __all__ = [
34
+ "PipelineClient",
35
+ ]
xfmr_zem/cli.py ADDED
@@ -0,0 +1,295 @@
1
+ """
2
+ CLI for Zem - Unified Data Pipeline Framework (MCP + ZenML)
3
+ """
4
+
5
+ import os
6
+ import click
7
+ from rich.console import Console
8
+ from rich.table import Table
9
+ from loguru import logger
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ from xfmr_zem.client import PipelineClient
14
+
15
+ console = Console()
16
+
17
+
18
+ @click.group()
19
+ @click.version_option(version="0.1.0")
20
+ def main():
21
+ """Zem CLI - ZenML + MCP (NeMo Curator & DataJuicer)"""
22
+ pass
23
+
24
+
25
+ @main.command()
26
+ def info():
27
+ """Show framework information"""
28
+ console.print("[bold blue]Zem: Unified Data Pipeline Framework[/bold blue]")
29
+ console.print("Version: 0.1.0")
30
+ console.print("\nArchitecture: [green]Model Context Protocol (MCP) + ZenML[/green]")
31
+ console.print("\nIntegrations:")
32
+ console.print(" - [bold]ZenML[/bold]: Orchestration, Visualization & Artifact Tracking")
33
+ console.print(" - [bold]MCP Servers[/bold]: Standalone units for domain-specific logic")
34
+ console.print(" - [bold]NeMo Curator[/bold]: NVIDIA's high-performance curation")
35
+ console.print(" - [bold]DataJuicer[/bold]: Comprehensive data processing operators")
36
+
37
+
38
+ @main.command(name="list-tools")
39
+ @click.option("--config", "-c", type=click.Path(exists=True), help="Config file to discover tools from")
40
+ def list_tools(config):
41
+ """List available MCP tools dynamically from servers"""
42
+ if not config:
43
+ # Fallback to hardcoded list if no config provided (legacy behavior)
44
+ console.print("[yellow]Hint: Provide a config file to see dynamic tool list: zem list-tools -c your_config.yaml[/yellow]")
45
+ _print_static_operators()
46
+ return
47
+
48
+ try:
49
+ client = PipelineClient(config)
50
+ all_tools = client.discover_tools()
51
+
52
+ for srv_name, tools in all_tools.items():
53
+ console.print(f"\n[bold magenta]{srv_name} Server Tools:[/bold magenta]")
54
+ table = Table(show_header=True, header_style="bold cyan")
55
+ table.add_column("Tool Name")
56
+ table.add_column("Description")
57
+
58
+ for tool in tools:
59
+ table.add_row(tool.get("name", "N/A"), tool.get("description", "No description"))
60
+ console.print(table)
61
+
62
+ except Exception as e:
63
+ console.print(f"[bold red]Error discovering tools:[/bold red] {e}")
64
+
65
+
66
+ @main.command()
67
+ @click.argument("project_name")
68
+ def init(project_name: str):
69
+ """Bootstrap a new Zem project structure."""
70
+ base_path = Path(project_name)
71
+ if base_path.exists():
72
+ console.print(f"[bold red]Error:[/bold red] Path '{project_name}' already exists.")
73
+ sys.exit(1)
74
+
75
+ # Create directories
76
+ (base_path / "servers").mkdir(parents=True)
77
+ (base_path / "tests/manual").mkdir(parents=True)
78
+ (base_path / "data").mkdir(parents=True)
79
+
80
+ # Create sample server
81
+ sample_server_py = """from xfmr_zem.server import ZemServer
82
+ from typing import Any, List
83
+
84
+ # Initialize the sample server
85
+ mcp = ZemServer("SampleAgent")
86
+
87
+ @mcp.tool()
88
+ def hello_world(data: Any) -> List[Any]:
89
+ \"\"\"
90
+ A simple tool that adds a 'greeting' field to each record.
91
+ \"\"\"
92
+ dataset = mcp.get_data(data)
93
+ for item in dataset:
94
+ item["greeting"] = "Hello from your standalone Zem project!"
95
+ return dataset
96
+
97
+ if __name__ == "__main__":
98
+ mcp.run()
99
+ """
100
+ (base_path / "servers" / "sample_server.py").write_text(sample_server_py)
101
+
102
+ # Create sample pipeline
103
+ pipeline_yaml = f"""name: {project_name}_pipeline
104
+
105
+ servers:
106
+ agent: servers/sample_server.py
107
+
108
+ pipeline:
109
+ - name: my_first_step
110
+ agent.hello_world:
111
+ input:
112
+ data: [{{"text": "Zem is awesome!"}}]
113
+ """
114
+ (base_path / "pipeline.yaml").write_text(pipeline_yaml)
115
+
116
+ console.print(f"[bold green]Success![/bold green] Project '{project_name}' initialized.")
117
+ console.print(f"Created standalone sample server: [cyan]{project_name}/servers/sample_server.py[/cyan]")
118
+ console.print(f"Next steps:\n cd {project_name}\n zem list-tools -c pipeline.yaml\n zem run pipeline.yaml")
119
+
120
+ @main.command()
121
+ def operators():
122
+ """List available MCP tools (Static legacy list)"""
123
+ _print_static_operators()
124
+
125
+ def _print_static_operators():
126
+ # NeMo Curator Tools
127
+ console.print("\n[bold magenta]NeMo Curator Server Tools:[/bold magenta]")
128
+ nemo_table = Table(show_header=True, header_style="bold cyan")
129
+ nemo_table.add_column("Tool Name")
130
+ nemo_table.add_column("Description")
131
+
132
+ nemo_table.add_row("pii_removal", "Remove PII using NeMo Curator")
133
+ nemo_table.add_row("text_cleaning", "General text cleaning using NeMo Curator")
134
+ console.print(nemo_table)
135
+
136
+ # DataJuicer Tools
137
+ console.print("\n[bold magenta]DataJuicer Server Tools:[/bold magenta]")
138
+ dj_table = Table(show_header=True, header_style="bold cyan")
139
+ dj_table.add_column("Tool Name")
140
+ dj_table.add_column("Description")
141
+
142
+ dj_table.add_row("clean_html", "Remove HTML tags")
143
+ dj_table.add_row("clean_links", "Remove URLs/Links")
144
+ dj_table.add_row("fix_unicode", "Normalize Unicode (NFKC)")
145
+ dj_table.add_row("whitespace_normalization", "Clean extra spaces/newlines")
146
+ dj_table.add_row("text_length_filter", "Filter by character length")
147
+ dj_table.add_row("language_filter", "Heuristic-based language filtering")
148
+ dj_table.add_row("document_simhash_dedup", "Simple SimHash-based deduplication")
149
+ console.print(dj_table)
150
+
151
+ # IO Tools
152
+ console.print("\n[bold magenta]IO Server Tools (File Handling):[/bold magenta]")
153
+ io_table = Table(show_header=True, header_style="bold cyan")
154
+ io_table.add_column("Tool Name")
155
+ io_table.add_column("Description")
156
+
157
+ io_table.add_row("load_jsonl", "Load data from JSONL file")
158
+ io_table.add_row("save_jsonl", "Save data to JSONL file")
159
+ io_table.add_row("load_csv", "Load data from CSV file")
160
+ io_table.add_row("save_csv", "Save data to CSV file")
161
+ console.print(io_table)
162
+
163
+ # Profiler Tools
164
+ console.print("\n[bold magenta]Profiler Server Tools:[/bold magenta]")
165
+ prof_table = Table(show_header=True, header_style="bold cyan")
166
+ prof_table.add_column("Tool Name")
167
+ prof_table.add_column("Description")
168
+
169
+ prof_table.add_row("profile_data", "Generate summary & metrics for input data")
170
+ console.print(prof_table)
171
+
172
+
173
+ @main.command()
174
+ @click.argument("config_file", type=click.Path(exists=True))
175
+ @click.option("--params", "-p", type=click.Path(exists=True), help="Path to custom parameters.yml")
176
+ def run(config_file, params):
177
+ """Run a pipeline from a YAML configuration file"""
178
+ abs_config = os.path.abspath(config_file)
179
+ console.print(f"[bold green]Starting Pipeline:[/bold green] {abs_config}")
180
+ if params:
181
+ console.print(f"[bold blue]Custom Parameters:[/bold blue] {params}")
182
+
183
+ try:
184
+ client = PipelineClient(abs_config, params_path=params)
185
+ run_response = client.run()
186
+
187
+ console.print(f"\n[bold blue]Pipeline Execution Finished![/bold blue]")
188
+ console.print(f"Run Name: [cyan]{run_response.name}[/cyan]")
189
+ console.print(f"Status: [yellow]{run_response.status}[/yellow]")
190
+
191
+ console.print(f"\n[dim]To visualize this run, ensure ZenML dashboard is running:[/dim]")
192
+ console.print(f"[dim]uv run zenml up --port 8871[/dim]")
193
+ console.print(f"[dim]Or view runs via: zem dashboard[/dim]") # Future proofing hint
194
+
195
+ except Exception as e:
196
+ console.print(f"\n[bold red]Pipeline Failed:[/bold red] {e}")
197
+ if os.path.exists("/tmp/zenml_error.log"):
198
+ console.print("\n[bold yellow]Error log snippet (/tmp/zenml_error.log):[/bold yellow]")
199
+ with open("/tmp/zenml_error.log", "r") as f:
200
+ console.print(f.read())
201
+
202
+
203
+ @main.command()
204
+ def dashboard():
205
+ """Open the ZenML dashboard."""
206
+ import subprocess
207
+ import webbrowser
208
+ console.print("[bold blue]Checking ZenML Dashboard...[/bold blue]")
209
+ # Default ZenML port used in this project
210
+ url = "http://127.0.0.1:8871"
211
+ console.print(f"URL: [link={url}]{url}[/link]")
212
+ try:
213
+ webbrowser.open(url)
214
+ except Exception as e:
215
+ console.print(f"[yellow]Could not open browser automatically: {e}[/yellow]")
216
+
217
+
218
+ @main.command()
219
+ @click.argument("artifact_id")
220
+ @click.option("--id2", help="Secondary artifact ID for comparison (diff mode)")
221
+ @click.option("--limit", "-n", default=10, help="Number of rows to preview")
222
+ @click.option("--sample", is_flag=True, help="Show a random sample instead of the head")
223
+ def preview(artifact_id, id2, limit, sample):
224
+ """Preview a ZenML artifact (supports diff mode and sampling)"""
225
+ from zenml.client import Client
226
+ import pandas as pd
227
+ import json
228
+
229
+ def load_art_df(uid):
230
+ art = Client().get_artifact_version(uid)
231
+ d = art.load()
232
+ if isinstance(d, dict) and "path" in d:
233
+ p = d["path"]
234
+ ext = os.path.splitext(p)[1].lower()
235
+ if ext == ".parquet": return pd.read_parquet(p)
236
+ elif ext == ".csv": return pd.read_csv(p)
237
+ elif ext == ".jsonl":
238
+ with open(p, "r") as f: lines = [json.loads(l) for l in f]
239
+ return pd.DataFrame(lines)
240
+ elif isinstance(d, list): return pd.DataFrame(d)
241
+ elif isinstance(d, pd.DataFrame): return d
242
+ return None
243
+
244
+ try:
245
+ df1 = load_art_df(artifact_id)
246
+ if df1 is None:
247
+ console.print("[bold red]Error:[/bold red] Could not load artifact as tabular data.")
248
+ return
249
+
250
+ if id2:
251
+ df2 = load_art_df(id2)
252
+ if df2 is None:
253
+ console.print("[bold red]Error:[/bold red] Could not load second artifact.")
254
+ return
255
+
256
+ console.print(f"[bold blue]Comparing Artifacts:[/bold blue] {artifact_id} vs {id2}")
257
+ # Simple column/row diff
258
+ cols1, cols2 = set(df1.columns), set(df2.columns)
259
+ console.print(f" Rows: {len(df1)} -> {len(df2)} ({len(df2)-len(df1):+d})")
260
+ console.print(f" Cols: {len(df1.columns)} -> {len(df2.columns)}")
261
+ if cols1 != cols2:
262
+ added = cols2 - cols1
263
+ removed = cols1 - cols2
264
+ if added: console.print(f" [green]+ Added columns:[/green] {added}")
265
+ if removed: console.print(f" [red]- Removed columns:[/red] {removed}")
266
+
267
+ # Show a few sample rows from both for side-by-side or sequential feel
268
+ console.print("\n[bold magenta]Diff Sample (df2):[/bold magenta]")
269
+ df_to_show = df2
270
+ else:
271
+ df_to_show = df1
272
+
273
+ if sample and len(df_to_show) > limit:
274
+ preview_df = df_to_show.sample(limit)
275
+ else:
276
+ preview_df = df_to_show.head(limit)
277
+
278
+ table = Table(show_header=True, header_style="bold magenta", title=f"Preview ({'Sample' if sample else 'Head'} {limit} rows)")
279
+ for col in preview_df.columns: table.add_column(str(col))
280
+ for _, row in preview_df.iterrows():
281
+ row_values = []
282
+ for val in row:
283
+ v_str = str(val)
284
+ if len(v_str) > 100: v_str = v_str[:97] + "..."
285
+ row_values.append(v_str)
286
+ table.add_row(*row_values)
287
+ console.print(table)
288
+
289
+ except Exception as e:
290
+ console.print(f"[bold red]Error previewing artifact:[/bold red] {e}")
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
295
+
xfmr_zem/client.py ADDED
@@ -0,0 +1,208 @@
1
+
2
+ from typing import Dict, Any, List
3
+ import yaml
4
+ from pathlib import Path
5
+ from zenml import pipeline
6
+ from .zenml_wrapper import mcp_generic_step
7
+ import os
8
+ import sys
9
+
10
+ from .schemas import ZemConfig
11
+ from .zenml_wrapper import mcp_generic_step, list_mcp_tools
12
+
13
+ class PipelineClient:
14
+ """
15
+ Client to run Zem pipelines using MCP servers and ZenML orchestration.
16
+ """
17
+ def __init__(self, config_path: str, params_path: str = None):
18
+ self.config_path = Path(config_path)
19
+ self.params_path = params_path
20
+ self.params = {}
21
+ self.config_dict = self._load_config_dict(self.config_path)
22
+
23
+ # 6. Validate with Pydantic
24
+ self.config = ZemConfig(**self.config_dict)
25
+ self.server_configs = self._load_server_configs()
26
+
27
+ def _load_params(self, params_path: str = None) -> Dict[str, Any]:
28
+ """Load parameters from default or specified path."""
29
+ params = {}
30
+ default_params_path = self.config_path.parent / "parameters.yml"
31
+ if default_params_path.exists():
32
+ with open(default_params_path, "r") as f:
33
+ params.update(yaml.safe_load(f) or {})
34
+
35
+ if params_path:
36
+ p_path = Path(params_path)
37
+ if p_path.exists():
38
+ with open(p_path, "r") as f:
39
+ params.update(yaml.safe_load(f) or {})
40
+ return params
41
+
42
+ def _load_config_dict(self, path: Path) -> Dict[str, Any]:
43
+ """Load YAML config and perform substitution."""
44
+ with open(path, "r") as f:
45
+ raw_content = f.read()
46
+
47
+ self.params = self._load_params(None)
48
+ preliminary_dict = yaml.safe_load(raw_content) or {}
49
+ internal_params = preliminary_dict.get("parameters", {})
50
+ if internal_params:
51
+ self.params.update(internal_params)
52
+
53
+ if self.params_path:
54
+ custom_params = self._load_params(self.params_path)
55
+ self.params.update(custom_params)
56
+
57
+ content = raw_content
58
+ for key, value in self.params.items():
59
+ content = content.replace(f"{{{{ {key} }}}}", str(value))
60
+ content = content.replace(f"{{{{{key}}}}}", str(value))
61
+
62
+ return yaml.safe_load(content)
63
+
64
+ def _load_server_configs(self) -> Dict[str, Any]:
65
+ servers = self.config.servers
66
+ configs = {}
67
+ for name, path_str in servers.items():
68
+ # 1. Try relative to config file (User's project)
69
+ abs_path = (self.config_path.parent / path_str).resolve()
70
+
71
+ # 2. If it doesn't exist AND starts with "servers/", check internal package
72
+ if not abs_path.exists() and path_str.startswith("servers/"):
73
+ package_root = Path(__file__).parent.resolve()
74
+ abs_path = (package_root / path_str / "server.py").resolve()
75
+
76
+ # 3. If it still doesn't exist, try relative to project root
77
+ if not abs_path.exists():
78
+ project_root = Path(__file__).parent.parent.parent.resolve()
79
+ abs_path = (project_root / path_str).resolve()
80
+
81
+ # 4. If it's a directory, append default filename
82
+ if abs_path.exists() and abs_path.is_dir():
83
+ abs_path = (abs_path / "server.py").resolve()
84
+
85
+ env = os.environ.copy()
86
+ src_path = str(Path(__file__).parent.parent.resolve())
87
+ current_pythonpath = env.get("PYTHONPATH", "")
88
+ env["PYTHONPATH"] = f"{src_path}:{current_pythonpath}" if current_pythonpath else src_path
89
+
90
+ server_specific_params = {}
91
+ prefix = f"{name}."
92
+ for key, value in self.params.items():
93
+ if key.startswith(prefix):
94
+ server_specific_params[key[len(prefix):]] = value
95
+ else:
96
+ server_specific_params[key] = value
97
+
98
+ env["ZEM_PARAMETERS"] = yaml.dump(server_specific_params)
99
+ configs[name] = {
100
+ "command": sys.executable,
101
+ "args": [str(abs_path)],
102
+ "env": env
103
+ }
104
+ return configs
105
+
106
+ def discover_tools(self) -> Dict[str, List[Dict[str, Any]]]:
107
+ """Fetch tools from all registered servers."""
108
+ all_tools = {}
109
+ for name, cfg in self.server_configs.items():
110
+ all_tools[name] = list_mcp_tools(cfg["command"], cfg["args"], cfg["env"])
111
+ return all_tools
112
+
113
+ def run(self):
114
+ """Build and run the ZenML pipeline."""
115
+ pipeline_steps = self.config.pipeline
116
+ server_configs = self.server_configs
117
+ pipeline_name = self.config.name
118
+
119
+ @pipeline(name=pipeline_name, enable_cache=False)
120
+ def dynamic_generated_pipeline(pipeline_params: Dict[str, Any]):
121
+ step_outputs = {}
122
+ last_output = {}
123
+
124
+ for i, p_step in enumerate(pipeline_steps):
125
+ step_def = p_step.root
126
+ srv = ""
127
+ tool = ""
128
+ tool_args = {}
129
+ step_alias = f"step_{i}"
130
+
131
+ if isinstance(step_def, str):
132
+ srv, tool = step_def.split(".")
133
+ elif isinstance(step_def, dict):
134
+ # Check for name at top level or inside the tool dict
135
+ step_alias = step_def.get("name")
136
+
137
+ # Exclude control keywords
138
+ control_keys = ["name", "cache"]
139
+ keys = [k for k in step_def.keys() if k not in control_keys]
140
+ if not keys: continue
141
+ key = keys[0]
142
+
143
+ if "." not in key:
144
+ # Might be another control key or misconfig
145
+ continue
146
+
147
+ srv, tool = key.split(".")
148
+
149
+ if not step_alias:
150
+ step_alias = step_def[key].get("name")
151
+
152
+ step_alias = step_alias or f"{srv}.{tool}.{i}"
153
+ tool_args = step_def[key].get("input", {}) or {}
154
+
155
+ # Smart Parallelization & DAG Logic:
156
+ # 1. By default, a step is a root (None) unless it has no 'data' input,
157
+ # in which case it inherits from the previous step (linear chain).
158
+ # 2. If 'data' is a reference ($step), it depends on that specific step.
159
+
160
+ current_prev_output = None
161
+ has_explicit_data = "data" in tool_args
162
+
163
+ if not has_explicit_data:
164
+ # No data provided? Inherit from the last executed step to keep simple sequences working
165
+ current_prev_output = last_output
166
+ else:
167
+ # Data provided? Check if it's a reference or raw data
168
+ for k, v in list(tool_args.items()):
169
+ if isinstance(v, str) and v.startswith("$"):
170
+ target_step = v[1:]
171
+ if target_step in step_outputs:
172
+ if k == "data":
173
+ current_prev_output = step_outputs[target_step]
174
+ del tool_args[k]
175
+ else:
176
+ # Limitation: ZenML doesn't materialize artifacts nested in dicts
177
+ print(f"[Warning] Tool argument '{k}' uses a step reference '{v}'. "
178
+ "Currently, only the 'data' field supports cross-step dependencies. "
179
+ "This value will be passed as a raw string.")
180
+ else:
181
+ raise ValueError(f"Step reference '{v}' not found in previous steps. Available: {list(step_outputs.keys())}")
182
+
183
+ # 3. Adaptive Caching:
184
+ # Check for 'cache' at top level
185
+ enable_cache = step_def.get("cache", True) if isinstance(step_def, dict) else True
186
+
187
+ from zenml import step as zenml_step
188
+ unique_step_name = f"{srv}.{tool}.{i}"
189
+ dynamic_step = zenml_step(
190
+ mcp_generic_step.entrypoint,
191
+ name=unique_step_name,
192
+ enable_cache=enable_cache
193
+ )
194
+
195
+ step_output = dynamic_step(
196
+ server_name=srv,
197
+ tool_name=tool,
198
+ server_config=server_configs.get(srv, {}),
199
+ tool_args=tool_args,
200
+ previous_output=current_prev_output
201
+ )
202
+
203
+ step_outputs[step_alias] = step_output
204
+ last_output = step_output
205
+
206
+ return last_output
207
+
208
+ return dynamic_generated_pipeline(pipeline_params=self.params)
@@ -0,0 +1,92 @@
1
+ import time
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Type
3
+ from uuid import uuid4
4
+
5
+ from zenml.enums import ExecutionMode
6
+ from zenml.logger import get_logger
7
+ from zenml.orchestrators import (
8
+ BaseOrchestrator,
9
+ BaseOrchestratorConfig,
10
+ BaseOrchestratorFlavor,
11
+ )
12
+ from zenml.orchestrators.dag_runner import ThreadedDagRunner
13
+ from zenml.utils import string_utils
14
+
15
+ if TYPE_CHECKING:
16
+ from zenml.models import PipelineRunResponse, PipelineSnapshotResponse
17
+ from zenml.stack import Stack
18
+
19
+ logger = get_logger(__name__)
20
+
21
+ class ParallelLocalOrchestrator(BaseOrchestrator):
22
+ """Orchestrator responsible for running pipelines locally in parallel."""
23
+ _orchestrator_run_id: Optional[str] = None
24
+
25
+ def submit_pipeline(
26
+ self,
27
+ snapshot: "PipelineSnapshotResponse",
28
+ stack: "Stack",
29
+ base_environment: Dict[str, str],
30
+ step_environments: Dict[str, Dict[str, str]],
31
+ placeholder_run: Optional["PipelineRunResponse"] = None,
32
+ ) -> None:
33
+ """Submits a pipeline to the orchestrator."""
34
+ self._orchestrator_run_id = str(uuid4())
35
+ start_time = time.time()
36
+
37
+ # Build DAG
38
+ dag = {
39
+ step_name: step.spec.upstream_steps
40
+ for step_name, step in snapshot.step_configurations.items()
41
+ }
42
+
43
+ def run_step_wrapper(step_name: str) -> None:
44
+ step = snapshot.step_configurations[step_name]
45
+ self.run_step(step=step)
46
+
47
+ # Use ThreadedDagRunner for parallel execution
48
+ dag_runner = ThreadedDagRunner(
49
+ dag=dag,
50
+ run_fn=run_step_wrapper
51
+ )
52
+
53
+ logger.info("Starting parallel local execution...")
54
+ dag_runner.run()
55
+
56
+ run_duration = time.time() - start_time
57
+ logger.info(
58
+ "Parallel pipeline run has finished in `%s`.",
59
+ string_utils.get_human_readable_time(run_duration),
60
+ )
61
+ self._orchestrator_run_id = None
62
+
63
+ def get_orchestrator_run_id(self) -> str:
64
+ """Returns the active orchestrator run id."""
65
+ if not self._orchestrator_run_id:
66
+ raise RuntimeError("No run id set.")
67
+ return self._orchestrator_run_id
68
+
69
+ class ParallelLocalOrchestratorConfig(BaseOrchestratorConfig):
70
+ """Parallel local orchestrator config."""
71
+ @property
72
+ def is_local(self) -> bool:
73
+ return True
74
+
75
+ class ParallelLocalOrchestratorFlavor(BaseOrchestratorFlavor):
76
+ """Class for the `ParallelLocalOrchestratorFlavor`."""
77
+ @property
78
+ def name(self) -> str:
79
+ return "parallel_local"
80
+
81
+ @property
82
+ def config_class(self) -> Type[ParallelLocalOrchestratorConfig]:
83
+ return ParallelLocalOrchestratorConfig
84
+
85
+ @property
86
+ def logo_url(self) -> str:
87
+ """A URL to represent the flavor in the dashboard."""
88
+ return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/local.png"
89
+
90
+ @property
91
+ def implementation_class(self) -> Type[ParallelLocalOrchestrator]:
92
+ return ParallelLocalOrchestrator
xfmr_zem/schemas.py ADDED
@@ -0,0 +1,15 @@
1
+ from typing import Dict, Any, List, Optional, Union
2
+ from pydantic import BaseModel, Field, RootModel
3
+
4
+ class StepInput(BaseModel):
5
+ data: Optional[Any] = None
6
+ model_config = {"extra": "allow"}
7
+
8
+ class PipelineStep(RootModel):
9
+ root: Union[str, Dict[str, Any]]
10
+
11
+ class ZemConfig(BaseModel):
12
+ name: str = "dynamic_generated_pipeline"
13
+ parameters: Dict[str, Any] = Field(default_factory=dict)
14
+ servers: Dict[str, str] = Field(default_factory=dict)
15
+ pipeline: List[PipelineStep]