wafer-core 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,73 @@
1
+ """Kernel trace tool for discovering SOTA kernels.
2
+
3
+ Given a PyTorch operation, traces what kernel PyTorch dispatches to on target hardware.
4
+
5
+ Usage:
6
+ from wafer_core.tools.dispatch_baseline import quick_trace
7
+
8
+ result = quick_trace(
9
+ "torch.matmul(A, B)",
10
+ {"A": (4096, 4096), "B": (4096, 4096)},
11
+ hardware="H100",
12
+ )
13
+ print(result.summary())
14
+ """
15
+
16
+ from wafer_core.tools.dispatch_baseline.client import (
17
+ lookup_baseline,
18
+ store_baseline,
19
+ )
20
+ from wafer_core.tools.dispatch_baseline.codegen import (
21
+ generate_trace_script,
22
+ parse_op_string,
23
+ update_dtypes,
24
+ update_shapes,
25
+ )
26
+ from wafer_core.tools.dispatch_baseline.dtypes import (
27
+ HardwareSpec,
28
+ KernelInfo,
29
+ KernelTraceConfig,
30
+ KernelTraceResult,
31
+ OpSpec,
32
+ RooflineAnalysis,
33
+ TensorSpec,
34
+ )
35
+ from wafer_core.tools.dispatch_baseline.executor import (
36
+ TraceExecutionResult,
37
+ quick_trace,
38
+ trace_kernel_local,
39
+ trace_kernel_remote,
40
+ )
41
+ from wafer_core.tools.dispatch_baseline.roofline import (
42
+ HARDWARE_SPECS,
43
+ compute_roofline,
44
+ get_hardware_spec,
45
+ )
46
+
47
+ __all__ = [
48
+ # Data types
49
+ "HardwareSpec",
50
+ "KernelInfo",
51
+ "KernelTraceConfig",
52
+ "KernelTraceResult",
53
+ "OpSpec",
54
+ "RooflineAnalysis",
55
+ "TensorSpec",
56
+ "TraceExecutionResult",
57
+ # Codegen
58
+ "generate_trace_script",
59
+ "parse_op_string",
60
+ "update_dtypes",
61
+ "update_shapes",
62
+ # Execution
63
+ "quick_trace",
64
+ "trace_kernel_local",
65
+ "trace_kernel_remote",
66
+ # Database
67
+ "lookup_baseline",
68
+ "store_baseline",
69
+ # Roofline
70
+ "HARDWARE_SPECS",
71
+ "compute_roofline",
72
+ "get_hardware_spec",
73
+ ]
@@ -0,0 +1,174 @@
1
+ """Analyzer for kernel trace output.
2
+
3
+ Parses profiler output to extract kernel information and identify the primary kernel.
4
+ """
5
+
6
+ import json
7
+ import re
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from wafer_core.tools.dispatch_baseline.dtypes import KernelInfo, KernelTraceResult, OpSpec
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ParsedTraceResult:
16
+ """Parsed trace output with environment info for caching."""
17
+
18
+ result: KernelTraceResult
19
+ pytorch_version: str
20
+ runtime_version: str
21
+ gpu_arch: str
22
+
23
+
24
+ def parse_trace_output(output: str, op_spec: OpSpec, hardware: str) -> ParsedTraceResult:
25
+ """Parse kernel trace script output and extract kernel information.
26
+
27
+ Args:
28
+ output: Raw output from the trace script (stdout)
29
+ op_spec: The operation that was traced
30
+ hardware: Target hardware name
31
+
32
+ Returns:
33
+ ParsedTraceResult with extracted kernel information and environment info
34
+ """
35
+ # Look for our JSON marker in the output
36
+ json_match = re.search(r"KERNEL_TRACE_RESULT_JSON:(.+)$", output, re.MULTILINE)
37
+
38
+ if not json_match:
39
+ return ParsedTraceResult(
40
+ result=KernelTraceResult(
41
+ op_spec=op_spec,
42
+ hardware=hardware,
43
+ kernels=[],
44
+ primary_kernel=None,
45
+ raw_output=output,
46
+ error="Could not find KERNEL_TRACE_RESULT_JSON marker in output",
47
+ ),
48
+ pytorch_version="unknown",
49
+ runtime_version="unknown",
50
+ gpu_arch="unknown",
51
+ )
52
+
53
+ try:
54
+ result_json = json.loads(json_match.group(1))
55
+ except json.JSONDecodeError as e:
56
+ return ParsedTraceResult(
57
+ result=KernelTraceResult(
58
+ op_spec=op_spec,
59
+ hardware=hardware,
60
+ kernels=[],
61
+ primary_kernel=None,
62
+ raw_output=output,
63
+ error=f"Failed to parse JSON: {e}",
64
+ ),
65
+ pytorch_version="unknown",
66
+ runtime_version="unknown",
67
+ gpu_arch="unknown",
68
+ )
69
+
70
+ # Extract environment info
71
+ env_info = result_json.get("environment", {})
72
+ pytorch_version = env_info.get("pytorch_version", "unknown")
73
+ runtime_version = env_info.get("runtime_version", "unknown")
74
+ gpu_arch = env_info.get("gpu_arch", "unknown")
75
+
76
+ # Extract kernels from JSON
77
+ kernels = []
78
+ for k in result_json.get("kernels", []):
79
+ kernel_info = _parse_kernel_dict(k)
80
+ if kernel_info:
81
+ kernels.append(kernel_info)
82
+
83
+ # Identify primary kernel (longest duration)
84
+ primary_kernel = kernels[0] if kernels else None
85
+
86
+ return ParsedTraceResult(
87
+ result=KernelTraceResult(
88
+ op_spec=op_spec,
89
+ hardware=hardware,
90
+ kernels=kernels,
91
+ primary_kernel=primary_kernel,
92
+ raw_output=output,
93
+ ),
94
+ pytorch_version=pytorch_version,
95
+ runtime_version=runtime_version,
96
+ gpu_arch=gpu_arch,
97
+ )
98
+
99
+
100
+ def _parse_kernel_dict(kernel_dict: dict[str, Any]) -> KernelInfo | None:
101
+ """Parse a kernel dictionary into KernelInfo."""
102
+ name = kernel_dict.get("name", "")
103
+ if not name:
104
+ return None
105
+
106
+ return KernelInfo(
107
+ name=name,
108
+ duration_us=kernel_dict.get("duration_us", 0.0),
109
+ )
110
+
111
+
112
+ def extract_kernels_from_nsys(nsys_output: str) -> list[dict[str, Any]]:
113
+ """Extract kernel information from nsys output.
114
+
115
+ This is a fallback for when torch.profiler doesn't capture all kernels.
116
+
117
+ Args:
118
+ nsys_output: Output from `nsys stats` or similar
119
+
120
+ Returns:
121
+ List of kernel dictionaries
122
+ """
123
+ kernels = []
124
+
125
+ # Look for GPU kernel summary lines
126
+ # Format varies, but typically: name, duration, count, etc.
127
+ lines = nsys_output.split("\n")
128
+
129
+ for line in lines:
130
+ # Skip headers and empty lines
131
+ if not line.strip() or "---" in line or "Name" in line:
132
+ continue
133
+
134
+ # Try to parse as nsys kernel line
135
+ # This is a simplified parser - real nsys output varies
136
+ parts = line.split()
137
+ if len(parts) >= 2:
138
+ # Heuristic: last column is often the kernel name
139
+ name = parts[-1]
140
+ if any(c.isalpha() for c in name) and "_" in name:
141
+ kernels.append({"name": name, "duration_us": 0.0})
142
+
143
+ return kernels
144
+
145
+
146
+ def merge_kernel_infos(
147
+ profiler_kernels: list[KernelInfo], nsys_kernels: list[dict[str, Any]]
148
+ ) -> list[KernelInfo]:
149
+ """Merge kernel info from multiple sources.
150
+
151
+ Prefers profiler data when available, but adds any missing kernels from nsys.
152
+
153
+ Args:
154
+ profiler_kernels: Kernels from torch.profiler
155
+ nsys_kernels: Kernels from nsys
156
+
157
+ Returns:
158
+ Merged list of KernelInfo objects
159
+ """
160
+ seen_names = {k.name for k in profiler_kernels}
161
+ result = list(profiler_kernels)
162
+
163
+ for nsys_kernel in nsys_kernels:
164
+ name = nsys_kernel.get("name", "")
165
+ if name and name not in seen_names:
166
+ result.append(
167
+ KernelInfo(
168
+ name=name,
169
+ duration_us=nsys_kernel.get("duration_us", 0.0),
170
+ )
171
+ )
172
+ seen_names.add(name)
173
+
174
+ return result
@@ -0,0 +1,196 @@
1
+ """API client for baseline kernel trace results.
2
+
3
+ Interacts with the Supabase baselines table to store and retrieve kernel dispatch info.
4
+ Results are keyed by (op, shape, dtype, pytorch_version, runtime_version, gpu_arch).
5
+ Since kernel dispatch is deterministic for a given environment, this serves as a
6
+ shared lookup table across all users.
7
+ """
8
+
9
+ import logging
10
+
11
+ import httpx
12
+
13
+ from wafer_core.tools.dispatch_baseline.dtypes import (
14
+ KernelInfo,
15
+ KernelTraceResult,
16
+ OpSpec,
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # API timeout for baseline lookups
22
+ API_TIMEOUT = 30.0
23
+
24
+
25
+ def _get_api_config() -> tuple[str, dict[str, str]]:
26
+ """Get API URL and auth headers.
27
+
28
+ Returns:
29
+ Tuple of (api_url, headers)
30
+ """
31
+ # Import here to avoid circular imports and allow use without CLI
32
+ try:
33
+ from wafer.api_client import get_api_url
34
+ from wafer.auth import get_auth_headers
35
+
36
+ return get_api_url(), get_auth_headers()
37
+ except ImportError:
38
+ # If CLI not installed, return defaults (won't work without auth)
39
+ return "https://api.wafer.ai", {}
40
+
41
+
42
+ def lookup_baseline(
43
+ op_spec: OpSpec,
44
+ hardware: str,
45
+ pytorch_version: str,
46
+ runtime_version: str,
47
+ gpu_arch: str,
48
+ ) -> KernelTraceResult | None:
49
+ """Look up baseline result from the database.
50
+
51
+ Args:
52
+ op_spec: Operation specification
53
+ hardware: Hardware name (for display, not part of cache key)
54
+ pytorch_version: PyTorch version string
55
+ runtime_version: CUDA version string
56
+ gpu_arch: GPU architecture (e.g., "sm_90")
57
+
58
+ Returns:
59
+ KernelTraceResult if found, None otherwise
60
+ """
61
+ api_url, headers = _get_api_config()
62
+
63
+ if not headers.get("Authorization"):
64
+ logger.debug("No auth headers, skipping baseline lookup")
65
+ return None
66
+
67
+ # Build request
68
+ request_data = {
69
+ "op": op_spec.op,
70
+ "inputs": [
71
+ {"name": t.name, "shape": list(t.shape), "dtype": t.dtype}
72
+ for t in op_spec.inputs
73
+ ],
74
+ "kwargs": op_spec.kwargs,
75
+ "pytorch_version": pytorch_version,
76
+ "runtime_version": runtime_version,
77
+ "gpu_arch": gpu_arch,
78
+ }
79
+
80
+ try:
81
+ with httpx.Client(timeout=API_TIMEOUT, headers=headers) as client:
82
+ response = client.post(f"{api_url}/v1/baselines/lookup", json=request_data)
83
+ response.raise_for_status()
84
+ data = response.json()
85
+
86
+ if not data.get("found"):
87
+ return None
88
+
89
+ # Reconstruct result
90
+ kernels = [
91
+ KernelInfo(
92
+ name=k["name"],
93
+ duration_us=k["duration_us"],
94
+ )
95
+ for k in data.get("kernels", [])
96
+ ]
97
+
98
+ primary_data = data.get("primary_kernel")
99
+ primary_kernel = KernelInfo(
100
+ name=primary_data["name"],
101
+ duration_us=primary_data["duration_us"],
102
+ ) if primary_data else (kernels[0] if kernels else None)
103
+
104
+ return KernelTraceResult(
105
+ op_spec=op_spec,
106
+ hardware=hardware,
107
+ kernels=kernels,
108
+ primary_kernel=primary_kernel,
109
+ # Note: roofline will be recomputed by the caller
110
+ )
111
+
112
+ except httpx.HTTPError as e:
113
+ logger.warning(f"Baseline lookup failed: {e}")
114
+ return None
115
+ except Exception as e:
116
+ logger.warning(f"Baseline lookup error: {e}")
117
+ return None
118
+
119
+
120
+ def store_baseline(
121
+ result: KernelTraceResult,
122
+ pytorch_version: str,
123
+ runtime_version: str,
124
+ gpu_arch: str,
125
+ ) -> bool:
126
+ """Store baseline result in the database.
127
+
128
+ Args:
129
+ result: Trace result to cache
130
+ pytorch_version: PyTorch version string
131
+ runtime_version: CUDA version string
132
+ gpu_arch: GPU architecture
133
+
134
+ Returns:
135
+ True if stored successfully, False otherwise
136
+ """
137
+ if result.error:
138
+ # Don't cache errors
139
+ return False
140
+
141
+ if not result.primary_kernel:
142
+ # Don't cache empty results
143
+ return False
144
+
145
+ api_url, headers = _get_api_config()
146
+
147
+ if not headers.get("Authorization"):
148
+ logger.debug("No auth headers, skipping baseline store")
149
+ return False
150
+
151
+ # Build request
152
+ request_data = {
153
+ "op": result.op_spec.op,
154
+ "inputs": [
155
+ {"name": t.name, "shape": list(t.shape), "dtype": t.dtype}
156
+ for t in result.op_spec.inputs
157
+ ],
158
+ "kwargs": result.op_spec.kwargs,
159
+ "pytorch_version": pytorch_version,
160
+ "runtime_version": runtime_version,
161
+ "gpu_arch": gpu_arch,
162
+ "hardware_name": result.hardware,
163
+ "primary_kernel": {
164
+ "name": result.primary_kernel.name,
165
+ "duration_us": result.primary_kernel.duration_us,
166
+ },
167
+ "kernels": [
168
+ {
169
+ "name": k.name,
170
+ "duration_us": k.duration_us,
171
+ }
172
+ for k in result.kernels
173
+ ],
174
+ }
175
+
176
+ try:
177
+ with httpx.Client(timeout=API_TIMEOUT, headers=headers) as client:
178
+ response = client.post(f"{api_url}/v1/baselines/store", json=request_data)
179
+ response.raise_for_status()
180
+ data = response.json()
181
+
182
+ if data.get("created"):
183
+ logger.info(f"Stored baseline: {result.op_spec.op} ({gpu_arch})")
184
+ else:
185
+ logger.debug(f"Baseline already exists: {result.op_spec.op} ({gpu_arch})")
186
+
187
+ return True
188
+
189
+ except httpx.HTTPError as e:
190
+ logger.warning(f"Baseline store failed: {e}")
191
+ return False
192
+ except Exception as e:
193
+ logger.warning(f"Baseline store error: {e}")
194
+ return False
195
+
196
+
@@ -0,0 +1,246 @@
1
+ """Code generator for kernel trace profiling scripts.
2
+
3
+ Generates minimal Python scripts that:
4
+ 1. Create tensors with specified shapes/dtypes
5
+ 2. Run the PyTorch op with warmup
6
+ 3. Profile using torch.profiler
7
+ 4. Output structured JSON with kernel info
8
+ """
9
+
10
+ from wafer_core.tools.dispatch_baseline.dtypes import KernelTraceConfig, OpSpec, TensorSpec
11
+
12
+
13
+ def generate_trace_script(config: KernelTraceConfig) -> str:
14
+ """Generate a profiling script for the given operation.
15
+
16
+ Args:
17
+ config: Kernel trace configuration with op spec and settings
18
+
19
+ Returns:
20
+ Python script as a string
21
+ """
22
+ op_spec = config.op_spec
23
+ tensor_setup = _generate_tensor_setup(op_spec.inputs)
24
+ op_call = _generate_op_call(op_spec)
25
+
26
+ script = f'''"""Auto-generated kernel trace script."""
27
+ import json
28
+ import sys
29
+ import torch
30
+
31
+ # Ensure CUDA is available
32
+ assert torch.cuda.is_available(), "CUDA not available"
33
+
34
+ # Setup tensors
35
+ {tensor_setup}
36
+
37
+ # Warmup
38
+ print("Warming up...", file=sys.stderr)
39
+ for _ in range({config.num_warmup}):
40
+ _ = {op_call}
41
+ torch.cuda.synchronize()
42
+
43
+ # Profile with torch.profiler
44
+ print("Profiling...", file=sys.stderr)
45
+ with torch.profiler.profile(
46
+ activities=[
47
+ torch.profiler.ProfilerActivity.CPU,
48
+ torch.profiler.ProfilerActivity.CUDA,
49
+ ],
50
+ record_shapes=True,
51
+ with_stack=False,
52
+ ) as prof:
53
+ for _ in range({config.num_runs}):
54
+ _ = {op_call}
55
+ torch.cuda.synchronize()
56
+
57
+ # Extract kernel information
58
+ kernels = []
59
+ for event in prof.key_averages():
60
+ if event.device_type == torch.profiler.DeviceType.CUDA:
61
+ # Use device_time (new) or cuda_time (old) for average time in us
62
+ duration = getattr(event, 'device_time', None) or getattr(event, 'cuda_time', 0)
63
+ kernels.append({{
64
+ "name": event.key,
65
+ "duration_us": duration,
66
+ "count": event.count,
67
+ }})
68
+
69
+ # Sort by duration (longest first)
70
+ kernels.sort(key=lambda k: k["duration_us"], reverse=True)
71
+
72
+ # Get environment info for caching
73
+ props = torch.cuda.get_device_properties(0)
74
+
75
+ # Detect runtime version (CUDA or ROCm)
76
+ if hasattr(torch.version, 'hip') and torch.version.hip:
77
+ runtime_version = torch.version.hip
78
+ # ROCm uses gcnArchName for architecture (e.g., "gfx942")
79
+ gpu_arch = getattr(props, 'gcnArchName', f"gfx{{props.major}}{{props.minor}}")
80
+ else:
81
+ runtime_version = torch.version.cuda or "unknown"
82
+ gpu_arch = f"sm_{{props.major}}{{props.minor}}"
83
+
84
+ env_info = {{
85
+ "pytorch_version": torch.__version__,
86
+ "runtime_version": runtime_version,
87
+ "gpu_arch": gpu_arch,
88
+ "gpu_name": props.name,
89
+ }}
90
+
91
+ # Build result
92
+ result = {{
93
+ "op": "{op_spec.op}",
94
+ "inputs": {_serialize_inputs(op_spec.inputs)},
95
+ "num_runs": {config.num_runs},
96
+ "kernels": kernels,
97
+ "total_cuda_time_us": sum(k["duration_us"] for k in kernels),
98
+ "environment": env_info,
99
+ }}
100
+
101
+ # Output as JSON (marker for parsing)
102
+ print("KERNEL_TRACE_RESULT_JSON:" + json.dumps(result))
103
+ '''
104
+ return script
105
+
106
+
107
+ def _generate_tensor_setup(inputs: list[TensorSpec]) -> str:
108
+ """Generate tensor creation code."""
109
+ lines = []
110
+ for tensor in inputs:
111
+ shape_str = ", ".join(str(d) for d in tensor.shape)
112
+ dtype_map = {
113
+ "float16": "torch.float16",
114
+ "float32": "torch.float32",
115
+ "bfloat16": "torch.bfloat16",
116
+ "float64": "torch.float64",
117
+ "int8": "torch.int8",
118
+ "int16": "torch.int16",
119
+ "int32": "torch.int32",
120
+ "int64": "torch.int64",
121
+ }
122
+ dtype = dtype_map.get(tensor.dtype, f"torch.{tensor.dtype}")
123
+
124
+ # Use randn for float types, randint for int types
125
+ if "int" in tensor.dtype:
126
+ lines.append(
127
+ f'{tensor.name} = torch.randint(-128, 127, ({shape_str},), '
128
+ f'dtype={dtype}, device="{tensor.device}")'
129
+ )
130
+ else:
131
+ lines.append(
132
+ f'{tensor.name} = torch.randn({shape_str}, dtype={dtype}, device="{tensor.device}")'
133
+ )
134
+
135
+ return "\n".join(lines)
136
+
137
+
138
+ def _generate_op_call(op_spec: OpSpec) -> str:
139
+ """Generate the operation call code."""
140
+ args = [t.name for t in op_spec.inputs]
141
+ kwargs = [f"{k}={v}" for k, v in op_spec.kwargs.items()]
142
+ all_args = ", ".join(args + kwargs)
143
+ return f"{op_spec.op}({all_args})"
144
+
145
+
146
+ def _serialize_inputs(inputs: list[TensorSpec]) -> str:
147
+ """Serialize inputs for JSON output."""
148
+ items = []
149
+ for t in inputs:
150
+ items.append(f'{{"name": "{t.name}", "shape": {list(t.shape)}, "dtype": "{t.dtype}"}}')
151
+ return "[" + ", ".join(items) + "]"
152
+
153
+
154
+ def parse_op_string(op_string: str) -> OpSpec:
155
+ """Parse a simple operation string into an OpSpec.
156
+
157
+ Supports formats like:
158
+ - "torch.matmul(A, B)"
159
+ - "torch.nn.functional.softmax(x, dim=-1)"
160
+
161
+ For shapes, use --shape flags separately.
162
+
163
+ Args:
164
+ op_string: Operation string like "torch.matmul(A, B)"
165
+
166
+ Returns:
167
+ OpSpec with parsed operation and placeholder inputs
168
+ """
169
+ # Extract op name and arguments
170
+ if "(" not in op_string or ")" not in op_string:
171
+ raise ValueError(f"Invalid op format: {op_string}. Expected: op(args)")
172
+
173
+ op_name = op_string[: op_string.index("(")].strip()
174
+ args_str = op_string[op_string.index("(") + 1 : op_string.rindex(")")].strip()
175
+
176
+ # Parse arguments (simple comma split, doesn't handle nested calls)
177
+ args = [a.strip() for a in args_str.split(",") if a.strip()]
178
+
179
+ # Separate positional args (tensor names) from kwargs
180
+ inputs = []
181
+ kwargs = {}
182
+
183
+ for arg in args:
184
+ if "=" in arg:
185
+ key, value = arg.split("=", 1)
186
+ kwargs[key.strip()] = value.strip()
187
+ else:
188
+ # Assume it's a tensor input with default shape
189
+ inputs.append(
190
+ TensorSpec(
191
+ name=arg,
192
+ shape=(1024, 1024), # Default shape, override with --shape
193
+ dtype="float16",
194
+ )
195
+ )
196
+
197
+ return OpSpec(op=op_name, inputs=inputs, kwargs=kwargs)
198
+
199
+
200
+ def update_shapes(op_spec: OpSpec, shapes: dict[str, tuple[int, ...]]) -> OpSpec:
201
+ """Update tensor shapes in an OpSpec.
202
+
203
+ Args:
204
+ op_spec: Original OpSpec
205
+ shapes: Dict mapping tensor name to new shape
206
+
207
+ Returns:
208
+ New OpSpec with updated shapes
209
+ """
210
+ new_inputs = []
211
+ for tensor in op_spec.inputs:
212
+ if tensor.name in shapes:
213
+ new_inputs.append(
214
+ TensorSpec(
215
+ name=tensor.name,
216
+ shape=shapes[tensor.name],
217
+ dtype=tensor.dtype,
218
+ device=tensor.device,
219
+ )
220
+ )
221
+ else:
222
+ new_inputs.append(tensor)
223
+
224
+ return OpSpec(op=op_spec.op, inputs=new_inputs, kwargs=op_spec.kwargs)
225
+
226
+
227
+ def update_dtypes(op_spec: OpSpec, dtype: str) -> OpSpec:
228
+ """Update all tensor dtypes in an OpSpec.
229
+
230
+ Args:
231
+ op_spec: Original OpSpec
232
+ dtype: New dtype for all tensors
233
+
234
+ Returns:
235
+ New OpSpec with updated dtypes
236
+ """
237
+ new_inputs = [
238
+ TensorSpec(
239
+ name=t.name,
240
+ shape=t.shape,
241
+ dtype=dtype,
242
+ device=t.device,
243
+ )
244
+ for t in op_spec.inputs
245
+ ]
246
+ return OpSpec(op=op_spec.op, inputs=new_inputs, kwargs=op_spec.kwargs)