wafer-cli 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/gpu_run.py CHANGED
@@ -19,7 +19,10 @@ CONTAINER_WORKSPACE = "/workspace"
19
19
  class PushResult:
20
20
  """Result of pushing a directory to remote target."""
21
21
 
22
- workspace_path: str # Absolute path on remote (tilde-expanded)
22
+ workspace_name: str # Just the workspace name (e.g., "project")
23
+ workspace_path: (
24
+ str # Full absolute path on remote (e.g., "/home/user/.wafer/workspaces/project")
25
+ )
23
26
  files_uploaded: list[str] # Relative paths of uploaded files
24
27
 
25
28
 
@@ -71,6 +74,7 @@ def push_directory(
71
74
  files_uploaded.append(str(file.relative_to(local_path)))
72
75
 
73
76
  return PushResult(
77
+ workspace_name=workspace_name,
74
78
  workspace_path=expanded_workspace,
75
79
  files_uploaded=files_uploaded,
76
80
  )
wafer/kernel_scope.py ADDED
@@ -0,0 +1,453 @@
1
+ """Kernel Scope - CLI for static ISA analysis of Triton kernels.
2
+
3
+ This module provides the CLI wrapper for the `wafer amd kernel-scope` command.
4
+ It supports analysis of:
5
+ - AMDGCN ISA files (.s, .gcn, .asm)
6
+ - LLVM-IR files (.ll)
7
+ - TTGIR files (.ttgir, .ttir, .mlir)
8
+
9
+ Design: Wafer-436 - AMD Kernel Scope
10
+ """
11
+
12
+ import json
13
+ import sys
14
+ from pathlib import Path
15
+ from typing import Optional
16
+
17
+
18
+ def print_usage() -> None:
19
+ """Print CLI usage information."""
20
+ print("Usage: wafer amd kernel-scope <subcommand> [options]", file=sys.stderr)
21
+ print("", file=sys.stderr)
22
+ print("Subcommands:", file=sys.stderr)
23
+ print(" analyze <file|directory> Analyze ISA/LLVM-IR/TTGIR files", file=sys.stderr)
24
+ print(" metrics List available metrics", file=sys.stderr)
25
+ print(" targets List supported GPU targets", file=sys.stderr)
26
+ print("", file=sys.stderr)
27
+ print("Analyze Options:", file=sys.stderr)
28
+ print(" --json Output as JSON", file=sys.stderr)
29
+ print(" --csv Output as CSV", file=sys.stderr)
30
+ print(" --recursive / -r Scan directories recursively", file=sys.stderr)
31
+ print(" --filter EXPR Filter results (e.g., 'spills > 0')", file=sys.stderr)
32
+ print(" --output / -o FILE Write output to file", file=sys.stderr)
33
+ print(" --kernel INDEX Kernel index if multiple in file", file=sys.stderr)
34
+ print("", file=sys.stderr)
35
+ print("Examples:", file=sys.stderr)
36
+ print(" wafer amd kernel-scope analyze kernel.s", file=sys.stderr)
37
+ print(" wafer amd kernel-scope analyze kernel.s --json", file=sys.stderr)
38
+ print(" wafer amd kernel-scope analyze ~/.triton/cache/ --filter 'spills > 0'", file=sys.stderr)
39
+ print(" wafer amd kernel-scope analyze . -r --csv -o metrics.csv", file=sys.stderr)
40
+ print(" wafer amd kernel-scope metrics", file=sys.stderr)
41
+ print(" wafer amd kernel-scope targets", file=sys.stderr)
42
+
43
+
44
+ def analyze_command(
45
+ path: str,
46
+ json_output: bool = False,
47
+ csv_output: bool = False,
48
+ recursive: bool = True,
49
+ filter_expr: Optional[str] = None,
50
+ output_file: Optional[str] = None,
51
+ kernel_index: int = 0,
52
+ ) -> str:
53
+ """Analyze ISA/LLVM-IR/TTGIR file or directory.
54
+
55
+ Args:
56
+ path: Path to file or directory
57
+ json_output: Output as JSON
58
+ csv_output: Output as CSV
59
+ recursive: Scan directories recursively
60
+ filter_expr: Filter expression (e.g., "spills > 0")
61
+ output_file: Write output to file
62
+ kernel_index: Kernel index for multi-kernel files
63
+
64
+ Returns:
65
+ Analysis output string
66
+ """
67
+ from wafer_core.lib.kernel_scope import (
68
+ analyze_isa_file,
69
+ analyze_directory,
70
+ analyze_file,
71
+ )
72
+
73
+ target_path = Path(path).expanduser()
74
+
75
+ if not target_path.exists():
76
+ raise FileNotFoundError(f"Path not found: {path}")
77
+
78
+ # Single file analysis
79
+ if target_path.is_file():
80
+ suffix = target_path.suffix.lower()
81
+
82
+ # For ISA files, use kernel_index parameter
83
+ if suffix in (".s", ".gcn", ".asm"):
84
+ result = analyze_isa_file(target_path, kernel_index=kernel_index)
85
+ else:
86
+ result = analyze_file(target_path)
87
+
88
+ if not result.success:
89
+ raise RuntimeError(f"Analysis failed: {result.error}")
90
+
91
+ output = _format_single_result(result, json_output, csv_output)
92
+
93
+ # Directory analysis
94
+ else:
95
+ batch_result = analyze_directory(target_path, recursive=recursive)
96
+
97
+ # Apply filter if specified
98
+ if filter_expr:
99
+ batch_result = _apply_filter(batch_result, filter_expr)
100
+
101
+ output = _format_batch_result(batch_result, json_output, csv_output)
102
+
103
+ # Write to file if specified
104
+ if output_file:
105
+ Path(output_file).write_text(output)
106
+ print(f"Output written to {output_file}", file=sys.stderr)
107
+ return f"Results saved to {output_file}"
108
+
109
+ return output
110
+
111
+
112
+ def metrics_command() -> str:
113
+ """List available metrics.
114
+
115
+ Returns:
116
+ Metrics list output
117
+ """
118
+ metrics = [
119
+ ("vgpr_count", "Vector GPR allocation", "From .amdhsa_next_free_vgpr directive"),
120
+ ("sgpr_count", "Scalar GPR allocation", "From .amdhsa_next_free_sgpr directive"),
121
+ ("agpr_count", "Accumulator GPR count", "For MFMA operations (MI100+)"),
122
+ ("lds_size", "LDS allocation (bytes)", "From .amdhsa_group_segment_fixed_size"),
123
+ ("scratch_size", "Scratch memory (bytes)", "From .amdhsa_private_segment_fixed_size"),
124
+ ("spill_count", "Register spill operations", "Count of scratch_store/load instructions"),
125
+ ("mfma_count", "MFMA instructions", "Count of v_mfma_* instructions"),
126
+ ("mfma_density_pct", "MFMA density (%)", "MFMA / total VALU * 100"),
127
+ ("packed_ops_count", "Packed instructions", "Count of v_pk_* instructions"),
128
+ ("fma_count", "FMA instructions", "Count of v_fma_* instructions"),
129
+ ("barrier_count", "Barriers", "Count of s_barrier instructions"),
130
+ ("full_stall_count", "Full stalls", "Count of waitcnt 0 instructions"),
131
+ ("global_load_count", "Global loads", "Count of global_load_* instructions"),
132
+ ("global_store_count", "Global stores", "Count of global_store_* instructions"),
133
+ ("lds_ops_count", "LDS operations", "Count of ds_read/write instructions"),
134
+ ("theoretical_occupancy", "Max waves/CU", "Limited by VGPR/SGPR/LDS"),
135
+ ]
136
+
137
+ lines = [
138
+ "Available Metrics for Kernel Scope Analysis",
139
+ "=" * 60,
140
+ "",
141
+ ]
142
+
143
+ for name, description, derivation in metrics:
144
+ lines.append(f" {name:<25} {description}")
145
+ lines.append(f" {'':<25} Derivation: {derivation}")
146
+ lines.append("")
147
+
148
+ lines.extend([
149
+ "Instruction Categories:",
150
+ " VALU - Vector ALU (v_add_*, v_mul_*, v_fma_*)",
151
+ " SALU - Scalar ALU (s_add_*, s_mul_*)",
152
+ " VMEM - Vector memory (global_load_*, global_store_*)",
153
+ " SMEM - Scalar memory (s_load_*, s_buffer_load_*)",
154
+ " LDS - Local Data Share (ds_read_*, ds_write_*)",
155
+ " MFMA - Matrix FMA (v_mfma_f32_*, v_mfma_f16_*)",
156
+ " SYNC - Synchronization (s_barrier, s_waitcnt)",
157
+ " SPILL - Spill operations (scratch_store_*, scratch_load_*)",
158
+ ])
159
+
160
+ return "\n".join(lines)
161
+
162
+
163
+ def targets_command() -> str:
164
+ """List supported GPU targets.
165
+
166
+ Returns:
167
+ Targets list output
168
+ """
169
+ from wafer_core.lib.kernel_scope.targets import SUPPORTED_TARGETS, get_target_specs
170
+
171
+ lines = [
172
+ "Supported GPU Targets",
173
+ "=" * 60,
174
+ "",
175
+ f"{'Architecture':<12} {'Series':<10} {'VGPRs/CU':<10} {'SGPRs/CU':<10} {'LDS/CU':<10} {'Max Waves':<10}",
176
+ "-" * 60,
177
+ ]
178
+
179
+ for target in SUPPORTED_TARGETS:
180
+ specs = get_target_specs(target)
181
+ lines.append(
182
+ f"{specs.name:<12} {specs.series:<10} {specs.vgprs_per_cu:<10} "
183
+ f"{specs.sgprs_per_cu:<10} {specs.lds_per_cu:<10} {specs.max_waves_per_cu:<10}"
184
+ )
185
+
186
+ lines.extend([
187
+ "",
188
+ "Note: Default values are used for unknown architectures.",
189
+ ])
190
+
191
+ return "\n".join(lines)
192
+
193
+
194
+ def _format_single_result(result, json_output: bool, csv_output: bool) -> str:
195
+ """Format a single analysis result."""
196
+ if json_output:
197
+ return result.to_json()
198
+
199
+ if csv_output:
200
+ return _result_to_csv(result)
201
+
202
+ return _result_to_text(result)
203
+
204
+
205
+ def _format_batch_result(batch_result, json_output: bool, csv_output: bool) -> str:
206
+ """Format batch analysis results."""
207
+ if json_output:
208
+ return batch_result.to_json()
209
+
210
+ if csv_output:
211
+ return _batch_to_csv(batch_result)
212
+
213
+ return _batch_to_text(batch_result)
214
+
215
+
216
+ def _result_to_text(result) -> str:
217
+ """Format single result as human-readable text."""
218
+ lines = []
219
+
220
+ if result.isa_analysis:
221
+ a = result.isa_analysis
222
+ lines.extend([
223
+ f"Kernel: {a.kernel_name}",
224
+ f"Architecture: {a.architecture}",
225
+ "",
226
+ "=== Registers ===",
227
+ f" VGPRs: {a.vgpr_count}",
228
+ f" SGPRs: {a.sgpr_count}",
229
+ f" AGPRs: {a.agpr_count}",
230
+ ])
231
+
232
+ if a.spill_count > 0:
233
+ lines.extend([
234
+ "",
235
+ "!!! SPILLS DETECTED !!!",
236
+ f" Total spills: {a.spill_count}",
237
+ f" VGPR spills: {a.vgpr_spill_count}",
238
+ f" SGPR spills: {a.sgpr_spill_count}",
239
+ ])
240
+ else:
241
+ lines.append(" Spills: None (good)")
242
+
243
+ lines.extend([
244
+ "",
245
+ "=== Memory ===",
246
+ f" LDS: {a.lds_size} bytes",
247
+ f" Scratch: {a.scratch_size} bytes",
248
+ f" Global loads: {a.global_load_count}",
249
+ f" Global stores: {a.global_store_count}",
250
+ f" LDS ops: {a.lds_ops_count}",
251
+ "",
252
+ "=== Instructions ===",
253
+ f" MFMA: {a.mfma_count} ({a.mfma_density_pct:.1f}% density)",
254
+ f" FMA: {a.fma_count}",
255
+ f" Packed (v_pk_*): {a.packed_ops_count}",
256
+ f" Barriers: {a.barrier_count}",
257
+ f" Full stalls: {a.full_stall_count}",
258
+ "",
259
+ "=== Instruction Mix ===",
260
+ f" VALU: {a.instruction_mix.valu_count}",
261
+ f" SALU: {a.instruction_mix.salu_count}",
262
+ f" VMEM: {a.instruction_mix.vmem_count}",
263
+ f" SMEM: {a.instruction_mix.smem_count}",
264
+ f" LDS: {a.instruction_mix.lds_count}",
265
+ f" MFMA: {a.instruction_mix.mfma_count}",
266
+ f" Sync: {a.instruction_mix.sync_count}",
267
+ f" Total: {a.instruction_mix.total_count}",
268
+ "",
269
+ "=== Occupancy ===",
270
+ f" Max waves (VGPR): {a.max_waves_vgpr}",
271
+ f" Max waves (SGPR): {a.max_waves_sgpr}",
272
+ f" Max waves (LDS): {a.max_waves_lds}",
273
+ f" Theoretical: {a.theoretical_occupancy} waves/CU",
274
+ ])
275
+
276
+ if a.warnings:
277
+ lines.extend([
278
+ "",
279
+ "=== Warnings ===",
280
+ ])
281
+ for warning in a.warnings:
282
+ lines.append(f" {warning}")
283
+
284
+ elif result.ttgir_analysis:
285
+ a = result.ttgir_analysis
286
+ lines.extend([
287
+ "TTGIR Analysis",
288
+ "",
289
+ "=== Operations ===",
290
+ f" tt.dot: {a.dot_count}",
291
+ f" tt.load: {a.load_count}",
292
+ f" tt.store: {a.store_count}",
293
+ f" tt.reduce: {a.reduce_count}",
294
+ f" Barriers: {a.barrier_count}",
295
+ ])
296
+
297
+ if a.tile_info:
298
+ lines.extend([
299
+ "",
300
+ "=== Tiling ===",
301
+ f" BLOCK_M: {a.tile_info.block_m}",
302
+ f" BLOCK_N: {a.tile_info.block_n}",
303
+ f" BLOCK_K: {a.tile_info.block_k}",
304
+ f" num_warps: {a.tile_info.num_warps}",
305
+ f" num_stages: {a.tile_info.num_stages}",
306
+ ])
307
+
308
+ if a.has_software_pipelining:
309
+ lines.append(" Software pipelining: enabled")
310
+
311
+ if a.estimated_compute_intensity:
312
+ lines.append(f" Compute intensity: {a.estimated_compute_intensity:.1f} FLOPs/byte")
313
+
314
+ elif result.llvm_ir_analysis:
315
+ a = result.llvm_ir_analysis
316
+ lines.extend([
317
+ "LLVM-IR Analysis",
318
+ "",
319
+ f" Functions: {a.function_count}",
320
+ f" Total instructions: {a.total_instructions}",
321
+ f" Functions with loops: {a.functions_with_loops}",
322
+ f" Has vector ops: {a.has_vector_ops}",
323
+ ])
324
+
325
+ if a.kernel_functions:
326
+ lines.append(f" Kernel functions: {', '.join(a.kernel_functions)}")
327
+
328
+ return "\n".join(lines)
329
+
330
+
331
+ def _result_to_csv(result) -> str:
332
+ """Format single result as CSV."""
333
+ if result.isa_analysis:
334
+ a = result.isa_analysis
335
+ header = "kernel_name,architecture,vgpr_count,sgpr_count,spill_count,mfma_count,mfma_density_pct,occupancy"
336
+ row = f"{a.kernel_name},{a.architecture},{a.vgpr_count},{a.sgpr_count},{a.spill_count},{a.mfma_count},{a.mfma_density_pct:.2f},{a.theoretical_occupancy}"
337
+ return f"{header}\n{row}"
338
+
339
+ return "# Unsupported format for CSV"
340
+
341
+
342
+ def _batch_to_text(batch_result) -> str:
343
+ """Format batch results as text."""
344
+ lines = [
345
+ f"Analyzed {batch_result.total_files} files",
346
+ f" Successful: {batch_result.successful}",
347
+ f" Failed: {batch_result.failed}",
348
+ "",
349
+ ]
350
+
351
+ if batch_result.summary:
352
+ lines.extend([
353
+ "=== Summary ===",
354
+ f" Avg VGPRs: {batch_result.summary.get('total_vgpr_avg', 0):.1f}",
355
+ f" Avg SGPRs: {batch_result.summary.get('total_sgpr_avg', 0):.1f}",
356
+ f" Total spills: {batch_result.summary.get('total_spills', 0)}",
357
+ f" Files with spills: {batch_result.summary.get('files_with_spills', 0)}",
358
+ f" Total MFMA: {batch_result.summary.get('total_mfma', 0)}",
359
+ f" Avg MFMA density: {batch_result.summary.get('avg_mfma_density', 0):.1f}%",
360
+ "",
361
+ ])
362
+
363
+ # Show individual results
364
+ for result in batch_result.results:
365
+ if result.success and result.isa_analysis:
366
+ a = result.isa_analysis
367
+ status = "⚠️" if a.spill_count > 0 else "✓"
368
+ lines.append(
369
+ f" {status} {result.file_path}: "
370
+ f"VGPRs={a.vgpr_count}, spills={a.spill_count}, MFMA={a.mfma_count}"
371
+ )
372
+ elif not result.success:
373
+ lines.append(f" ✗ {result.file_path}: {result.error}")
374
+
375
+ return "\n".join(lines)
376
+
377
+
378
+ def _batch_to_csv(batch_result) -> str:
379
+ """Format batch results as CSV."""
380
+ lines = ["file_path,kernel_name,architecture,vgpr_count,sgpr_count,spill_count,mfma_count,mfma_density_pct,occupancy"]
381
+
382
+ for result in batch_result.results:
383
+ if result.success and result.isa_analysis:
384
+ a = result.isa_analysis
385
+ lines.append(
386
+ f"{result.file_path},{a.kernel_name},{a.architecture},"
387
+ f"{a.vgpr_count},{a.sgpr_count},{a.spill_count},"
388
+ f"{a.mfma_count},{a.mfma_density_pct:.2f},{a.theoretical_occupancy}"
389
+ )
390
+
391
+ return "\n".join(lines)
392
+
393
+
394
+ def _apply_filter(batch_result, filter_expr: str):
395
+ """Apply filter expression to batch results."""
396
+ # Simple filter parsing: "metric op value"
397
+ # Supported: spills > 0, vgpr_count > 128, mfma_count == 0
398
+ import re
399
+
400
+ match = re.match(r"(\w+)\s*(>|<|>=|<=|==|!=)\s*(\d+)", filter_expr)
401
+ if not match:
402
+ print(f"Warning: Invalid filter expression: {filter_expr}", file=sys.stderr)
403
+ return batch_result
404
+
405
+ metric = match.group(1)
406
+ op = match.group(2)
407
+ value = int(match.group(3))
408
+
409
+ # Map common aliases
410
+ metric_map = {
411
+ "spills": "spill_count",
412
+ "vgpr": "vgpr_count",
413
+ "sgpr": "sgpr_count",
414
+ "mfma": "mfma_count",
415
+ "occupancy": "theoretical_occupancy",
416
+ }
417
+ metric = metric_map.get(metric, metric)
418
+
419
+ # Filter function
420
+ def passes_filter(result):
421
+ if not result.success or not result.isa_analysis:
422
+ return False
423
+
424
+ actual = getattr(result.isa_analysis, metric, None)
425
+ if actual is None:
426
+ return False
427
+
428
+ if op == ">":
429
+ return actual > value
430
+ elif op == "<":
431
+ return actual < value
432
+ elif op == ">=":
433
+ return actual >= value
434
+ elif op == "<=":
435
+ return actual <= value
436
+ elif op == "==":
437
+ return actual == value
438
+ elif op == "!=":
439
+ return actual != value
440
+
441
+ return False
442
+
443
+ filtered_results = [r for r in batch_result.results if passes_filter(r)]
444
+
445
+ from wafer_core.lib.kernel_scope.api import BatchAnalysisResult
446
+
447
+ return BatchAnalysisResult(
448
+ total_files=len(filtered_results),
449
+ successful=sum(1 for r in filtered_results if r.success),
450
+ failed=sum(1 for r in filtered_results if not r.success),
451
+ results=tuple(filtered_results),
452
+ summary=batch_result.summary,
453
+ )