wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ """Same kernel analysis - comparing identical kernel names across platforms.
2
+
3
+ Identifies kernels where AMD and NVIDIA use the same kernel name/pattern
4
+ and compares their performance directly.
5
+ """
6
+
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import Any
10
+
11
+ from .aligner import KernelPair, LayerAlignment
12
+
13
+
14
+ @dataclass
15
+ class SameKernelComparison:
16
+ """Comparison of identical kernels across platforms."""
17
+
18
+ layer: int
19
+ kernel_name: str
20
+ operation: str
21
+ amd_avg_us: float
22
+ nvidia_avg_us: float
23
+ ratio: float
24
+ gap_us: float
25
+ amd_count: int
26
+ nvidia_count: int
27
+
28
+
29
+ @dataclass
30
+ class SameKernelAnalysis:
31
+ """Complete same kernel analysis result."""
32
+
33
+ kernels: list[SameKernelComparison] = field(default_factory=list)
34
+ summary: dict[str, Any] = field(default_factory=dict)
35
+
36
+
37
+ def analyze_same_kernels(
38
+ layer_alignments: list[LayerAlignment],
39
+ ) -> SameKernelAnalysis:
40
+ """Find and compare kernels with identical names across platforms.
41
+
42
+ Args:
43
+ layer_alignments: List of aligned layers
44
+
45
+ Returns:
46
+ SameKernelAnalysis with comparisons
47
+ """
48
+ same_kernels: list[SameKernelComparison] = []
49
+
50
+ for layer_alignment in layer_alignments:
51
+ for pair in layer_alignment.kernel_pairs:
52
+ if pair.is_same_kernel and pair.amd_kernel and pair.nvidia_kernel:
53
+ same_kernels.append(
54
+ SameKernelComparison(
55
+ layer=layer_alignment.layer,
56
+ kernel_name=pair.amd_kernel,
57
+ operation=pair.operation,
58
+ amd_avg_us=pair.amd_avg_us,
59
+ nvidia_avg_us=pair.nvidia_avg_us,
60
+ ratio=pair.ratio,
61
+ gap_us=pair.gap_us,
62
+ amd_count=pair.amd_count,
63
+ nvidia_count=pair.nvidia_count,
64
+ )
65
+ )
66
+
67
+ if same_kernels:
68
+ ratios = [k.ratio for k in same_kernels if k.ratio != float("inf")]
69
+ avg_ratio = sum(ratios) / len(ratios) if ratios else 1.0
70
+ amd_faster = sum(1 for k in same_kernels if k.ratio < 1.0)
71
+ nvidia_faster = sum(1 for k in same_kernels if k.ratio > 1.0)
72
+ else:
73
+ avg_ratio = 1.0
74
+ amd_faster = 0
75
+ nvidia_faster = 0
76
+
77
+ return SameKernelAnalysis(
78
+ kernels=same_kernels,
79
+ summary={
80
+ "total_same_kernels": len(same_kernels),
81
+ "avg_ratio": avg_ratio,
82
+ "kernels_where_amd_faster": amd_faster,
83
+ "kernels_where_nvidia_faster": nvidia_faster,
84
+ },
85
+ )
86
+
87
+
88
+ def analyze_same_kernels_from_alignment(
89
+ layer_alignments: list[LayerAlignment],
90
+ ) -> dict[str, Any]:
91
+ """Analyze same kernels from alignment data (for API compatibility).
92
+
93
+ Args:
94
+ layer_alignments: List of aligned layers
95
+
96
+ Returns:
97
+ Dictionary with same kernel analysis results
98
+ """
99
+ analysis = analyze_same_kernels(layer_alignments)
100
+
101
+ kernels = [
102
+ {
103
+ "layer": k.layer,
104
+ "kernel_name": k.kernel_name,
105
+ "operation": k.operation,
106
+ "amd_avg_us": k.amd_avg_us,
107
+ "nvidia_avg_us": k.nvidia_avg_us,
108
+ "ratio": k.ratio,
109
+ "gap_us": k.gap_us,
110
+ "amd_count": k.amd_count,
111
+ "nvidia_count": k.nvidia_count,
112
+ }
113
+ for k in analysis.kernels
114
+ ]
115
+
116
+ return {
117
+ "kernels": kernels,
118
+ "summary": analysis.summary,
119
+ }
@@ -0,0 +1,99 @@
1
+ """Warning detection and reporting for trace analysis.
2
+
3
+ Detects issues with trace data quality and provides actionable suggestions.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Literal
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class TraceWarning:
12
+ """A warning about trace data quality or analysis limitations."""
13
+
14
+ code: str # e.g., "NO_PHASE_ANNOTATIONS", "NO_PYTHON_STACKS"
15
+ severity: Literal["info", "warning", "error"]
16
+ message: str
17
+ suggestion: str
18
+
19
+
20
+ def detect_warnings(
21
+ events: list[dict],
22
+ kernel_names: list[str],
23
+ phases: list[dict] | None = None,
24
+ layers_detected: int = 0,
25
+ total_kernels: int = 0,
26
+ ) -> list[TraceWarning]:
27
+ """Detect warnings from trace data.
28
+
29
+ Args:
30
+ events: All trace events
31
+ kernel_names: List of all kernel names
32
+ phases: Optional list of phase events (for checking phase annotations)
33
+ layers_detected: Number of layers detected
34
+ total_kernels: Total number of kernels
35
+
36
+ Returns:
37
+ List of warnings
38
+ """
39
+ warnings: list[TraceWarning] = []
40
+
41
+ # Check for phase annotations
42
+ has_phase_annotations = any(
43
+ ev.get("cat") == "user_annotation" and ev.get("name", "").startswith("execute_context")
44
+ for ev in events
45
+ )
46
+
47
+ if not has_phase_annotations:
48
+ warnings.append(
49
+ TraceWarning(
50
+ code="NO_PHASE_ANNOTATIONS",
51
+ severity="warning",
52
+ message="No vLLM phase annotations found. Phase analysis (prefill/decode) will be unavailable.",
53
+ suggestion="Ensure you're using vLLM v1.0+ with profiling enabled. Re-profile with torch.profiler.profile() to capture phase markers.",
54
+ )
55
+ )
56
+
57
+ # Check for Python stack traces
58
+ has_python_stacks = any(
59
+ ev.get("cat") == "python_function"
60
+ for ev in events
61
+ )
62
+
63
+ if not has_python_stacks:
64
+ warnings.append(
65
+ TraceWarning(
66
+ code="NO_PYTHON_STACKS",
67
+ severity="info",
68
+ message="No Python stack traces available. CPU→kernel mapping will be limited.",
69
+ suggestion="Re-profile with with_stack=True: torch.profiler.profile(with_stack=True) for better CPU operator identification.",
70
+ )
71
+ )
72
+
73
+ # Check for high percentage of unknown kernels
74
+ if total_kernels > 0:
75
+ unknown_count = sum(1 for name in kernel_names if "unknown" in name.lower() or name == "Other")
76
+ unknown_percentage = (unknown_count / total_kernels) * 100
77
+
78
+ if unknown_percentage > 20:
79
+ warnings.append(
80
+ TraceWarning(
81
+ code="HIGH_UNKNOWN_KERNELS",
82
+ severity="warning",
83
+ message=f"{unknown_percentage:.1f}% of kernels are classified as 'Unknown'. Kernel registry may be outdated.",
84
+ suggestion="Update kernel pattern registry or report unrecognized kernel patterns for support.",
85
+ )
86
+ )
87
+
88
+ # Check for layer detection failure
89
+ if layers_detected == 0 and total_kernels > 100:
90
+ warnings.append(
91
+ TraceWarning(
92
+ code="LAYER_DETECTION_FAILED",
93
+ severity="warning",
94
+ message="No transformer layers detected. Layer-wise analysis unavailable.",
95
+ suggestion="This may indicate a non-transformer model (e.g., SSM/Mamba) or insufficient correlation data. Check model architecture.",
96
+ )
97
+ )
98
+
99
+ return warnings
@@ -1,5 +1,43 @@
1
- """Re-export targets from utils for convenience."""
1
+ """Target system: specs (config) + targets (live resources) + reconciliation.
2
2
 
