tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,580 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import ast
4
+ import importlib
5
+ import importlib.util
6
+ import json
7
+ import logging
8
+ import sys
9
+ from datetime import datetime
10
+ from functools import lru_cache
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import triton.language as tl
15
+ from tritonparse.tools.load_tensor import load_tensor
16
+ from tritonparse.tp_logger import logger
17
+
18
+ TRITON_KERNELS_CUSTOM_TYPES = (
19
+ importlib.util.find_spec("triton_kernels") is not None
20
+ and importlib.util.find_spec("triton_kernels.tensor") is not None
21
+ )
22
+
23
+ # Mapping from dtype string representation to Triton dtype objects
24
+ TRITON_DTYPE_MAP = {
25
+ # Signed integers
26
+ "int8": tl.int8,
27
+ "int16": tl.int16,
28
+ "int32": tl.int32,
29
+ "int64": tl.int64,
30
+ # Unsigned integers
31
+ "int1": tl.int1,
32
+ "uint8": tl.uint8,
33
+ "uint16": tl.uint16,
34
+ "uint32": tl.uint32,
35
+ "uint64": tl.uint64,
36
+ # Standard floating point types
37
+ "fp16": tl.float16,
38
+ "bf16": tl.bfloat16,
39
+ "fp32": tl.float32,
40
+ "fp64": tl.float64,
41
+ # FP8 variants
42
+ "fp8e4b15": tl.float8e4b15,
43
+ "fp8e4nv": tl.float8e4nv,
44
+ "fp8e4b8": tl.float8e4b8,
45
+ "fp8e5": tl.float8e5,
46
+ "fp8e5b16": tl.float8e5b16,
47
+ }
48
+
49
+
50
+ @lru_cache(maxsize=1)
51
+ def _get_triton_tensor_types():
52
+ mod = importlib.import_module("triton_kernels.tensor")
53
+ return (
54
+ mod.Tensor,
55
+ mod.Storage,
56
+ mod.StridedLayout,
57
+ )
58
+
59
+
60
+ def create_args_from_json_file(json_path):
61
+ """
62
+ Load and parse a reproducer JSON file.
63
+
64
+ Args:
65
+ json_path (str): Path to the JSON file describing the kernel launch.
66
+
67
+ Returns:
68
+ tuple[list, dict]: Grid specification list and map of argument name to value.
69
+ """
70
+ with open(json_path, "r") as f:
71
+ data = json.load(f)
72
+ return create_args_from_json(data)
73
+
74
+
75
+ def create_args_from_json(data):
76
+ """
77
+ Parse a reproducer JSON and build kernel grid and argument dictionary.
78
+
79
+ Args:
80
+ data (dict | list): JSON data describing the kernel launch.
81
+
82
+ Returns:
83
+ tuple[list, dict]: Grid specification list and map of argument name to value.
84
+ """
85
+ # Handle data format validation and extraction
86
+ if isinstance(data, list):
87
+ if len(data) != 1:
88
+ print(
89
+ f"Error: Expected single element list, got list with {len(data)} elements"
90
+ )
91
+ sys.exit(1)
92
+ data = data[0]
93
+ elif not isinstance(data, dict):
94
+ print(f"Error: Expected list or dict, got {type(data)}")
95
+ sys.exit(1)
96
+
97
+ grid = data.get("grid", [])
98
+ args_dict = {}
99
+ extracted_args = data.get("extracted_args", {})
100
+
101
+ for arg_name, arg_info in extracted_args.items():
102
+ args_dict[arg_name] = _create_arg_from_info(arg_info)
103
+
104
+ return grid, args_dict
105
+
106
+
107
+ def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
108
+ """
109
+ Apply custom stride and storage offset to a tensor if needed.
110
+
111
+ Args:
112
+ tensor: The base contiguous tensor
113
+ shape: The desired shape
114
+ stride: The desired stride (or None for contiguous)
115
+ storage_offset: The desired storage offset
116
+
117
+ Returns:
118
+ torch.Tensor: The strided tensor view or original tensor if contiguous
119
+ """
120
+ if stride is None:
121
+ return tensor
122
+
123
+ # Calculate expected contiguous stride
124
+ expected_contiguous_stride = []
125
+ s = 1
126
+ for dim_size in reversed(shape):
127
+ expected_contiguous_stride.insert(0, s)
128
+ s *= dim_size
129
+
130
+ # If stride matches contiguous stride and no storage offset, return as-is
131
+ if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
132
+ return tensor
133
+
134
+ # Calculate required storage size
135
+ if len(shape) > 0 and len(stride) > 0:
136
+ max_offset = storage_offset
137
+ for dim_stride, dim_size in zip(stride, shape):
138
+ if dim_size > 0:
139
+ max_offset += dim_stride * (dim_size - 1)
140
+ storage_size = max_offset + 1
141
+ else:
142
+ storage_size = storage_offset + 1
143
+
144
+ # Create larger storage tensor and create strided view
145
+ storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
146
+
147
+ # Create strided view
148
+ strided_view = storage_tensor.as_strided(
149
+ size=shape, stride=stride, storage_offset=storage_offset
150
+ )
151
+
152
+ # Copy data from the base tensor into the strided layout
153
+ strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
154
+
155
+ return strided_view
156
+
157
+
158
+ def _create_base_tensor(arg_info) -> torch.Tensor:
159
+ """
160
+ Create a base tensor without stride/offset modifications.
161
+
162
+ Args:
163
+ arg_info (dict): Argument information including dtype, shape, device, etc.
164
+
165
+ Returns:
166
+ torch.Tensor: The created base tensor
167
+ """
168
+ if arg_info.get("blob_path"):
169
+ return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
170
+
171
+ # Extract basic tensor properties
172
+ dtype_str = arg_info.get("dtype")
173
+ try:
174
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
175
+ except AttributeError:
176
+ logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
177
+ torch_dtype = torch.float32
178
+
179
+ shape = arg_info.get("shape", [])
180
+ device = arg_info.get("device", "cpu")
181
+ # Normalize cuda device to cuda:0
182
+ if isinstance(device, str) and device.startswith("cuda"):
183
+ device = "cuda:0"
184
+
185
+ # Extract statistical information if available
186
+ mean = arg_info.get("mean")
187
+ std = arg_info.get("std")
188
+ min_val = arg_info.get("min")
189
+ max_val = arg_info.get("max")
190
+ has_stats = (
191
+ mean is not None
192
+ and std is not None
193
+ and min_val is not None
194
+ and max_val is not None
195
+ )
196
+
197
+ if arg_info.get("tensor_capture_error", False):
198
+ logging.error(
199
+ f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
200
+ )
201
+
202
+ # Use a dummy tensor to check properties of the dtype
203
+ tensor_props = torch.empty(0, dtype=torch_dtype)
204
+
205
+ # Case 1: Floating point types
206
+ if tensor_props.is_floating_point():
207
+ if has_stats:
208
+ # Generate tensor with statistical properties matching original data
209
+ if std == 0 or min_val == max_val:
210
+ # Constant tensor
211
+ return torch.full(shape, mean, dtype=torch_dtype, device=device)
212
+ # Generate normal distribution with mean and std, then clamp to [min, max]
213
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
214
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
215
+ return tensor.to(torch_dtype)
216
+ else:
217
+ # Fallback to original random generation
218
+ if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
219
+ tmp = torch.rand(shape, dtype=torch.float32, device=device)
220
+ return tmp.to(torch_dtype)
221
+ else:
222
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
223
+
224
+ # Case 2: Integer types
225
+ elif torch_dtype in [
226
+ torch.int8,
227
+ torch.int16,
228
+ torch.int32,
229
+ torch.int64,
230
+ torch.uint8,
231
+ torch.bool,
232
+ ]:
233
+ if has_stats and torch_dtype != torch.bool:
234
+ # Generate tensor with statistical properties, then round for integers
235
+ if std == 0 or min_val == max_val:
236
+ # Constant tensor
237
+ return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
238
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
239
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
240
+ return torch.round(tensor).to(torch_dtype)
241
+ else:
242
+ # Fallback to original random generation
243
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
244
+
245
+ # Case 3: Complex numbers need special handling
246
+ elif tensor_props.is_complex():
247
+ # Complex types: fallback to original logic for now
248
+ # TODO: Could be improved to use statistical info if available
249
+ float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
250
+ real_part = torch.rand(shape, dtype=float_dtype, device=device)
251
+ imag_part = torch.rand(shape, dtype=float_dtype, device=device)
252
+ return torch.complex(real_part, imag_part)
253
+
254
+ # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
255
+ elif "uint" in str(torch_dtype):
256
+ if has_stats:
257
+ # Generate tensor with statistical properties for unsigned integers
258
+ if std == 0 or min_val == max_val:
259
+ return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
260
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
261
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
262
+ return torch.round(tensor).to(torch_dtype)
263
+ else:
264
+ # Fallback to original random generation
265
+ return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
266
+
267
+ # Case 5: If we don't know how to handle the type, raise an error
268
+ else:
269
+ raise NotImplementedError(
270
+ f"Random data generation not implemented for dtype: {torch_dtype}"
271
+ )
272
+
273
+
274
+ def _create_tensor(arg_info) -> torch.Tensor:
275
+ """
276
+ Create a tensor with stride and storage offset if needed.
277
+
278
+ Args:
279
+ arg_info (dict): Argument information including dtype, shape, stride, etc.
280
+
281
+ Returns:
282
+ torch.Tensor: The created tensor with applied stride/offset
283
+ """
284
+ tensor = _create_base_tensor(arg_info)
285
+
286
+ # Apply stride and storage offset if needed
287
+ shape = arg_info.get("shape", [])
288
+ stride = arg_info.get("stride")
289
+ storage_offset = arg_info.get("storage_offset", 0)
290
+ return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
291
+
292
+
293
+ def _create_arg_from_info(arg_info):
294
+ """
295
+ Recursively construct a kernel argument from its JSON schema.
296
+
297
+ Args:
298
+ arg_info (dict): JSON object describing a single argument, including
299
+ fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
300
+
301
+ Returns:
302
+ Any: The constructed Python object suitable for kernel invocation.
303
+
304
+ Raises:
305
+ RuntimeError: When required optional dependencies are missing.
306
+ NotImplementedError: When a dtype or type is not supported yet.
307
+ """
308
+ arg_type = arg_info.get("type")
309
+
310
+ if arg_type == "NoneType":
311
+ return None
312
+
313
+ if arg_type in ["int", "bool", "str", "float"]:
314
+ return arg_info.get("value")
315
+
316
+ elif arg_type == "tensor":
317
+ return _create_tensor(arg_info)
318
+
319
+ elif arg_type == "triton_kernels.tensor.Tensor":
320
+ if not TRITON_KERNELS_CUSTOM_TYPES:
321
+ raise RuntimeError(
322
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
323
+ )
324
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
325
+ storage = _create_arg_from_info(arg_info.get("storage"))
326
+ dtype_str = arg_info.get("dtype")
327
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
328
+ return Tensor(
329
+ storage=storage,
330
+ shape=arg_info.get("shape"),
331
+ shape_max=arg_info.get("shape_max"),
332
+ dtype=torch_dtype,
333
+ )
334
+
335
+ elif arg_type == "triton_kernels.tensor.Storage":
336
+ if not TRITON_KERNELS_CUSTOM_TYPES:
337
+ raise RuntimeError(
338
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
339
+ )
340
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
341
+ data = _create_arg_from_info(arg_info.get("data"))
342
+ layout = _create_arg_from_info(arg_info.get("layout"))
343
+ return Storage(data=data, layout=layout)
344
+
345
+ elif arg_type == "StridedLayout":
346
+ if not TRITON_KERNELS_CUSTOM_TYPES:
347
+ raise RuntimeError(
348
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
349
+ )
350
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
351
+ return StridedLayout(shape=arg_info.get("initial_shape"))
352
+
353
+ elif arg_type == "dtype":
354
+ dtype_repr = arg_info.get("repr")
355
+ if dtype_repr in TRITON_DTYPE_MAP:
356
+ return TRITON_DTYPE_MAP[dtype_repr]
357
+ else:
358
+ raise NotImplementedError(f"Unsupported Triton dtype: {dtype_repr}")
359
+
360
+ else:
361
+ print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
362
+ return None
363
+
364
+
365
+ def determine_output_paths(out_dir: str, kernel_name: str, template: str):
366
+ """
367
+ Determine output file paths for reproducer script and context data.
368
+
369
+ Args:
370
+ out_dir: Output directory path. If empty, uses default location.
371
+ kernel_name: Name of the kernel for default directory naming.
372
+ template: Template name or path. If a path, extracts the filename.
373
+
374
+ Returns:
375
+ Tuple of (python_script_path, json_context_path) as Path objects.
376
+ """
377
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
378
+ output_directory = Path(out_dir) / kernel_name
379
+ output_directory.mkdir(parents=True, exist_ok=True)
380
+
381
+ # Extract template name from path if needed
382
+ template_name = (
383
+ Path(template).stem if "/" in template or "\\" in template else template
384
+ )
385
+
386
+ filename_parts = ["repro"]
387
+ if template != "example":
388
+ filename_parts.append(template_name)
389
+ filename_parts.append(timestamp)
390
+ filename = "_".join(filename_parts) + ".py"
391
+ out_py_path = output_directory / filename
392
+ temp_json_path = output_directory / f"repro_context_{timestamp}.json"
393
+
394
+ return out_py_path, temp_json_path
395
+
396
+
397
+ def _generate_import_statements(kernel_info) -> tuple[str, str]:
398
+ """
399
+ Generate (sys.path insertion statement, import statement) for the kernel.
400
+
401
+ Strategy:
402
+ - Always add the kernel file's parent directory to sys.path.
403
+ - If the filename (without .py) is a valid identifier, import using that
404
+ module name: `from <stem> import <func> as imported_kernel_function`.
405
+ - Otherwise, fall back to dynamic import via importlib.util and bind
406
+ `imported_kernel_function` from the loaded module.
407
+ """
408
+ file_path = Path(kernel_info.file_path)
409
+ function_name = kernel_info.function_name
410
+
411
+ if not file_path or not function_name:
412
+ raise ValueError("Kernel file path or function name missing from context.")
413
+
414
+ # Always add the file's parent directory to sys.path
415
+ sys_stmt = (
416
+ "import sys; p = r'" + str(file_path.parent) + "';\n"
417
+ "if p not in sys.path: sys.path.insert(0, p)"
418
+ )
419
+
420
+ module_name = file_path.with_suffix("").name
421
+ if module_name.isidentifier():
422
+ import_stmt = (
423
+ f"from {module_name} import {function_name} as imported_kernel_function"
424
+ )
425
+ logger.debug("Generated direct import statement: %s", import_stmt)
426
+ return sys_stmt, import_stmt
427
+
428
+ # Fallback: dynamic import when filename is not a valid identifier
429
+ import_stmt = (
430
+ "import importlib.util\n"
431
+ f"_spec = importlib.util.spec_from_file_location('kernel_mod', r'{str(file_path)}')\n"
432
+ "_mod = importlib.util.module_from_spec(_spec)\n"
433
+ "_spec.loader.exec_module(_mod)\n"
434
+ f"imported_kernel_function = getattr(_mod, '{function_name}')"
435
+ )
436
+ logger.debug("Generated dynamic import for file: %s", file_path)
437
+ return sys_stmt, import_stmt
438
+
439
+
440
+ def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]:
441
+ """
442
+ Parses a Triton kernel's source code using AST to distinguish positional args
443
+ from keyword args (those with default values).
444
+
445
+ This implementation uses Python's ast module for robust parsing that handles:
446
+ - Return type annotations (e.g., -> None)
447
+ - Complex type annotations (e.g., Callable[[dict[str, int]], list[Tensor]])
448
+ - Decorators (e.g., @triton.jit)
449
+ - Keyword-only arguments (after *)
450
+ - All Python syntax variations
451
+
452
+ Args:
453
+ kernel_source_code: Python source code containing the kernel function
454
+
455
+ Returns:
456
+ tuple[list[str], list[str]]: (positional_args, keyword_args)
457
+
458
+ Raises:
459
+ ValueError: If parsing fails or no function definition is found
460
+ """
461
+ try:
462
+ # Parse source code into AST
463
+ tree = ast.parse(kernel_source_code)
464
+
465
+ # Find the first function definition
466
+ func_def = None
467
+ for node in ast.walk(tree):
468
+ if isinstance(node, ast.FunctionDef):
469
+ func_def = node
470
+ break
471
+
472
+ if not func_def:
473
+ raise ValueError("No function definition found in source code")
474
+
475
+ positional_args = []
476
+ keyword_args = []
477
+
478
+ # Extract function arguments
479
+ args = func_def.args
480
+
481
+ # Calculate number of positional arguments
482
+ # defaults are right-aligned with args, so:
483
+ # num_positional = total_args - num_defaults
484
+ num_defaults = len(args.defaults)
485
+ num_args = len(args.args)
486
+ num_positional = num_args - num_defaults
487
+
488
+ # Classify regular arguments
489
+ for i, arg in enumerate(args.args):
490
+ arg_name = arg.arg
491
+ if i < num_positional:
492
+ positional_args.append(arg_name)
493
+ else:
494
+ keyword_args.append(arg_name)
495
+
496
+ # Handle keyword-only arguments (after *)
497
+ for arg in args.kwonlyargs:
498
+ keyword_args.append(arg.arg)
499
+
500
+ logger.debug("Parsed positional args: %s", positional_args)
501
+ logger.debug("Parsed keyword args: %s", keyword_args)
502
+ return positional_args, keyword_args
503
+
504
+ except SyntaxError as e:
505
+ raise ValueError(
506
+ f"Invalid Python syntax in kernel source at line {e.lineno}: {e.msg}"
507
+ ) from e
508
+ except Exception as e:
509
+ raise ValueError(f"Failed to parse kernel signature: {e}") from e
510
+
511
+
512
+ def _generate_invocation_snippet(
513
+ positional_args: list[str], keyword_args: list[str]
514
+ ) -> str:
515
+ """Generates a single-line Python code snippet for kernel invocation."""
516
+ # Prepare positional args for direct injection into the call
517
+ pos_args_str = ", ".join([f'args_dict["{arg}"]' for arg in positional_args])
518
+
519
+ # Prepare keyword args for direct injection
520
+ kw_args_str = ", ".join([f'{arg}=args_dict["{arg}"]' for arg in keyword_args])
521
+
522
+ # Combine them, ensuring proper comma separation
523
+ all_args = []
524
+ if pos_args_str:
525
+ all_args.append(pos_args_str)
526
+ if kw_args_str:
527
+ all_args.append(kw_args_str)
528
+
529
+ # Create the single-line call
530
+ return f"imported_kernel_function[tuple(grid)]({', '.join(all_args)})"
531
+
532
+
533
+ def format_python_code(code: str) -> str:
534
+ """
535
+ Format Python code using black and organize imports using isort if available.
536
+
537
+ Args:
538
+ code: Python code string to format.
539
+
540
+ Returns:
541
+ Formatted code with organized imports if tools are available,
542
+ otherwise returns original code.
543
+ """
544
+ # Step 1: Organize imports using isort if available
545
+ try:
546
+ import isort
547
+
548
+ # Configure isort to move imports to the top more aggressively
549
+ code = isort.code(
550
+ code,
551
+ float_to_top=True, # Move imports to the top
552
+ remove_redundant_aliases=True, # Remove redundant aliases like 'import torch as torch'
553
+ force_single_line=False,
554
+ line_length=88,
555
+ profile="black", # Compatible with black formatter
556
+ treat_comments_as_code=[], # Don't treat comments as code barriers
557
+ treat_all_comments_as_code=False, # Allow imports to move past comments
558
+ )
559
+ logger.debug("Successfully organized imports")
560
+ except ImportError:
561
+ logger.debug("isort library not available, import organization will be skipped")
562
+ except Exception as e:
563
+ logger.warning(f"Failed to organize imports: {e}")
564
+
565
+ # Step 2: Format code using black if available
566
+ try:
567
+ import black
568
+
569
+ formatted_code = black.format_str(code, mode=black.Mode())
570
+ logger.debug("Successfully formatted generated code")
571
+ return formatted_code
572
+ except ImportError:
573
+ logger.debug("black library not available, code formatting will be skipped")
574
+ return code
575
+ except black.InvalidInput as e:
576
+ logger.warning(f"Failed to format generated code: {e}")
577
+ return code
578
+ except Exception as e:
579
+ logger.warning(f"Unexpected error while formatting code: {e}")
580
+ return code
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # We'd like to sperate structured logging module and tritonparse module as much as possible. So, put the shared variables here.
3
+ import os
4
+
5
+ # The compilation information will be stored to /logs/DEFAULT_TRACE_FILE_PREFIX by default
6
+ # unless other flags disable or set another store. Add USER to avoid permission issues in shared servers.
7
+ DEFAULT_TRACE_FILE_PREFIX = (
8
+ f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
9
+ )
10
+ DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER = "dedicated_log_triton_trace_"
11
+ # Return True if test outputs (e.g., temp dirs) should be preserved.
12
+ TEST_KEEP_OUTPUT = os.getenv("TEST_KEEP_OUTPUT", "0") in ["1", "true", "True"]
@@ -0,0 +1,56 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Tuple
6
+
7
+
8
+ class SourceType(str, Enum):
9
+ """Enumeration of supported source types for OSS only."""
10
+
11
+ LOCAL = "local"
12
+ LOCAL_FILE = "local_file"
13
+
14
+ @classmethod
15
+ def _missing_(cls, value: object) -> "SourceType":
16
+ """
17
+ Handle unknown source types by raising a ValueError.
18
+
19
+ Args:
20
+ value: The unknown value that was attempted to be used as a SourceType
21
+
22
+ Returns:
23
+ Never returns, always raises ValueError
24
+ """
25
+ valid_types = [e.value for e in cls]
26
+ raise ValueError(
27
+ f"Invalid source type '{value}'. Valid types are: {', '.join(valid_types)}"
28
+ )
29
+
30
+
31
+ class Source:
32
+ """Represents a source of logs to parse."""
33
+
34
+ def __init__(self, source_str: str, verbose: bool = False):
35
+ """
36
+ Initialize a Source object by parsing the source string.
37
+
38
+ Args:
39
+ source_str: String representing the source
40
+ verbose: Whether to print verbose information
41
+ """
42
+ self.source_str = source_str
43
+ self.verbose = verbose
44
+ self.type, self.value = self._parse_source()
45
+
46
+ def _parse_source(self) -> Tuple[SourceType, str]:
47
+ # Check if it's a local path
48
+ path = Path(self.source_str)
49
+ if path.is_dir():
50
+ return SourceType.LOCAL, str(path.absolute())
51
+ elif path.is_file():
52
+ return SourceType.LOCAL_FILE, str(path.absolute())
53
+ else:
54
+ raise ValueError(
55
+ f"Source '{self.source_str}' is not a valid directory or file"
56
+ )