tritonparse 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (40) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/common.py +409 -0
  3. tritonparse/event_diff.py +120 -0
  4. tritonparse/extract_source_mappings.py +49 -0
  5. tritonparse/ir_parser.py +220 -0
  6. tritonparse/mapper.py +100 -0
  7. tritonparse/reproducer/__init__.py +21 -0
  8. tritonparse/reproducer/__main__.py +81 -0
  9. tritonparse/reproducer/cli.py +37 -0
  10. tritonparse/reproducer/config.py +15 -0
  11. tritonparse/reproducer/factory.py +16 -0
  12. tritonparse/reproducer/ingestion/__init__.py +6 -0
  13. tritonparse/reproducer/ingestion/ndjson.py +165 -0
  14. tritonparse/reproducer/orchestrator.py +65 -0
  15. tritonparse/reproducer/param_generator.py +142 -0
  16. tritonparse/reproducer/prompts/__init__.py +1 -0
  17. tritonparse/reproducer/prompts/loader.py +18 -0
  18. tritonparse/reproducer/providers/__init__.py +1 -0
  19. tritonparse/reproducer/providers/base.py +14 -0
  20. tritonparse/reproducer/providers/gemini.py +47 -0
  21. tritonparse/reproducer/runtime/__init__.py +1 -0
  22. tritonparse/reproducer/runtime/executor.py +13 -0
  23. tritonparse/reproducer/utils/io.py +6 -0
  24. tritonparse/shared_vars.py +9 -0
  25. tritonparse/source_type.py +56 -0
  26. tritonparse/sourcemap_utils.py +72 -0
  27. tritonparse/structured_logging.py +1046 -0
  28. tritonparse/tools/__init__.py +0 -0
  29. tritonparse/tools/decompress_bin_ndjson.py +118 -0
  30. tritonparse/tools/format_fix.py +149 -0
  31. tritonparse/tools/load_tensor.py +58 -0
  32. tritonparse/tools/prettify_ndjson.py +315 -0
  33. tritonparse/tp_logger.py +9 -0
  34. tritonparse/trace_processor.py +331 -0
  35. tritonparse/utils.py +156 -0
  36. tritonparse-0.1.1.dist-info/METADATA +10 -0
  37. tritonparse-0.1.1.dist-info/RECORD +40 -0
  38. tritonparse-0.1.1.dist-info/WHEEL +5 -0
  39. tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
  40. tritonparse-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,220 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from typing import Any, Dict, List