3
+ New API (preferred):
4
+ from wafer_core.targets import Target, TargetSpec, ReconcileResult
5
+ from wafer_core.targets.providers import get_provider
6
+ from wafer_core.targets.reconcile import reconcile
7
+ from wafer_core.targets.spec_store import load_spec, list_spec_names
8
+ from wafer_core.targets.state_cache import get_binding_hints
9
+
10
+ Legacy API (still works, will be deprecated):
11
+ from wafer_core.targets import RunPodTarget, runpod_ssh_context, ...
12
+ """
13
+
14
+ # ── New types ────────────────────────────────────────────────────────────────
15
+ from wafer_core.targets.digitalocean import (
16
+ DigitalOceanError,
17
+ DigitalOceanSSHInfo,
18
+ cleanup_all_droplets,
19
+ cleanup_digitalocean_target,
20
+ digitalocean_ssh_context,
21
+ get_droplet_state,
22
+ list_running_droplets,
23
+ )
24
+ from wafer_core.targets.runpod import (
25
+ RunPodError,
26
+ RunPodSSHInfo,
27
+ cleanup_all_pods,
28
+ cleanup_target,
29
+ get_pod_state,
30
+ list_running_pods,
31
+ runpod_ssh_context,
32
+ )
33
+ from wafer_core.targets.types import (
34
+ ReconcileResult,
35
+ Target,
36
+ TargetProvider,
37
+ TargetSpec,
38
+ )
39
+
40
+ # ── Legacy re-exports (unchanged, for backwards compatibility) ───────────────
3
41
  from wafer_core.utils.kernel_utils.targets import (
4
42
  BaremetalTarget,
5
43
  DigitalOceanTarget,
@@ -18,26 +56,14 @@ from wafer_core.utils.kernel_utils.targets import (
18
56
  select_target_for_operation,
19
57
  target_to_deployment_config,
20
58
  )
21
- from wafer_core.targets.runpod import (
22
- RunPodError,
23
- RunPodSSHInfo,
24
- cleanup_all_pods,
25
- cleanup_target,
26
- get_pod_state,
27
- list_running_pods,
28
- runpod_ssh_context,
29
- )
30
- from wafer_core.targets.digitalocean import (
31
- DigitalOceanError,
32
- DigitalOceanSSHInfo,
33
- cleanup_all_droplets,
34
- cleanup_digitalocean_target,
35
- digitalocean_ssh_context,
36
- get_droplet_state,
37
- list_running_droplets,
38
- )
39
59
 
40
60
  __all__ = [
61
+ # New API
62
+ "Target",
63
+ "TargetSpec",
64
+ "TargetProvider",
65
+ "ReconcileResult",
66
+ # Legacy: target config types
41
67
  "BaremetalTarget",
42
68
  "VMTarget",
43
69
  "ModalTarget",
@@ -54,7 +80,7 @@ __all__ = [
54
80
  "check_target_available",
55
81
  "find_free_gpu",
56
82
  "run_operation_on_target",
57
- # RunPod provisioning
83
+ # Legacy: RunPod provisioning
58
84
  "RunPodError",
59
85
  "RunPodSSHInfo",
60
86
  "runpod_ssh_context",
@@ -62,7 +88,7 @@ __all__ = [
62
88
  "cleanup_all_pods",
63
89
  "list_running_pods",
64
90
  "get_pod_state",
65
- # DigitalOcean provisioning
91
+ # Legacy: DigitalOcean provisioning
66
92
  "DigitalOceanError",
67
93
  "DigitalOceanSSHInfo",
68
94
  "digitalocean_ssh_context",
@@ -0,0 +1,181 @@
1
+ """Pool queries: filter live targets by GPU type, provider, and labels.
2
+
3
+ A pool is a predicate over live targets, not a hardcoded list.
4
+ Pool queries are defined in ~/.wafer/config.toml:
5
+
6
+ [pools.mi300x]
7
+ gpu_type = "MI300X"
8
+
9
+ [pools.mi300x-rocm7]
10
+ gpu_type = "MI300X"
11
+ labels.rocm_version = "7.0.2"
12
+
13
+ [pools.runpod-only]
14
+ provider = "runpod"
15
+
16
+ Matching: a target matches a pool query if all specified fields match.
17
+ Fields not specified in the query are ignored (match anything).
18
+ Label matching is AND — all required labels must be present and equal.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, field
24
+ from pathlib import Path
25
+
26
+ from wafer_core.targets.types import Target
27
+
28
+ WAFER_DIR = Path.home() / ".wafer"
29
+ CONFIG_FILE = WAFER_DIR / "config.toml"
30
+
31
+ # Fields on PoolQuery that map directly to Target fields
32
+ _TARGET_FIELDS = ("gpu_type", "provider", "status")
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class PoolQuery:
37
+ """Predicate for filtering live targets.
38
+
39
+ All specified fields must match (AND semantics).
40
+ None means "don't care" for that field.
41
+ """
42
+
43
+ gpu_type: str | None = None
44
+ provider: str | None = None
45
+ status: str | None = "running"
46
+ labels: dict[str, str] = field(default_factory=dict)
47
+
48
+
49
+ def match_targets(query: PoolQuery, targets: list[Target]) -> list[Target]:
50
+ """Filter targets that satisfy the pool query. Pure function."""
51
+ matched = []
52
+ for target in targets:
53
+ if not _matches(query, target):
54
+ continue
55
+ matched.append(target)
56
+ return matched
57
+
58
+
59
+ def _matches(query: PoolQuery, target: Target) -> bool:
60
+ """Check if a single target satisfies the query."""
61
+ if query.gpu_type is not None and target.gpu_type != query.gpu_type:
62
+ return False
63
+ if query.provider is not None and target.provider != query.provider:
64
+ return False
65
+ if query.status is not None and target.status != query.status:
66
+ return False
67
+
68
+ # All required labels must be present and equal
69
+ for key, value in query.labels.items():
70
+ if target.labels.get(key) != value:
71
+ return False
72
+
73
+ return True
74
+
75
+
76
+ def load_pool_query(name: str) -> PoolQuery:
77
+ """Load a pool query from ~/.wafer/config.toml.
78
+
79
+ Raises KeyError if the pool is not defined.
80
+ """
81
+ pools = _load_pools_section()
82
+ if name not in pools:
83
+ available = ", ".join(sorted(pools)) if pools else "(none)"
84
+ raise KeyError(f"Pool {name!r} not found. Available: {available}")
85
+
86
+ raw = pools[name]
87
+ assert isinstance(raw, dict), f"Pool {name!r} must be a table, got {type(raw).__name__}"
88
+
89
+ labels_raw = raw.get("labels", {})
90
+ assert isinstance(labels_raw, dict), (
91
+ f"Pool {name!r} labels must be a table, got {type(labels_raw).__name__}"
92
+ )
93
+
94
+ return PoolQuery(
95
+ gpu_type=raw.get("gpu_type"),
96
+ provider=raw.get("provider"),
97
+ status=raw.get("status", "running"),
98
+ labels={str(k): str(v) for k, v in labels_raw.items()},
99
+ )
100
+
101
+
102
+ def list_pool_names() -> list[str]:
103
+ """List all pool names from config.toml."""
104
+ pools = _load_pools_section()
105
+ return sorted(pools.keys())
106
+
107
+
108
+ def is_query_pool(name: str) -> bool:
109
+ """Check if a pool is defined as a PoolQuery (new format) vs target list (old format).
110
+
111
+ Old format: [pools.name] targets = ["t1", "t2"]
112
+ New format: [pools.name] gpu_type = "MI300X"
113
+
114
+ Returns False if pool doesn't exist or is old format.
115
+ """
116
+ pools = _load_pools_section()
117
+ if name not in pools:
118
+ return False
119
+ raw = pools[name]
120
+ if not isinstance(raw, dict):
121
+ return False
122
+ # Old format has a "targets" key with a list of names
123
+ return "targets" not in raw
124
+
125
+
126
+ async def resolve_pool(name: str) -> list[Target]:
127
+ """Resolve a pool query to live targets.
128
+
129
+ Queries all cloud providers, hydrates cached labels, filters by pool query.
130
+ Returns matching Target objects sorted by resource_id for determinism.
131
+
132
+ Raises KeyError if pool not found.
133
+ """
134
+ from dataclasses import replace
135
+
136
+ from wafer_core.targets.providers import get_all_cloud_providers
137
+ from wafer_core.targets.state_cache import load_all_labels
138
+ from wafer_core.targets.types import TargetProvider
139
+
140
+ import trio
141
+
142
+ query = load_pool_query(name)
143
+
144
+ # Fetch all live targets
145
+ all_targets: list[Target] = []
146
+
147
+ async def _fetch(prov_impl: TargetProvider, results: list[Target]) -> None:
148
+ try:
149
+ targets = await prov_impl.list_targets()
150
+ results.extend(targets)
151
+ except Exception:
152
+ pass # Skip providers that fail (missing API key, etc.)
153
+
154
+ async with trio.open_nursery() as nursery:
155
+ for _, prov_impl in get_all_cloud_providers():
156
+ nursery.start_soon(_fetch, prov_impl, all_targets)
157
+
158
+ # Hydrate labels from cache
159
+ cached_labels = load_all_labels()
160
+ all_targets = [
161
+ replace(t, labels=cached_labels[t.resource_id])
162
+ if t.resource_id in cached_labels
163
+ else t
164
+ for t in all_targets
165
+ ]
166
+
167
+ # Filter and sort
168
+ matched = match_targets(query, all_targets)
169
+ matched.sort(key=lambda t: t.resource_id)
170
+ return matched
171
+
172
+
173
+ def _load_pools_section() -> dict:
174
+ """Read the [pools] section from config.toml. Returns empty dict if missing."""
175
+ if not CONFIG_FILE.exists():
176
+ return {}
177
+
178
+ import tomllib
179
+
180
+ data = tomllib.loads(CONFIG_FILE.read_text())
181
+ return data.get("pools", {})
@@ -0,0 +1,113 @@
1
+ """SSH probe: detect software labels on a live target.
2
+
3
+ Runs a Python script on the target via SSH that reports installed
4
+ software versions. Returns a flat dict[str, str] of labels.
5
+
6
+ Only called at provision time or manually via `wafer targets probe`.
7
+ Results are cached in target_state.json — probe is never implicit.
8
+
9
+ Uses subprocess ssh (not asyncssh) to match existing codebase patterns.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import subprocess
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Probe script runs on the target machine via SSH.
21
+ # Prints a JSON dict to stdout. Must work with stock Python 3.10+.
22
+ _PROBE_SCRIPT = r"""
23
+ import json, shutil, subprocess, sys
24
+
25
+ def probe():
26
+ result = {}
27
+
28
+ # Python version
29
+ result["python_version"] = ".".join(map(str, sys.version_info[:2]))
30
+
31
+ # ROCm version from filesystem
32
+ try:
33
+ with open("/opt/rocm/.info/version") as f:
34
+ result["rocm_version"] = f.read().strip().split("-")[0]
35
+ except Exception:
36
+ pass
37
+
38
+ # CUDA version from nvcc
39
+ nvcc = shutil.which("nvcc")
40
+ if nvcc:
41
+ try:
42
+ out = subprocess.check_output([nvcc, "--version"], text=True)
43
+ for line in out.split("\n"):
44
+ if "release" in line.lower():
45
+ parts = line.split("release")
46
+ if len(parts) > 1:
47
+ result["cuda_version"] = parts[1].split(",")[0].strip()
48
+ break
49
+ except Exception:
50
+ pass
51
+
52
+ # PyTorch version
53
+ try:
54
+ import torch
55
+ result["pytorch_version"] = torch.__version__.split("+")[0]
56
+ except ImportError:
57
+ pass
58
+
59
+ # Triton version
60
+ try:
61
+ import triton
62
+ result["triton_version"] = triton.__version__
63
+ except ImportError:
64
+ pass
65
+
66
+ print(json.dumps(result))
67
+
68
+ probe()
69
+ """
70
+
71
+
72
+ def probe_target_labels(
73
+ host: str,
74
+ port: int,
75
+ username: str,
76
+ ssh_key_path: str | None = None,
77
+ timeout: int = 60,
78
+ ) -> dict[str, str]:
79
+ """SSH into a target and probe installed software. Returns labels dict.
80
+
81
+ Raises on SSH failure — caller decides how to handle.
82
+ """
83
+ ssh_args = [
84
+ "ssh",
85
+ "-p", str(port),
86
+ "-o", "StrictHostKeyChecking=no",
87
+ "-o", "UserKnownHostsFile=/dev/null",
88
+ "-o", "LogLevel=ERROR",
89
+ "-o", "ConnectTimeout=10",
90
+ ]
91
+ if ssh_key_path:
92
+ ssh_args.extend(["-i", ssh_key_path])
93
+
94
+ ssh_args.append(f"{username}@{host}")
95
+ ssh_args.append("python3")
96
+
97
+ result = subprocess.run(
98
+ ssh_args,
99
+ input=_PROBE_SCRIPT,
100
+ capture_output=True,
101
+ text=True,
102
+ timeout=timeout,
103
+ )
104
+
105
+ if result.returncode != 0:
106
+ stderr = result.stderr.strip()
107
+ raise RuntimeError(f"Probe failed (exit {result.returncode}): {stderr}")
108
+
109
+ stdout = result.stdout.strip()
110
+ labels = json.loads(stdout)
111
+ assert isinstance(labels, dict), f"Probe returned {type(labels).__name__}, expected dict"
112
+
113
+ return {str(k): str(v) for k, v in labels.items()}
@@ -0,0 +1,46 @@
1
+ """Provider registry for GPU resource management.
2
+
3
+ Each provider implements TargetProvider protocol: list, get, provision, terminate.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from wafer_core.targets.providers.baremetal import BaremetalProvider
9
+ from wafer_core.targets.providers.digitalocean import DigitalOceanProvider
10
+ from wafer_core.targets.providers.runpod import RunPodProvider
11
+ from wafer_core.targets.types import TargetProvider
12
+
13
+ _PROVIDERS: dict[str, type] = {
14
+ "runpod": RunPodProvider,
15
+ "digitalocean": DigitalOceanProvider,
16
+ "baremetal": BaremetalProvider,
17
+ }
18
+
19
+
20
+ def get_provider(name: str) -> TargetProvider:
21
+ """Get a provider instance by name.
22
+
23
+ Raises KeyError if provider is not registered.
24
+ """
25
+ cls = _PROVIDERS.get(name)
26
+ if cls is None:
27
+ raise KeyError(f"Unknown provider: {name!r}. Available: {', '.join(sorted(_PROVIDERS))}")
28
+ return cls()
29
+
30
+
31
+ def get_all_cloud_providers() -> list[tuple[str, TargetProvider]]:
32
+ """Get all cloud providers that can list remote resources.
33
+
34
+ Excludes baremetal (no remote API to query).
35
+ Returns list of (name, provider) tuples.
36
+ """
37
+ return [(name, cls()) for name, cls in _PROVIDERS.items() if name != "baremetal"]
38
+
39
+
40
+ __all__ = [
41
+ "BaremetalProvider",
42
+ "DigitalOceanProvider",
43
+ "RunPodProvider",
44
+ "get_all_cloud_providers",
45
+ "get_provider",
46
+ ]