tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,367 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import gzip
4
+ import json
5
+ import logging
6
+ import os
7
+ from collections import defaultdict
8
+ from typing import Any, Dict, List
9
+
10
+ from .event_diff import _generate_launch_diff
11
+ from .ir_analysis import _generate_ir_analysis
12
+ from .ir_parser import (
13
+ extract_code_locations,
14
+ extract_loc_definitions,
15
+ extract_ptx_amdgcn_mappings,
16
+ )
17
+ from .mapper import create_bidirectional_mapping, create_python_mapping
18
+ from .sourcemap_utils import get_file_extension, load_ir_contents
19
+
20
+ logger = logging.getLogger("SourceMapping")
21
+
22
+
23
+ def generate_source_mappings(
24
+ ir_content: str, ir_type: str, other_mappings: List[Any] = None
25
+ ) -> Dict[str, Dict[str, Any]]:
26
+ """
27
+ Generate source mappings from intermediate representation (IR) content to the source file.
28
+ Example:
29
+ loc definition: Line 39 in ttir: #loc2 = loc("/tmp/torchinductor_yhao/yp/abcdef.py":20:28)
30
+ loc reference: Line 9 in ttir: %0 = tt.get_program_id x : i32 loc(#loc2)
31
+ Then, the output will be:
32
+ {
33
+ "9": {
34
+ "file": "/tmp/torchinductor_yhao/yp/abcdef.py",
35
+ "line": 20,
36
+ "column": 28,
37
+ "ttir_line": 9
38
+ },
39
+ }
40
+
41
+ Args:
42
+ ir_content (str): The content of the intermediate representation.
43
+ ir_type (str): The type of the intermediate representation (e.g., 'ttir').
44
+ other_mappings (List[Any]): A collection of additional mappings, primarily utilized for PTX mappings since PTX's location annotations reference the file name instead of the complete path.
45
+
46
+ Returns:
47
+ Dict[str, Dict[str, Any]]: A dictionary mapping line numbers to their corresponding source file,
48
+ line, column, and the line number in the IR.
49
+ """
50
+ if other_mappings is None:
51
+ other_mappings = []
52
+ if ir_type == "ptx" or ir_type == "amdgcn":
53
+ return extract_ptx_amdgcn_mappings(ir_content, other_mappings, ir_type)
54
+
55
+ loc_defs = extract_loc_definitions(ir_content)
56
+ logger.debug(f"Found {len(loc_defs)} #loc definitions")
57
+
58
+ loc_refs = extract_code_locations(ir_content)
59
+ logger.debug(f"Found {len(loc_refs)} loc references")
60
+
61
+ mappings = {}
62
+ for ln, loc_id in loc_refs.items():
63
+ if loc_id.startswith("direct:"):
64
+ _, file_path, line, col = loc_id.split(":", 3)
65
+ mappings[str(ln)] = {
66
+ "file": file_path,
67
+ "line": int(line),
68
+ "column": int(col),
69
+ f"{ir_type}_line": ln,
70
+ }
71
+ elif loc_id in loc_defs:
72
+ info = loc_defs[loc_id]
73
+ entry = {
74
+ "file": info["file"],
75
+ "line": info["line"],
76
+ "column": info["column"],
77
+ f"{ir_type}_line": ln,
78
+ }
79
+ # Propagate callsite metadata if present
80
+ if info.get("is_callsite"):
81
+ entry["is_callsite"] = True
82
+ entry["callsite_callee"] = info["callsite_callee"]
83
+ entry["callsite_caller"] = info["callsite_caller"]
84
+ # Propagate alias metadata if present
85
+ if "alias_name" in info:
86
+ entry["alias_name"] = info["alias_name"]
87
+ if "alias_of" in info:
88
+ entry["loc_id"] = loc_id
89
+ mappings[str(ln)] = entry
90
+
91
+ # Add separate entries for loc definition lines
92
+ for loc_id, info in loc_defs.items():
93
+ if "def_line" not in info:
94
+ continue
95
+ def_ln = info["def_line"]
96
+ # Only create mapping if this line doesn't already have one
97
+ if str(def_ln) not in mappings:
98
+ entry = {
99
+ "file": info["file"],
100
+ "line": info["line"],
101
+ "column": info["column"],
102
+ f"{ir_type}_line": def_ln,
103
+ "kind": "loc_def",
104
+ }
105
+ if "alias_name" in info:
106
+ entry["alias_name"] = info["alias_name"]
107
+ if "alias_of" in info:
108
+ entry["loc_id"] = loc_id
109
+ mappings[str(def_ln)] = entry
110
+
111
+ return mappings
112
+
113
+
114
+ def process_ir(
115
+ key: str,
116
+ file_content: Dict[str, str],
117
+ file_path: Dict[str, str],
118
+ other_mappings: List[Any] = None,
119
+ ):
120
+ ir_content = load_ir_contents(key, file_content, file_path)
121
+ if not ir_content:
122
+ return {}
123
+ mapping = generate_source_mappings(ir_content, key.split(".")[1], other_mappings)
124
+ logger.debug(f"Generated source mapping for {key}")
125
+ return mapping
126
+
127
+
128
+ def parse_single_trace_content(trace_content: str) -> str:
129
+ """
130
+ Process a single trace content and extract source code mappings.
131
+
132
+ This function takes a trace content as input, extracts the IR files, generates source mappings,
133
+ creates bidirectional mappings between different IR types, and updates the payload with the mappings.
134
+
135
+ Args:
136
+ trace_content (str): The content of the trace file as a string.
137
+
138
+ Returns:
139
+ str: The updated trace content with source mappings as a JSON string.
140
+ """
141
+
142
+ entry = json.loads(trace_content)
143
+ if entry.get("event_type") == "compilation":
144
+ payload = entry.setdefault("payload", {})
145
+ file_content = payload.get("file_content", {})
146
+ file_path = payload.get("file_path", {})
147
+
148
+ # Find the IR file keys
149
+ ttir_key = next((k for k in file_content if k.endswith(".ttir")), None)
150
+ ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None)
151
+ ptx_key = next((k for k in file_content if k.endswith(".ptx")), None)
152
+ amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None)
153
+ # Skip if no IR files found
154
+ if not (ttir_key or ttgir_key or ptx_key or amdgcn_key):
155
+ logger.warning("No IR files found in the payload.")
156
+ return trace_content
157
+
158
+ # generate ttir->source, ttgir->source, ptx->source
159
+ ttir_map = process_ir(ttir_key, file_content, file_path)
160
+ ttgir_map = process_ir(ttgir_key, file_content, file_path)
161
+ ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map])
162
+ amdgcn_map = process_ir(
163
+ amdgcn_key, file_content, file_path, [ttir_map, ttgir_map]
164
+ )
165
+
166
+ # Create bidirectional mappings between all IR types
167
+ ir_maps = {
168
+ "ttir": ttir_map,
169
+ "ttgir": ttgir_map,
170
+ "ptx": ptx_map,
171
+ "amdgcn": amdgcn_map,
172
+ }
173
+
174
+ # Create mappings between all pairs of IR types
175
+ ir_types = list(ir_maps.keys())
176
+ for i, src_type in enumerate(ir_types):
177
+ for tgt_type in ir_types[i + 1 :]:
178
+ if ir_maps[src_type] and ir_maps[tgt_type]:
179
+ create_bidirectional_mapping(
180
+ ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type
181
+ )
182
+ logger.debug(
183
+ f"Created bidirectional mapping between {src_type} and {tgt_type}"
184
+ )
185
+
186
+ py_map = {}
187
+
188
+ if "python_source" in payload:
189
+ logger.debug(
190
+ f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})"
191
+ )
192
+
193
+ # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code.
194
+ # Create a list of valid IR mappings, filtering out None keys
195
+ ir_mappings = []
196
+ ir_keys_and_maps = [
197
+ (ttir_key, ttir_map),
198
+ (ttgir_key, ttgir_map),
199
+ (ptx_key, ptx_map),
200
+ (amdgcn_key, amdgcn_map),
201
+ ]
202
+
203
+ for key, mapping in ir_keys_and_maps:
204
+ if key:
205
+ ir_mappings.append((get_file_extension(key), mapping))
206
+
207
+ py_map = create_python_mapping(ir_mappings)
208
+
209
+ # Store the mappings in the payload
210
+ payload["source_mappings"] = {
211
+ "ttir": ttir_map,
212
+ "ttgir": ttgir_map,
213
+ **({"ptx": ptx_map} if ptx_map else {}),
214
+ **({"amdgcn": amdgcn_map} if amdgcn_map else {}),
215
+ "python": py_map,
216
+ }
217
+ # NDJSON format requires a newline at the end of each line
218
+ return json.dumps(entry, separators=(",", ":")) + "\n"
219
+
220
+
221
+ def parse_single_file(
222
+ file_path: str,
223
+ output_dir: str = None,
224
+ split_inductor_compilations: bool = True,
225
+ ):
226
+ """
227
+ Process a single file, correctly group events by kernel, and extract mappings.
228
+
229
+ This function reads a trace file, groups compilation and launch events by
230
+ their kernel hash, generates a launch_diff event for each kernel, and writes
231
+ the processed data to output files.
232
+
233
+ Args:
234
+ file_path (str): The path to the file to be processed.
235
+ output_dir (str, optional): Directory to save the output files.
236
+ split_inductor_compilations (bool, optional): Whether to split
237
+ output files by frame_id, compile_id, attempt_id, and compiled_autograd_id.
238
+ Defaults to True. This rule follows tlparse's behavior.
239
+ """
240
+ kernels_by_hash = defaultdict(
241
+ lambda: {"compilation": None, "launches": [], "output_file": None}
242
+ )
243
+
244
+ output_dir = output_dir or os.path.dirname(file_path)
245
+ is_compressed_input = file_path.endswith(".bin.ndjson")
246
+ file_handle = (
247
+ gzip.open(file_path, "rt", encoding="utf-8")
248
+ if is_compressed_input
249
+ else open(file_path, "r")
250
+ )
251
+
252
+ with file_handle as f:
253
+ file_name = os.path.basename(file_path)
254
+ file_name_without_extension = (
255
+ file_name[:-11] if is_compressed_input else os.path.splitext(file_name)[0]
256
+ )
257
+
258
+ for i, line in enumerate(f):
259
+ logger.debug(f"Processing line {i + 1} in {file_path}")
260
+ json_str = line.strip()
261
+ if not json_str:
262
+ continue
263
+
264
+ # We don't need to generate full mappings for every line here,
265
+ # just enough to get the event type and necessary IDs.
266
+ try:
267
+ parsed_json = json.loads(json_str)
268
+ except json.JSONDecodeError:
269
+ logger.warning(f"Failed to parse JSON on line {i + 1} in {file_path}")
270
+ continue
271
+
272
+ event_type = parsed_json.get("event_type", None)
273
+ payload = parsed_json.get("payload", {})
274
+
275
+ if event_type == "compilation":
276
+ kernel_hash = payload.get("metadata", {}).get("hash")
277
+ if not kernel_hash:
278
+ continue
279
+
280
+ # Split inductor compilations into separate files
281
+ # This rule follows tlparse's behavior.
282
+ if split_inductor_compilations:
283
+ pt_info = payload.get("pt_info", {})
284
+ frame_id = pt_info.get("frame_id")
285
+ frame_compile_id = pt_info.get("frame_compile_id")
286
+ attempt_id = pt_info.get("attempt_id", 0)
287
+ cai = pt_info.get("compiled_autograd_id", "-")
288
+ if frame_id is not None or frame_compile_id is not None:
289
+ fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson"
290
+ else:
291
+ fname = f"{file_name_without_extension}_mapped.ndjson"
292
+ else:
293
+ fname = f"{file_name_without_extension}_mapped.ndjson"
294
+
295
+ output_file = os.path.join(output_dir, fname)
296
+ # The full processing is deferred until the final write.
297
+ kernels_by_hash[kernel_hash]["compilation"] = json_str
298
+ kernels_by_hash[kernel_hash]["output_file"] = output_file
299
+
300
+ elif event_type == "launch":
301
+ kernel_hash = parsed_json.get("compilation_metadata", {}).get("hash")
302
+ if kernel_hash:
303
+ kernels_by_hash[kernel_hash]["launches"].append((parsed_json, i))
304
+
305
+ # Organize lines for final output, keyed by output file path
306
+ all_output_lines = defaultdict(list)
307
+ for _kernel_hash, data in kernels_by_hash.items():
308
+ compilation_json_str = data["compilation"]
309
+ launches_with_indices = data["launches"]
310
+ output_file = data["output_file"]
311
+
312
+ if not output_file:
313
+ logger.warning(f"No output file for kernel hash {_kernel_hash}, skipping.")
314
+ continue
315
+
316
+ # Process the compilation event now to include source mappings
317
+ if compilation_json_str:
318
+ processed_compilation_line = parse_single_trace_content(
319
+ compilation_json_str
320
+ )
321
+ all_output_lines[output_file].append(processed_compilation_line)
322
+ compilation_event = json.loads(processed_compilation_line)
323
+ else:
324
+ compilation_event = None
325
+
326
+ for launch_event, _ in launches_with_indices:
327
+ all_output_lines[output_file].append(
328
+ json.dumps(launch_event, separators=(",", ":")) + "\n"
329
+ )
330
+
331
+ if compilation_event:
332
+ ir_analysis = _generate_ir_analysis(compilation_event)
333
+ if ir_analysis:
334
+ ir_analysis_event = {
335
+ "event_type": "ir_analysis",
336
+ "hash": _kernel_hash,
337
+ "ir_analysis": ir_analysis,
338
+ }
339
+ all_output_lines[output_file].append(
340
+ json.dumps(ir_analysis_event, separators=(",", ":")) + "\n"
341
+ )
342
+
343
+ if compilation_event and launches_with_indices:
344
+ sames, diffs, launch_index_map = _generate_launch_diff(
345
+ launches_with_indices
346
+ )
347
+ launch_diff_event = {
348
+ "event_type": "launch_diff",
349
+ "hash": _kernel_hash,
350
+ "name": compilation_event.get("payload", {})
351
+ .get("metadata", {})
352
+ .get("name"),
353
+ "total_launches": len(launches_with_indices),
354
+ "launch_index_map": launch_index_map,
355
+ "diffs": diffs,
356
+ "sames": sames,
357
+ }
358
+ all_output_lines[output_file].append(
359
+ json.dumps(launch_diff_event, separators=(",", ":")) + "\n"
360
+ )
361
+
362
+ if not os.path.exists(output_dir):
363
+ os.makedirs(output_dir)
364
+
365
+ for output_file, final_lines in all_output_lines.items():
366
+ with open(output_file, "w") as out:
367
+ out.writelines(final_lines)
tritonparse/utils.py ADDED
@@ -0,0 +1,155 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ import os
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from .common import (
10
+ copy_local_to_tmpdir,
11
+ is_fbcode,
12
+ parse_logs,
13
+ print_parsed_files_summary,
14
+ RankConfig,
15
+ save_logs,
16
+ )
17
+ from .source_type import Source, SourceType
18
+
19
+
20
+ def _add_parse_args(parser: argparse.ArgumentParser) -> None:
21
+ """Add common 'parse' subcommand arguments to a parser."""
22
+ parser.add_argument(
23
+ "source",
24
+ help=(
25
+ "Source of torch logs to be analyzed. It is expected to path to a local "
26
+ "directory or log"
27
+ ),
28
+ )
29
+ parser.add_argument(
30
+ "-o",
31
+ "--out",
32
+ help="Output directory.",
33
+ type=str,
34
+ )
35
+ parser.add_argument(
36
+ "--overwrite",
37
+ help=(
38
+ "Delete out directory if it already exists. Only does something if --out is set"
39
+ ),
40
+ action="store_true",
41
+ )
42
+ parser.add_argument("-r", "--rank", help="Rank of logs to be analyzed", type=int)
43
+ parser.add_argument(
44
+ "--all-ranks",
45
+ help="Analyze all ranks",
46
+ action="store_true",
47
+ )
48
+ parser.add_argument("-v", "--verbose", help="Verbose logging", action="store_true")
49
+ if is_fbcode():
50
+ from tritonparse.fb.utils import append_parser
51
+
52
+ append_parser(parser)
53
+
54
+
55
+ def oss_run(
56
+ source: str,
57
+ out: Optional[str] = None,
58
+ overwrite: Optional[bool] = False,
59
+ rank: Optional[int] = None,
60
+ all_ranks: bool = False,
61
+ verbose: bool = False,
62
+ split_inductor_compilations: bool = True,
63
+ ):
64
+ """
65
+ Main function for tritonparse. It is for OSS only.
66
+
67
+ Args:
68
+ source: Source of torch logs to be analyzed (required)
69
+ out: Output directory
70
+ overwrite: Delete out directory if it already exists
71
+ rank: Rank of logs to be analyzed
72
+ all_ranks: Analyze all ranks
73
+ verbose: Verbose logging
74
+ """
75
+ source = Source(source, verbose)
76
+ rank_config = RankConfig.from_cli_args(rank, all_ranks, source.type)
77
+
78
+ # Check output directory early if specified
79
+ if out is not None:
80
+ out_dir = Path(out)
81
+ if out_dir.exists():
82
+ if not overwrite:
83
+ raise RuntimeError(
84
+ f"{out_dir} already exists, pass --overwrite to overwrite"
85
+ )
86
+ shutil.rmtree(out_dir)
87
+ os.makedirs(out_dir, exist_ok=True)
88
+
89
+ # For signpost logging (not implemented in Python version)
90
+
91
+ if source.type == SourceType.LOCAL:
92
+ local_path = source.value
93
+ # Copy the results to a temp directory, then parse them
94
+ logs = copy_local_to_tmpdir(local_path, verbose)
95
+
96
+ elif source.type == SourceType.LOCAL_FILE:
97
+ local_path = source.value
98
+ # Copy the single file to a temp directory, then parse it
99
+ logs = copy_local_to_tmpdir(local_path, verbose)
100
+
101
+ parsed_log_dir, _ = parse_logs(
102
+ logs,
103
+ rank_config,
104
+ verbose,
105
+ split_inductor_compilations=split_inductor_compilations,
106
+ )
107
+ if out is not None:
108
+ save_logs(Path(out), parsed_log_dir, overwrite, verbose)
109
+ # Print beautiful summary of all parsed files
110
+ if out is not None:
111
+ out_dir = str(Path(out).absolute())
112
+ else:
113
+ out_dir = str(Path(parsed_log_dir).absolute())
114
+ print_parsed_files_summary(out_dir)
115
+ return None
116
+
117
+
118
+ def unified_parse(
119
+ source: str,
120
+ out: Optional[str] = None,
121
+ overwrite: Optional[bool] = False,
122
+ rank: Optional[int] = None,
123
+ all_ranks: bool = False,
124
+ verbose: bool = False,
125
+ split_inductor_compilations: bool = True,
126
+ **kwargs,
127
+ ):
128
+ """
129
+ Unified parse function that provides a flexible interface for parsing triton logs.
130
+
131
+ Args:
132
+ source: Input directory containing logs to parse.
133
+ out: Output directory for parsed results. By default, parsed logs will be saved to a temporary directory.
134
+ overwrite: Whether to overwrite existing output directory
135
+ rank: Specific rank to analyze
136
+ all_ranks: Whether to analyze all ranks
137
+ verbose: Whether to enable verbose logging
138
+ """
139
+ # Choose the appropriate parse function
140
+ if is_fbcode():
141
+ from tritonparse.fb.utils import fb_run as parse
142
+ else:
143
+ parse = oss_run
144
+
145
+ output = parse(
146
+ source=source,
147
+ out=out,
148
+ overwrite=overwrite,
149
+ rank=rank,
150
+ all_ranks=all_ranks,
151
+ verbose=verbose,
152
+ split_inductor_compilations=split_inductor_compilations,
153
+ **kwargs,
154
+ )
155
+ return output