6
+
7
+ logger = logging.getLogger("SourceMapping")
8
+
9
+ # the definition of the #loc directive. they are in the bottom of the IR files
10
+ # Example:#loc2 = loc("/tmp/torchinductor_yhao/yp/abcdef.py":20:28)
11
+ LOC_PATTERN = re.compile(r'#loc(\d*) = loc\("([^"]+)":(\d+):(\d+)\)')
12
+
13
+ # the reference to the #loc directive. they are in the end of lines of the IR files
14
+ # Example: loc(#loc2)
15
+ CODE_LOC_PATTERN = re.compile(r".*loc\(#loc(\d*)\)\s*$")
16
+
17
+ # this pattern is used in the first function arguments line.
18
+ DIRECT_FILE_PATTERN = re.compile(r'.*loc\("([^"]+)":(\d+):(\d+)\)')
19
+
20
+ # the definition of the PTX loc directive.
21
+ # Example: .loc 1 0 50 // abcdef.py:0:50
22
+ PTX_LOC_PATTERN = re.compile(
23
+ r"^\s*\.loc\s+\d+\s+(\d+)\s+(\d+)\s+//\s*(.+?):(\d+):(\d+)"
24
+ )
25
+
26
+ # the definition of the AMDGCN loc directive.
27
+ # Example: .loc 1 32 30 ; abcd.py:32:30
28
+ # .loc 1 32 46 is_stmt 0 ; abcd.py:32:46
29
+ AMDGCN_LOC_PATTERN = re.compile(
30
+ r".*loc\s+(\d+)\s+(\d+)\s+(\d+)(?:\s+[^;]*)?;\s*(.+?):(\d+):(\d+)"
31
+ )
32
+
33
+
34
+ def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
35
+ """
36
+ Extracts location definitions from the given IR content.
37
+
38
+ This function searches for #loc directives in the provided IR content string.
39
+ It identifies the main #loc directive, which is a special case located at the top
40
+ of the IR files, and any subsequent #loc directives that define source file locations.
41
+
42
+ Args:
43
+ ir_content (str): The content of the IR file as a string.
44
+
45
+ Returns:
46
+ Dict[str, Dict[str, Any]]: A dictionary mapping location IDs to their corresponding
47
+ file names, line numbers, and column numbers.
48
+ """
49
+ locations = {}
50
+ # The first #loc directive is a special case. It locates at the top of the IR files
51
+ main_match = re.search(r'#loc = loc\("([^"]+)":(\d+):(\d+)\)', ir_content)
52
+ if main_match:
53
+ locations["1"] = {
54
+ "file": main_match.group(1),
55
+ "line": int(main_match.group(2)),
56
+ "column": int(main_match.group(3)),
57
+ }
58
+ # #loc1 = loc(unknown) is another special case. We ignore it.
59
+ for loc_id, filename, line, col in LOC_PATTERN.findall(ir_content):
60
+ key = loc_id
61
+ locations[key] = {"file": filename, "line": int(line), "column": int(col)}
62
+ return locations
63
+
64
+
65
+ def extract_code_locations(ir_content: str) -> Dict[int, str]:
66
+ """
67
+ Extracts code location mappings from the given IR content.
68
+
69
+ This function scans through the provided IR content line by line and identifies
70
+ lines that contain location references. It uses regular expressions to match
71
+ both the #loc directives and direct file references. The function returns a
72
+ dictionary mapping line numbers to their corresponding location identifiers.
73
+ Limitations:
74
+ For the first function arguments line, it may use some #loc(file:line:col), DIRECT_FILE_PATTERN, we only use the first location reference.
75
+ Args:
76
+ ir_content (str): The content of the IR file as a string.
77
+
78
+ Returns:
79
+ Dict[int, str]: A dictionary mapping line numbers to location identifiers,
80
+ which can be either a #loc identifier or a direct file reference.
81
+ """
82
+ line_to_loc = {}
83
+ for i, line in enumerate(ir_content.split("\n"), start=1):
84
+ if m := CODE_LOC_PATTERN.search(line):
85
+ line_to_loc[i] = m.group(1) or "0"
86
+ elif m := DIRECT_FILE_PATTERN.search(line):
87
+ file_path, ln, col = m.groups()
88
+ line_to_loc[i] = f"direct:{file_path}:{ln}:{col}"
89
+ return line_to_loc
90
+
91
+
92
+ def extract_ptx_amdgcn_mappings(
93
+ content: str, other_mappings: List[Any] = None, ir_type: str = "ptx"
94
+ ) -> Dict[str, Dict[str, Any]]:
95
+ """
96
+ Extract mappings from PTX code where `.loc` directives provide source file and line info.
97
+ This function only processes code between the function begin and end markers (e.g., "// -- Begin function" and "// -- End function"). The PTX source code line mapping is quite different from that of other IRs. It segments the PTX code using the .loc directive, where each .loc directive provides information for mapping to a source code line.
98
+
99
+ This function:
100
+ 1. Identifies the function boundary in PTX code
101
+ 2. Only processes code within the function boundary
102
+ 3. Maps PTX lines with `.loc` directives to source files and line numbers
103
+ 4. Associates subsequent code lines with the most recent `.loc` directive
104
+
105
+ Args:
106
+ ptx_content: The content of the PTX file
107
+
108
+ Returns:
109
+ Dictionary mapping PTX line numbers to source location information
110
+ """
111
+ mappings = {}
112
+ current_mapping = None
113
+
114
+ # Mark function scope
115
+ function_start_line = 0
116
+ function_end_line = 0
117
+ # filename: {file_path, ...}
118
+ referenced_files = defaultdict(set)
119
+ if other_mappings is None:
120
+ other_mappings = []
121
+ for other in other_mappings:
122
+ for _, info in other.items():
123
+ if "file" in info:
124
+ file_name = os.path.basename(info["file"])
125
+ referenced_files[file_name].add(info["file"])
126
+
127
+ def get_file_path(filename: str) -> str:
128
+ file_path = filename
129
+ if not os.path.isabs(filename):
130
+ logger.debug(
131
+ f"Filename '{filename}' does not contain a path. Attempting to resolve."
132
+ )
133
+ # Attempt to resolve the filename to a full path using referenced_files
134
+ if filename in referenced_files:
135
+ if len(referenced_files[filename]) > 1:
136
+ logger.debug(
137
+ f"Filename '{filename}' has multiple file paths. Using the first one."
138
+ )
139
+ file_path = list(referenced_files[filename])[0]
140
+ logger.debug(f"Resolved filename '{filename}' to {file_path}")
141
+ else:
142
+ logger.debug(f"Filename '{filename}' not found in referenced files.")
143
+ return file_path
144
+
145
+ # Regular expressions to match function start and end markers
146
+ # @TODO: need to double check if the PTX content only contains one function
147
+ begin_func_pattern = re.compile(
148
+ r"(?:(?://|;)\s*(?:\.globl\s+\S+\s+)?|\.globl\s+\S+\s+;\s*)--\s*Begin function"
149
+ )
150
+ end_func_pattern = re.compile(r"(?://|;)\s*--\s*End function")
151
+
152
+ # First scan: find function boundaries
153
+ lines = content.split("\n")
154
+ for i, line in enumerate(lines, 1):
155
+ if begin_func_pattern.search(line) and function_start_line == 0:
156
+ function_start_line = i
157
+ elif end_func_pattern.search(line) and function_start_line > 0:
158
+ function_end_line = i
159
+ break
160
+
161
+ # If no function boundaries are found, return empty mapping
162
+ if function_start_line == 0 or function_end_line == 0:
163
+ logger.warning(
164
+ f"Could not identify {ir_type} function boundaries. No {ir_type} mappings generated."
165
+ )
166
+ return mappings
167
+
168
+ logger.debug(
169
+ f"Processing {ir_type} function from line {function_start_line} to {function_end_line}"
170
+ )
171
+
172
+ is_ptx = ir_type == "ptx"
173
+ is_amdgcn = ir_type == "amdgcn"
174
+
175
+ tmp_loc_pattern = PTX_LOC_PATTERN if is_ptx else AMDGCN_LOC_PATTERN
176
+ # Second scan: process code within function body
177
+ # pay attention to the line number, it starts from 0 but the function_start_line starts from 1
178
+ for i, line in enumerate(
179
+ lines[function_start_line:function_end_line], start=function_start_line + 1
180
+ ):
181
+ try:
182
+ # Check .loc directive line
183
+ match = tmp_loc_pattern.match(line)
184
+ if match:
185
+ if is_ptx:
186
+ py_line, py_col, filename, _, _ = match.groups()
187
+ elif is_amdgcn:
188
+ py_file_index, py_line, py_col, filename, _, _ = match.groups()
189
+ else:
190
+ logger.error(f"Unknown IR type: {ir_type}")
191
+ raise ValueError(f"Unknown IR type: {ir_type}")
192
+ file_path = get_file_path(filename)
193
+ # Create new mapping
194
+ current_mapping = {
195
+ "file": file_path,
196
+ "line": int(py_line),
197
+ "column": int(py_col),
198
+ f"{ir_type}_line": i,
199
+ }
200
+ # Store mapping
201
+ mappings[str(i)] = current_mapping
202
+ elif current_mapping:
203
+ # For lines without their own .loc after .loc directive, associate with the nearest .loc mapping
204
+ # Only process non-empty, non-comment meaningful code lines
205
+ line_content = line.strip()
206
+ if line_content and not (
207
+ (is_ptx and line_content.startswith("//"))
208
+ or (is_amdgcn and line_content.startswith(";"))
209
+ ):
210
+ mappings[str(i)] = {
211
+ "file": current_mapping["file"],
212
+ "line": current_mapping["line"],
213
+ "column": current_mapping["column"],
214
+ f"{ir_type}_line": i,
215
+ }
216
+ except Exception as e:
217
+ logger.error(f"Error processing line {i}: {e}")
218
+ logger.error(f"Line content: {line}")
219
+ raise e
220
+ return mappings
tritonparse/mapper.py ADDED
@@ -0,0 +1,100 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+
6
+ logger = logging.getLogger("SourceMapping")
7
+
8
+
9
+ def create_python_mapping(
10
+ ir_maps: List[Tuple[str, Dict[str, Dict[str, Any]]]],
11
+ ) -> Dict[int, Dict[str, List[int]]]:
12
+ """
13
+ Create a mapping from Python source code to IR mappings. We assume there is only one Python source code for each triton kernel.
14
+ Args:
15
+ ir_maps: A list of tuples containing the IR type and the IR mappings.
16
+
17
+ Returns:
18
+ A dictionary mapping Python source code line numbers to their corresponding IR mappings.
19
+ """
20
+ py_map = defaultdict(lambda: defaultdict(list))
21
+ for ir_type, ir_map in ir_maps:
22
+ for line_number, info in ir_map.items():
23
+ py_line_number: int = info["line"]
24
+ py_map[py_line_number][f"{ir_type}_lines"].append(line_number)
25
+ return {k: dict(v) for k, v in py_map.items()}
26
+
27
+
28
+ def create_ir_mapping(
29
+ source_map: Dict[str, Dict[str, Any]], target_map: Dict[str, Dict[str, Any]]
30
+ ) -> Dict[str, List[int]]:
31
+ """
32
+ Create a mapping from source IR lines to target IR lines.
33
+
34
+ This function takes two mappings: one for source IR and one for target IR, and creates a new mapping
35
+ that associates lines in the source IR with corresponding lines in the target IR based on their file,
36
+ line, and column information.
37
+
38
+ Args:
39
+ source_map (Dict[str, Dict[str, Any]]): A dictionary mapping source IR line numbers to their source file,
40
+ line, and column information.
41
+ target_map (Dict[str, Dict[str, Any]]): A dictionary mapping target IR line numbers to their source file,
42
+ line, and column information.
43
+
44
+ Returns:
45
+ Dict[str, List[int]]: A dictionary mapping source IR line numbers to lists of corresponding target IR line numbers.
46
+ """
47
+ source_to_target = defaultdict(list)
48
+
49
+ # Build a mapping from source file locations to target lines
50
+ for tgt_line, tgt_info in target_map.items():
51
+ if "file" in tgt_info and "line" in tgt_info:
52
+ key = f"{tgt_info['file']}:{tgt_info['line']}:{tgt_info.get('column', 0)}"
53
+ source_to_target[key].append(int(tgt_line))
54
+
55
+ # Map source lines to target lines
56
+ mapping = {}
57
+ for src_line, src_info in source_map.items():
58
+ if "file" in src_info and "line" in src_info:
59
+ key = f"{src_info['file']}:{src_info['line']}:{src_info.get('column', 0)}"
60
+ if key in source_to_target:
61
+ mapping[src_line] = sorted(source_to_target[key])
62
+
63
+ return mapping
64
+
65
+
66
+ def create_bidirectional_mapping(
67
+ source_map: Dict[str, Dict[str, Any]],
68
+ target_map: Dict[str, Dict[str, Any]],
69
+ source_type: str,
70
+ target_type: str,
71
+ ) -> None:
72
+ """
73
+ Create bidirectional mappings between two IR types and update their mapping dictionaries.
74
+
75
+ This function creates mappings from source IR to target IR and vice versa, then
76
+ updates both mapping dictionaries with the line references.
77
+
78
+ Args:
79
+ source_map: Dictionary mapping source IR line numbers to source locations
80
+ target_map: Dictionary mapping target IR line numbers to source locations
81
+ source_type: String identifier for the source IR type (e.g., 'ttir', 'ttgir', 'ptx')
82
+ target_type: String identifier for the target IR type (e.g., 'ttir', 'ttgir', 'ptx')
83
+ """
84
+ # Create forward mapping (source to target)
85
+ source_to_target = create_ir_mapping(source_map, target_map)
86
+
87
+ # Add target line references to source mappings
88
+ for source_line, target_lines in source_to_target.items():
89
+ if source_line in source_map and target_lines:
90
+ source_map[source_line][f"{target_type}_lines"] = target_lines
91
+
92
+ # Create reverse mapping (target to source)
93
+ target_to_source = create_ir_mapping(target_map, source_map)
94
+
95
+ # Add source line references to target mappings
96
+ for target_line, source_lines in target_to_source.items():
97
+ if target_line in target_map:
98
+ target_map[target_line][f"{source_type}_lines"] = source_lines
99
+
100
+ logger.debug(f"Created {source_type} to {target_type} mappings (and reverse)")
@@ -0,0 +1,21 @@
1
+ """Reproducer subpackage: generate runnable Triton repro scripts from traces.
2
+
3
+ Contains:
4
+ - ingestion.ndjson: parse NDJSON and build a context bundle
5
+ - orchestrator: LLM-based code generation with optional execute/repair
6
+ - providers: LLM provider protocol and Gemini provider
7
+ - prompts: simple prompt loader and templates
8
+ - runtime.executor: helper to run generated Python scripts
9
+ - param_generator: synthesize tensor/scalar allocations to reduce LLM burden
10
+ """
11
+
12
+ from .ingestion.ndjson import build_context_bundle
13
+ from .orchestrator import generate_from_ndjson
14
+ from .param_generator import generate_allocation_snippet, generate_kwargs_dict
15
+
16
+ __all__ = [
17
+ "build_context_bundle",
18
+ "generate_from_ndjson",
19
+ "generate_allocation_snippet",
20
+ "generate_kwargs_dict",
21
+ ]
@@ -0,0 +1,81 @@
1
+ import argparse
2
+ import sys
3
+
4
+
5
+ def main() -> None:
6
+ p = argparse.ArgumentParser(
7
+ description=(
8
+ "Generate a runnable Triton repro script from a tritonparse NDJSON" " trace"
9
+ )
10
+ )
11
+ p.add_argument("--ndjson", required=True, help="Path to NDJSON trace file")
12
+ p.add_argument(
13
+ "--launch-index",
14
+ type=int,
15
+ default=0,
16
+ help="Launch index to reproduce",
17
+ )
18
+ p.add_argument("--out", default="repro.py", help="Output Python file path")
19
+ p.add_argument(
20
+ "--execute",
21
+ action="store_true",
22
+ help="Execute the generated script",
23
+ )
24
+ p.add_argument(
25
+ "--retries",
26
+ type=int,
27
+ default=0,
28
+ help="Auto-repair attempts if execution fails",
29
+ )
30
+ p.add_argument(
31
+ "--temperature",
32
+ type=float,
33
+ help="Override sampling temperature",
34
+ )
35
+ p.add_argument(
36
+ "--max-tokens",
37
+ type=int,
38
+ help="Override max tokens for generation",
39
+ )
40
+ args = p.parse_args()
41
+
42
+ # Lazy imports to allow `--help` without optional deps installed
43
+ from .config import load_config
44
+ from .orchestrator import generate_from_ndjson
45
+
46
+ try:
47
+ from .factory import make_gemini_provider
48
+ except Exception: # pragma: no cover
49
+ print(
50
+ "Failed to import provider factory. Ensure optional deps are installed (e.g. google-genai).",
51
+ file=sys.stderr,
52
+ )
53
+ raise
54
+
55
+ cfg = load_config()
56
+ try:
57
+ provider = make_gemini_provider()
58
+ except ModuleNotFoundError: # pragma: no cover
59
+ print(
60
+ "Gemini provider requires 'google-genai'. Install via: pip install google-genai",
61
+ file=sys.stderr,
62
+ )
63
+ sys.exit(2)
64
+ temperature = args.temperature if args.temperature is not None else cfg.temperature
65
+ max_tokens = args.max_tokens if args.max_tokens is not None else cfg.max_tokens
66
+
67
+ res = generate_from_ndjson(
68
+ args.ndjson,
69
+ provider,
70
+ launch_index=args.launch_index,
71
+ out_py=args.out,
72
+ execute=args.execute,
73
+ retries=args.retries,
74
+ temperature=temperature,
75
+ max_tokens=max_tokens,
76
+ )
77
+ print(res)
78
+
79
+
80
+ if __name__ == "__main__": # pragma: no cover
81
+ main()
@@ -0,0 +1,37 @@
1
+ import argparse
2
+
3
+ from .config import load_config
4
+ from .factory import make_gemini_provider
5
+ from .orchestrator import generate_from_ndjson
6
+
7
+
8
+ def add_reproducer_subparser(parser: argparse.ArgumentParser) -> None:
9
+ sub = parser.add_subparsers(dest="subcommand")
10
+ repro = sub.add_parser(
11
+ "repro",
12
+ help="Generate a runnable Triton repro script from NDJSON",
13
+ )
14
+ repro.add_argument("--ndjson", required=True)
15
+ repro.add_argument("--launch-index", type=int, default=0)
16
+ repro.add_argument("--out", default="repro.py")
17
+ repro.add_argument("--execute", action="store_true")
18
+ repro.add_argument("--retries", type=int, default=0)
19
+
20
+
21
+ def maybe_handle_reproducer(args: argparse.Namespace) -> bool:
22
+ if getattr(args, "subcommand", None) != "repro":
23
+ return False
24
+ cfg = load_config()
25
+ provider = make_gemini_provider()
26
+ res = generate_from_ndjson(
27
+ args.ndjson,
28
+ provider,
29
+ launch_index=args.launch_index,
30
+ out_py=args.out,
31
+ execute=args.execute,
32
+ retries=args.retries,
33
+ temperature=cfg.temperature,
34
+ max_tokens=cfg.max_tokens,
35
+ )
36
+ print(res)
37
+ return True
@@ -0,0 +1,15 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class ReproducerConfig:
7
+ project: str = os.getenv("GOOGLE_CLOUD_PROJECT", "")
8
+ location: str = os.getenv("GOOGLE_LOCATION", "us-central1")
9
+ model: str = os.getenv("TP_REPRO_MODEL", "gemini-2.5-pro")
10
+ temperature: float = float(os.getenv("TP_REPRO_TEMPERATURE", "0.1"))
11
+ max_tokens: int = int(os.getenv("TP_REPRO_MAX_TOKENS", "10240"))
12
+
13
+
14
+ def load_config() -> ReproducerConfig:
15
+ return ReproducerConfig()
@@ -0,0 +1,16 @@
1
+ """Provider factory for reproducer.
2
+
3
+ Currently supports Gemini only.
4
+ """
5
+
6
+ from .config import load_config
7
+ from .providers.gemini import GeminiProvider
8
+
9
+
10
+ def make_gemini_provider() -> GeminiProvider:
11
+ cfg = load_config()
12
+ return GeminiProvider(
13
+ project=cfg.project,
14
+ location=cfg.location,
15
+ model=cfg.model,
16
+ )
@@ -0,0 +1,6 @@
1
+ """Ingestion utilities for reproducer.
2
+
3
+ Currently supports NDJSON trace parsing.
4
+ """
5
+
6
+ __all__ = []
@@ -0,0 +1,165 @@
1
+ import json
2
+ from typing import Any, Dict, List
3
+
4
+
5
+ def _iter_events(path: str):
6
+ with open(path, "r", encoding="utf-8") as f:
7
+ for line in f:
8
+ line = line.strip()
9
+ if not line:
10
+ continue
11
+ try:
12
+ yield json.loads(line)
13
+ except json.JSONDecodeError:
14
+ # skip malformed lines
15
+ continue
16
+
17
+
18
+ def _index_compilations(events: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
19
+ idx = {}
20
+ for e in events:
21
+ if e.get("event_type") != "compilation":
22
+ continue
23
+ payload = e.get("payload") or {}
24
+ meta = payload.get("metadata") or {}
25
+ h = meta.get("hash")
26
+ if h:
27
+ idx[h] = e
28
+ return idx
29
+
30
+
31
+ def _get_launches(events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
32
+ return [e for e in events if e.get("event_type") == "launch"]
33
+
34
+
35
+ def _resolve_kernel_source(
36
+ launch: Dict[str, Any], comp_idx: Dict[str, Dict[str, Any]]
37
+ ) -> str:
38
+ # In new format, launch has top-level compilation_metadata, not payload.*
39
+ comp_meta = (
40
+ launch.get("compilation_metadata")
41
+ or launch.get("payload", {}).get("compilation_metadata")
42
+ or {}
43
+ )
44
+ h = comp_meta.get("hash")
45
+ if not h:
46
+ return ""
47
+ comp = comp_idx.get(h, {})
48
+ payload = comp.get("payload") or {}
49
+ py = payload.get("python_source") or {}
50
+ return py.get("code", "")
51
+
52
+
53
+ def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
54
+ packed = {}
55
+ for k, v in args.items():
56
+ t = v.get("type") if isinstance(v, dict) else None
57
+ if t == "tensor":
58
+ packed[k] = {
59
+ "type": "tensor",
60
+ "shape": v.get("shape") if isinstance(v, dict) else None,
61
+ "dtype": v.get("dtype") if isinstance(v, dict) else None,
62
+ "device": v.get("device") if isinstance(v, dict) else None,
63
+ "stride": v.get("stride") if isinstance(v, dict) else None,
64
+ "is_contiguous": (
65
+ v.get("is_contiguous") if isinstance(v, dict) else None
66
+ ),
67
+ "numel": v.get("numel") if isinstance(v, dict) else None,
68
+ }
69
+ else:
70
+ # scalar / NoneType etc
71
+ if isinstance(v, dict):
72
+ packed[k] = {
73
+ "type": v.get("type"),
74
+ "value": v.get("value", v.get("repr")),
75
+ }
76
+ else:
77
+ packed[k] = {
78
+ "type": None,
79
+ "value": v,
80
+ }
81
+ return packed
82
+
83
+
84
+ # Sentinel and helper to normalize extracted argument values
85
+ _SKIP = object()
86
+
87
+
88
+ def _decode_arg(raw: Any):
89
+ if not isinstance(raw, dict):
90
+ return raw
91
+ t = raw.get("type")
92
+ if t == "tensor":
93
+ return _SKIP
94
+ if t == "NoneType":
95
+ return None
96
+ return raw.get("value", raw.get("repr"))
97
+
98
+
99
+ def build_context_bundle(ndjson_path: str, launch_index: int = 0) -> Dict[str, Any]:
100
+ events = list(_iter_events(ndjson_path))
101
+ launches = _get_launches(events)
102
+ if not launches:
103
+ raise RuntimeError("No launch events found in NDJSON.")
104
+ if launch_index < 0 or launch_index >= len(launches):
105
+ raise IndexError(
106
+ f"launch_index out of range: {launch_index} (total {len(launches)})"
107
+ )
108
+ launch = launches[launch_index]
109
+ comp_idx = _index_compilations(events)
110
+ kernel_source = _resolve_kernel_source(launch, comp_idx)
111
+ # find '@triton.jit' and slice the string
112
+ jit_marker = "@triton.jit"
113
+ jit_pos = kernel_source.find(jit_marker)
114
+ if jit_pos != -1:
115
+ kernel_source = kernel_source[jit_pos:]
116
+
117
+ # flatten launch fields (support both formats)
118
+ grid = launch.get("grid") or (launch.get("payload", {})).get("grid")
119
+ comp_meta = (
120
+ launch.get("compilation_metadata")
121
+ or (launch.get("payload", {})).get("compilation_metadata")
122
+ or {}
123
+ )
124
+ extracted_args = (
125
+ launch.get("extracted_args")
126
+ or (launch.get("payload", {})).get("extracted_args")
127
+ or {}
128
+ )
129
+
130
+ # compile metadata subset we care about
131
+ compile_block = {
132
+ "num_warps": comp_meta.get("num_warps"),
133
+ "num_stages": comp_meta.get("num_stages"),
134
+ "arch": comp_meta.get("arch"),
135
+ "backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
136
+ "triton_version": comp_meta.get("triton_version"),
137
+ "hash": comp_meta.get("hash"),
138
+ }
139
+
140
+ # kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
141
+ kwargs = {}
142
+ for k, v in extracted_args.items():
143
+ val = _decode_arg(v)
144
+ if val is _SKIP:
145
+ continue
146
+ kwargs[k] = val
147
+
148
+ # tensor args: only tensors
149
+ tensor_args = {
150
+ k: v
151
+ for k, v in extracted_args.items()
152
+ if isinstance(v, dict) and v.get("type") == "tensor"
153
+ }
154
+
155
+ bundle = {
156
+ "kernel_source": kernel_source,
157
+ "compile": compile_block,
158
+ "launch": {
159
+ "grid": grid,
160
+ "kwargs": kwargs,
161
+ },
162
+ "args": _pack_args(extracted_args),
163
+ "tensor_args": _pack_args(tensor_args),
164
+ }
165
+ return bundle