tritonparse 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ from collections.abc import Mapping
15
15
  from dataclasses import asdict, is_dataclass
16
16
  from datetime import date, datetime
17
17
  from enum import Enum
18
+ from functools import partial
18
19
  from pathlib import Path
19
20
  from typing import Any, Callable, Dict, List, Optional, Union
20
21
 
@@ -756,7 +757,7 @@ def maybe_trace_triton(
756
757
 
757
758
  # Add timing information if available
758
759
  if times:
759
- trace_data["times"] = times
760
+ trace_data["metadata"]["times"] = times
760
761
  # Log the collected information through the tracing system
761
762
  trace_structured_triton(
762
763
  event_type,
@@ -819,10 +820,18 @@ def extract_arg_info(arg_dict):
819
820
  return extracted_args
820
821
 
821
822
 
822
- def add_launch_metadata(grid, metadata, arg_dict):
823
+ def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
823
824
  # Extract detailed argument information
824
825
  extracted_args = extract_arg_info(arg_dict)
825
- return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)}
826
+ extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
827
+ return {
828
+ "launch_metadata_tritonparse": (
829
+ grid,
830
+ metadata._asdict(),
831
+ extracted_args,
832
+ extracted_inductor_args,
833
+ )
834
+ }
826
835
 
827
836
 
828
837
  class JITHookImpl(JITHook):
@@ -848,6 +857,7 @@ class JITHookImpl(JITHook):
848
857
  compile,
849
858
  is_manual_warmup: bool,
850
859
  already_compiled: bool,
860
+ inductor_args: Optional[Dict[str, Any]] = None,
851
861
  ) -> Optional[bool]:
852
862
  """
853
863
  Override or set the launch_metadata function for the JIT-compiled kernel.
@@ -875,12 +885,16 @@ class JITHookImpl(JITHook):
875
885
  return True
876
886
 
877
887
  # Get the current launch_metadata function if it exists
878
- current_launch_metadata = getattr(fn.jit_function, "launch_metadata", None)
888
+ function = getattr(fn, "jit_function", fn)
889
+
890
+ current_launch_metadata = getattr(function, "launch_metadata", None)
879
891
  if current_launch_metadata is not None:
880
892
  log.warning(
881
893
  f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
882
894
  )
883
- fn.jit_function.launch_metadata = add_launch_metadata
895
+ function.launch_metadata = partial(
896
+ add_launch_metadata, inductor_args=inductor_args
897
+ )
884
898
  return True
885
899
 
886
900
 
@@ -944,6 +958,7 @@ class LaunchHookImpl(LaunchHook):
944
958
  trace_data["extracted_args"] = launch_metadata_tritonparse[
945
959
  2
946
960
  ] # Now contains detailed arg info
961
+ trace_data["extracted_inductor_args"] = launch_metadata_tritonparse[3]
947
962
  trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
948
963
 
949
964
 
@@ -1011,6 +1026,15 @@ def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
1011
1026
  knobs.compilation.listener = maybe_trace_triton
1012
1027
 
1013
1028
 
1029
+ def init_with_env():
1030
+ """
1031
+ This function is used to initialize TritonParse with the environment variable TRITON_TRACE_FOLDER and TRITON_TRACE_LAUNCH specifically.
1032
+ It is only supposed to be used in OSS triton's source code.
1033
+ """
1034
+ if triton_trace_folder:
1035
+ init(triton_trace_folder, enable_trace_launch=TRITON_TRACE_LAUNCH)
1036
+
1037
+
1014
1038
  def clear_logging_config():
1015
1039
  """
1016
1040
  Clear all configurations made by init() and init_basic().
@@ -0,0 +1 @@
1
+ The tool scripts in this folder are used separately. They are not part of the main tritonparse functionality.
@@ -0,0 +1,139 @@
1
+ Metadata-Version: 2.4
2
+ Name: tritonparse
3
+ Version: 0.2.0
4
+ Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
5
+ Author-email: Yueming Hao <yhao@meta.com>
6
+ License-Expression: BSD-3-Clause
7
+ Project-URL: Homepage, https://github.com/meta-pytorch/tritonparse
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: triton>3.3.1
12
+ Provides-Extra: test
13
+ Requires-Dist: coverage>=7.0.0; extra == "test"
14
+ Dynamic: license-file
15
+
16
+ # TritonParse
17
+
18
+ [![License: BSD-3](https://img.shields.io/badge/License-BSD--3-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
19
+ [![GitHub Pages](https://img.shields.io/badge/GitHub%20Pages-Deploy-brightgreen)](https://meta-pytorch.org/tritonparse/)
20
+
21
+ **A comprehensive visualization and analysis tool for Triton IR files** — helping developers analyze, debug, and understand Triton kernel compilation processes.
22
+
23
+ 🌐 **[Try it online →](https://meta-pytorch.org/tritonparse/?json_url=https://meta-pytorch.org/tritonparse/dedicated_log_triton_trace_findhao__mapped.ndjson.gz)**
24
+
25
+ ## ✨ Key Features
26
+
27
+ - **🚀 Launch Difference Analysis** - Automatically detect and visualize variations in kernel launch parameters, helping you pinpoint performance bottlenecks and debug launch configurations.
28
+ - **🔍 Interactive Visualization** - Explore Triton kernels with detailed metadata and stack traces
29
+ - **📊 Multi-format IR Support** - View TTGIR, TTIR, LLIR, PTX, and AMDGCN in one place
30
+ - **🔄 Side-by-side Comparison** - Compare IR stages with synchronized highlighting
31
+ - **📝 Structured Logging** - Capture detailed compilation and launch events with source mapping
32
+ - **🌐 Ready-to-use Interface** - No installation required, works in your browser
33
+ - **🔒 Privacy-first** - All processing happens locally in your browser, no data uploaded
34
+
35
+ ## 🚀 Quick Start
36
+
37
+ ### 1. Generate Traces
38
+
39
+ ```python
40
+ import tritonparse.structured_logging
41
+
42
+ # Initialize logging with launch tracing enabled
43
+ tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)
44
+
45
+ # Your Triton/PyTorch code here
46
+ # ... your kernels ...
47
+
48
+ # Parse and generate trace files
49
+ import tritonparse.utils
50
+ tritonparse.utils.unified_parse("./logs/")
51
+ ```
52
+ The example terminal output is:
53
+ ```bash
54
+ tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
55
+ INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output
56
+
57
+ ================================================================================
58
+ 📁 TRITONPARSE PARSING RESULTS
59
+ ================================================================================
60
+ 📂 Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
61
+ 📊 Total files generated: 2
62
+
63
+ 📄 Generated files:
64
+ --------------------------------------------------
65
+ 1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
66
+ 2. 📝 log_file_list.json (181B)
67
+ ================================================================================
68
+ ✅ Parsing completed successfully!
69
+ ================================================================================
70
+ ```
71
+
72
+ ### 2. Visualize Results
73
+
74
+ **Visit [https://meta-pytorch.org/tritonparse/](https://meta-pytorch.org/tritonparse/?json_url=https://meta-pytorch.org/tritonparse/dedicated_log_triton_trace_findhao__mapped.ndjson.gz)** and open your local trace files (.ndjson.gz format).
75
+
76
+ > **🔒 Privacy Note**: Your trace files are processed entirely in your browser - nothing is uploaded to any server!
77
+
78
+ ## 🛠️ Installation
79
+
80
+ **For basic usage (trace generation):**
81
+ ```bash
82
+ git clone https://github.com/meta-pytorch/tritonparse.git
83
+ cd tritonparse
84
+ pip install -e .
85
+ ```
86
+
87
+ **Prerequisites:** Python ≥ 3.10, Triton ≥ 3.4.0, GPU required (NVIDIA/AMD)
88
+
89
+ TritonParse relies on new features in Triton. Please install the latest version of Triton:
90
+ ```bash
91
+ pip install triton
92
+ ```
93
+
94
+ ## 📚 Complete Documentation
95
+
96
+ | 📖 Guide | Description |
97
+ |----------|-------------|
98
+ | **[🏠 Wiki Home](https://github.com/meta-pytorch/tritonparse/wiki)** | Complete documentation and navigation |
99
+ | **[📦 Installation Guide](https://github.com/meta-pytorch/tritonparse/wiki/01.-Installation)** | Detailed setup for all scenarios |
100
+ | **[📋 Usage Guide](https://github.com/meta-pytorch/tritonparse/wiki/02.-Usage-Guide)** | Complete workflow and examples |
101
+ | **[🌐 Web Interface Guide](https://github.com/meta-pytorch/tritonparse/wiki/03.-Web-Interface-Guide)** | Master the visualization interface |
102
+ | **[🔧 Developer Guide](https://github.com/meta-pytorch/tritonparse/wiki/04.-Developer-Guide)** | Contributing and development setup |
103
+ | **[❓ FAQ](https://github.com/meta-pytorch/tritonparse/wiki/06.-FAQ)** | Frequently asked questions |
104
+
105
+ ## 🛠️ Tech Stack
106
+
107
+ - **Frontend**: React 19, TypeScript, Vite, Tailwind CSS, Monaco Editor
108
+ - **Backend**: Python with Triton integration, structured logging
109
+ - **Deployment**: GitHub Pages, automatic deployment
110
+
111
+ ## 📊 Understanding Triton Compilation
112
+
113
+ TritonParse visualizes the complete Triton compilation pipeline:
114
+
115
+ **Python Source** → **TTIR** → **TTGIR** → **LLIR** → **PTX/AMDGCN**
116
+
117
+ Each stage can be inspected and compared to understand optimization transformations.
118
+
119
+ ## 🤝 Contributing
120
+
121
+ We welcome contributions! Please see our **[Developer Guide](https://github.com/meta-pytorch/tritonparse/wiki/04.-Developer-Guide)** for:
122
+ - Development setup
123
+ - Code formatting standards
124
+ - Pull request process
125
+ - Architecture overview
126
+
127
+ ## 📞 Support & Community
128
+
129
+ - **🐛 Report Issues**: [GitHub Issues](https://github.com/meta-pytorch/tritonparse/issues)
130
+ - **💬 Discussions**: [GitHub Discussions](https://github.com/meta-pytorch/tritonparse/discussions)
131
+ - **📚 Documentation**: [TritonParse Wiki](https://github.com/meta-pytorch/tritonparse/wiki)
132
+
133
+ ## 📄 License
134
+
135
+ This project is licensed under the BSD-3 License - see the [LICENSE](LICENSE) file for details.
136
+
137
+ ---
138
+
139
+ **✨ Ready to get started?** Visit our **[Installation Guide](https://github.com/meta-pytorch/tritonparse/wiki/01.-Installation)** or try the **[online tool](https://meta-pytorch.org/tritonparse/)** directly!
@@ -0,0 +1,24 @@
1
+ tritonparse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ tritonparse/common.py,sha256=VWKFZJG7msIExMH0QNYCO8SHKqUBFtdRuLaCy_K8DFI,13725
3
+ tritonparse/event_diff.py,sha256=yOD6uNxLJroatfx2nEGr-erw24ObOrHU9P6V5pzr8do,4907
4
+ tritonparse/extract_source_mappings.py,sha256=Z6UxFj2cCE5NCWLQTYPKqUpLfbYhqP8xgCl5mvud9KI,1451
5
+ tritonparse/ir_parser.py,sha256=1j1tP9jpUN7wH3e01bKUkUPgTMlNXUdp8LKRCC-WTro,9324
6
+ tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
7
+ tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
8
+ tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
9
+ tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
10
+ tritonparse/structured_logging.py,sha256=mrApsVigIHZln6De2ElwNTUSZfO61OHXgd08o2X26sM,40430
11
+ tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
12
+ tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
13
+ tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
14
+ tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
16
+ tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
17
+ tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
18
+ tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
19
+ tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
20
+ tritonparse-0.2.0.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
21
+ tritonparse-0.2.0.dist-info/METADATA,sha256=kHgvrwCWkFfyy2EvE1pRR8OyfP4OamUvaajLhJE7hyo,6151
22
+ tritonparse-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ tritonparse-0.2.0.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
24
+ tritonparse-0.2.0.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- """Reproducer subpackage: generate runnable Triton repro scripts from traces.
2
-
3
- Contains:
4
- - ingestion.ndjson: parse NDJSON and build a context bundle
5
- - orchestrator: LLM-based code generation with optional execute/repair
6
- - providers: LLM provider protocol and Gemini provider
7
- - prompts: simple prompt loader and templates
8
- - runtime.executor: helper to run generated Python scripts
9
- - param_generator: synthesize tensor/scalar allocations to reduce LLM burden
10
- """
11
-
12
- from .ingestion.ndjson import build_context_bundle
13
- from .orchestrator import generate_from_ndjson
14
- from .param_generator import generate_allocation_snippet, generate_kwargs_dict
15
-
16
- __all__ = [
17
- "build_context_bundle",
18
- "generate_from_ndjson",
19
- "generate_allocation_snippet",
20
- "generate_kwargs_dict",
21
- ]
@@ -1,81 +0,0 @@
1
- import argparse
2
- import sys
3
-
4
-
5
- def main() -> None:
6
- p = argparse.ArgumentParser(
7
- description=(
8
- "Generate a runnable Triton repro script from a tritonparse NDJSON" " trace"
9
- )
10
- )
11
- p.add_argument("--ndjson", required=True, help="Path to NDJSON trace file")
12
- p.add_argument(
13
- "--launch-index",
14
- type=int,
15
- default=0,
16
- help="Launch index to reproduce",
17
- )
18
- p.add_argument("--out", default="repro.py", help="Output Python file path")
19
- p.add_argument(
20
- "--execute",
21
- action="store_true",
22
- help="Execute the generated script",
23
- )
24
- p.add_argument(
25
- "--retries",
26
- type=int,
27
- default=0,
28
- help="Auto-repair attempts if execution fails",
29
- )
30
- p.add_argument(
31
- "--temperature",
32
- type=float,
33
- help="Override sampling temperature",
34
- )
35
- p.add_argument(
36
- "--max-tokens",
37
- type=int,
38
- help="Override max tokens for generation",
39
- )
40
- args = p.parse_args()
41
-
42
- # Lazy imports to allow `--help` without optional deps installed
43
- from .config import load_config
44
- from .orchestrator import generate_from_ndjson
45
-
46
- try:
47
- from .factory import make_gemini_provider
48
- except Exception: # pragma: no cover
49
- print(
50
- "Failed to import provider factory. Ensure optional deps are installed (e.g. google-genai).",
51
- file=sys.stderr,
52
- )
53
- raise
54
-
55
- cfg = load_config()
56
- try:
57
- provider = make_gemini_provider()
58
- except ModuleNotFoundError: # pragma: no cover
59
- print(
60
- "Gemini provider requires 'google-genai'. Install via: pip install google-genai",
61
- file=sys.stderr,
62
- )
63
- sys.exit(2)
64
- temperature = args.temperature if args.temperature is not None else cfg.temperature
65
- max_tokens = args.max_tokens if args.max_tokens is not None else cfg.max_tokens
66
-
67
- res = generate_from_ndjson(
68
- args.ndjson,
69
- provider,
70
- launch_index=args.launch_index,
71
- out_py=args.out,
72
- execute=args.execute,
73
- retries=args.retries,
74
- temperature=temperature,
75
- max_tokens=max_tokens,
76
- )
77
- print(res)
78
-
79
-
80
- if __name__ == "__main__": # pragma: no cover
81
- main()
@@ -1,37 +0,0 @@
1
- import argparse
2
-
3
- from .config import load_config
4
- from .factory import make_gemini_provider
5
- from .orchestrator import generate_from_ndjson
6
-
7
-
8
- def add_reproducer_subparser(parser: argparse.ArgumentParser) -> None:
9
- sub = parser.add_subparsers(dest="subcommand")
10
- repro = sub.add_parser(
11
- "repro",
12
- help="Generate a runnable Triton repro script from NDJSON",
13
- )
14
- repro.add_argument("--ndjson", required=True)
15
- repro.add_argument("--launch-index", type=int, default=0)
16
- repro.add_argument("--out", default="repro.py")
17
- repro.add_argument("--execute", action="store_true")
18
- repro.add_argument("--retries", type=int, default=0)
19
-
20
-
21
- def maybe_handle_reproducer(args: argparse.Namespace) -> bool:
22
- if getattr(args, "subcommand", None) != "repro":
23
- return False
24
- cfg = load_config()
25
- provider = make_gemini_provider()
26
- res = generate_from_ndjson(
27
- args.ndjson,
28
- provider,
29
- launch_index=args.launch_index,
30
- out_py=args.out,
31
- execute=args.execute,
32
- retries=args.retries,
33
- temperature=cfg.temperature,
34
- max_tokens=cfg.max_tokens,
35
- )
36
- print(res)
37
- return True
@@ -1,15 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
-
4
-
5
- @dataclass
6
- class ReproducerConfig:
7
- project: str = os.getenv("GOOGLE_CLOUD_PROJECT", "")
8
- location: str = os.getenv("GOOGLE_LOCATION", "us-central1")
9
- model: str = os.getenv("TP_REPRO_MODEL", "gemini-2.5-pro")
10
- temperature: float = float(os.getenv("TP_REPRO_TEMPERATURE", "0.1"))
11
- max_tokens: int = int(os.getenv("TP_REPRO_MAX_TOKENS", "10240"))
12
-
13
-
14
- def load_config() -> ReproducerConfig:
15
- return ReproducerConfig()
@@ -1,16 +0,0 @@
1
- """Provider factory for reproducer.
2
-
3
- Currently supports Gemini only.
4
- """
5
-
6
- from .config import load_config
7
- from .providers.gemini import GeminiProvider
8
-
9
-
10
- def make_gemini_provider() -> GeminiProvider:
11
- cfg = load_config()
12
- return GeminiProvider(
13
- project=cfg.project,
14
- location=cfg.location,
15
- model=cfg.model,
16
- )
@@ -1,6 +0,0 @@
1
- """Ingestion utilities for reproducer.
2
-
3
- Currently supports NDJSON trace parsing.
4
- """
5
-
6
- __all__ = []
@@ -1,165 +0,0 @@
1
- import json
2
- from typing import Any, Dict, List
3
-
4
-
5
- def _iter_events(path: str):
6
- with open(path, "r", encoding="utf-8") as f:
7
- for line in f:
8
- line = line.strip()
9
- if not line:
10
- continue
11
- try:
12
- yield json.loads(line)
13
- except json.JSONDecodeError:
14
- # skip malformed lines
15
- continue
16
-
17
-
18
- def _index_compilations(events: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
19
- idx = {}
20
- for e in events:
21
- if e.get("event_type") != "compilation":
22
- continue
23
- payload = e.get("payload") or {}
24
- meta = payload.get("metadata") or {}
25
- h = meta.get("hash")
26
- if h:
27
- idx[h] = e
28
- return idx
29
-
30
-
31
- def _get_launches(events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
32
- return [e for e in events if e.get("event_type") == "launch"]
33
-
34
-
35
- def _resolve_kernel_source(
36
- launch: Dict[str, Any], comp_idx: Dict[str, Dict[str, Any]]
37
- ) -> str:
38
- # In new format, launch has top-level compilation_metadata, not payload.*
39
- comp_meta = (
40
- launch.get("compilation_metadata")
41
- or launch.get("payload", {}).get("compilation_metadata")
42
- or {}
43
- )
44
- h = comp_meta.get("hash")
45
- if not h:
46
- return ""
47
- comp = comp_idx.get(h, {})
48
- payload = comp.get("payload") or {}
49
- py = payload.get("python_source") or {}
50
- return py.get("code", "")
51
-
52
-
53
- def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
54
- packed = {}
55
- for k, v in args.items():
56
- t = v.get("type") if isinstance(v, dict) else None
57
- if t == "tensor":
58
- packed[k] = {
59
- "type": "tensor",
60
- "shape": v.get("shape") if isinstance(v, dict) else None,
61
- "dtype": v.get("dtype") if isinstance(v, dict) else None,
62
- "device": v.get("device") if isinstance(v, dict) else None,
63
- "stride": v.get("stride") if isinstance(v, dict) else None,
64
- "is_contiguous": (
65
- v.get("is_contiguous") if isinstance(v, dict) else None
66
- ),
67
- "numel": v.get("numel") if isinstance(v, dict) else None,
68
- }
69
- else:
70
- # scalar / NoneType etc
71
- if isinstance(v, dict):
72
- packed[k] = {
73
- "type": v.get("type"),
74
- "value": v.get("value", v.get("repr")),
75
- }
76
- else:
77
- packed[k] = {
78
- "type": None,
79
- "value": v,
80
- }
81
- return packed
82
-
83
-
84
- # Sentinel and helper to normalize extracted argument values
85
- _SKIP = object()
86
-
87
-
88
- def _decode_arg(raw: Any):
89
- if not isinstance(raw, dict):
90
- return raw
91
- t = raw.get("type")
92
- if t == "tensor":
93
- return _SKIP
94
- if t == "NoneType":
95
- return None
96
- return raw.get("value", raw.get("repr"))
97
-
98
-
99
- def build_context_bundle(ndjson_path: str, launch_index: int = 0) -> Dict[str, Any]:
100
- events = list(_iter_events(ndjson_path))
101
- launches = _get_launches(events)
102
- if not launches:
103
- raise RuntimeError("No launch events found in NDJSON.")
104
- if launch_index < 0 or launch_index >= len(launches):
105
- raise IndexError(
106
- f"launch_index out of range: {launch_index} (total {len(launches)})"
107
- )
108
- launch = launches[launch_index]
109
- comp_idx = _index_compilations(events)
110
- kernel_source = _resolve_kernel_source(launch, comp_idx)
111
- # find '@triton.jit' and slice the string
112
- jit_marker = "@triton.jit"
113
- jit_pos = kernel_source.find(jit_marker)
114
- if jit_pos != -1:
115
- kernel_source = kernel_source[jit_pos:]
116
-
117
- # flatten launch fields (support both formats)
118
- grid = launch.get("grid") or (launch.get("payload", {})).get("grid")
119
- comp_meta = (
120
- launch.get("compilation_metadata")
121
- or (launch.get("payload", {})).get("compilation_metadata")
122
- or {}
123
- )
124
- extracted_args = (
125
- launch.get("extracted_args")
126
- or (launch.get("payload", {})).get("extracted_args")
127
- or {}
128
- )
129
-
130
- # compile metadata subset we care about
131
- compile_block = {
132
- "num_warps": comp_meta.get("num_warps"),
133
- "num_stages": comp_meta.get("num_stages"),
134
- "arch": comp_meta.get("arch"),
135
- "backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
136
- "triton_version": comp_meta.get("triton_version"),
137
- "hash": comp_meta.get("hash"),
138
- }
139
-
140
- # kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
141
- kwargs = {}
142
- for k, v in extracted_args.items():
143
- val = _decode_arg(v)
144
- if val is _SKIP:
145
- continue
146
- kwargs[k] = val
147
-
148
- # tensor args: only tensors
149
- tensor_args = {
150
- k: v
151
- for k, v in extracted_args.items()
152
- if isinstance(v, dict) and v.get("type") == "tensor"
153
- }
154
-
155
- bundle = {
156
- "kernel_source": kernel_source,
157
- "compile": compile_block,
158
- "launch": {
159
- "grid": grid,
160
- "kwargs": kwargs,
161
- },
162
- "args": _pack_args(extracted_args),
163
- "tensor_args": _pack_args(tensor_args),
164
- }
165
- return bundle
@@ -1,65 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any, Dict
3
-
4
- from .ingestion.ndjson import build_context_bundle
5
- from .param_generator import generate_allocation_snippet, generate_kwargs_dict
6
- from .prompts.loader import render_prompt
7
- from .providers.base import LLMProvider
8
- from .runtime.executor import run_python
9
-
10
-
11
- def _excerpt(s: str, n: int = 160):
12
- lines = s.splitlines()
13
- return "\n".join(lines[:n])
14
-
15
-
16
- def generate_from_ndjson(
17
- ndjson_path: str,
18
- provider: LLMProvider,
19
- *,
20
- launch_index=0,
21
- out_py="repro.py",
22
- execute=False,
23
- retries: int = 0,
24
- **gen_kwargs,
25
- ) -> Dict[str, Any]:
26
- bundle = build_context_bundle(ndjson_path, launch_index=launch_index)
27
- # Augment bundle with pre-generated parameter allocation code to reduce LLM burden
28
- allocation_snippet = generate_allocation_snippet(bundle)
29
- kwargs_dict = generate_kwargs_dict(bundle)
30
- context = {
31
- **bundle,
32
- "allocation_snippet": allocation_snippet,
33
- "kwargs_dict": kwargs_dict,
34
- }
35
- system_prompt = render_prompt("system.txt", context)
36
- user_prompt = render_prompt("generate_one_shot.txt", context)
37
-
38
- code = provider.generate_code(system_prompt, user_prompt, **gen_kwargs)
39
- Path(out_py).write_text(code, encoding="utf-8")
40
-
41
- if not execute:
42
- return {"path": out_py}
43
-
44
- # Execute and optionally repair
45
- rc, out, err = run_python(out_py)
46
- attempt = 0
47
- while rc != 0 and attempt < retries:
48
- attempt += 1
49
- # Build repair prompt
50
- repair_ctx = {
51
- "prev_code_excerpt": _excerpt(code, 200),
52
- "error_text": err[-4000:] if err else "(no stderr)",
53
- }
54
- repair_prompt = render_prompt("repair_loop.txt", repair_ctx)
55
- code = provider.generate_code(system_prompt, repair_prompt, **gen_kwargs)
56
- Path(out_py).write_text(code, encoding="utf-8")
57
- rc, out, err = run_python(out_py)
58
-
59
- return {
60
- "path": out_py,
61
- "returncode": rc,
62
- "stdout": out,
63
- "stderr": err,
64
- "retries_used": attempt,
65
- }
@@ -1,142 +0,0 @@
1
- """Parameter generator: produce deterministic allocation code from a bundle.
2
-
3
- This module reduces LLM burden by emitting Python code that:
4
- - selects a device
5
- - seeds RNG
6
- - allocates tensors with the exact shape/dtype/device/stride
7
- - prepares scalar/constexpr kwargs
8
-
9
- The generated code is intended to be inserted into the final repro script.
10
- """
11
-
12
- import json
13
- from typing import Any, Dict, List, Optional
14
-
15
-
16
- def _torch_dtype_expr(dtype: str) -> str:
17
- mapping = {
18
- "float16": "torch.float16",
19
- "bfloat16": "torch.bfloat16",
20
- "float32": "torch.float32",
21
- "float": "torch.float32",
22
- "float64": "torch.float64",
23
- "half": "torch.float16",
24
- "bf16": "torch.bfloat16",
25
- "fp16": "torch.float16",
26
- "fp32": "torch.float32",
27
- "fp64": "torch.float64",
28
- "int8": "torch.int8",
29
- "int16": "torch.int16",
30
- "int32": "torch.int32",
31
- "int64": "torch.int64",
32
- "long": "torch.int64",
33
- "bool": "torch.bool",
34
- }
35
- return mapping.get(str(dtype).lower(), "torch.float32")
36
-
37
-
38
- def _compute_storage_numel(shape: List[int], stride: Optional[List[int]]) -> int:
39
- if not shape:
40
- return 1
41
- if not stride:
42
- # contiguous default
43
- numel = 1
44
- for s in shape:
45
- numel *= int(s)
46
- return numel
47
- # minimal storage size (in elements) to support the given logical shape/stride
48
- max_index = 0
49
- for dim, (sz, st) in enumerate(zip(shape, stride)):
50
- if sz <= 0:
51
- continue
52
- max_index = max(max_index, (int(sz) - 1) * int(st))
53
- return int(max_index) + 1
54
-
55
-
56
- def _emit_tensor_alloc(name: str, spec: Dict[str, Any]) -> str:
57
- shape = spec.get("shape") or []
58
- dtype = _torch_dtype_expr(spec.get("dtype"))
59
- device = spec.get("device") or "cuda:0"
60
- stride = spec.get("stride")
61
-
62
- # ensure ints
63
- shape = [int(s) for s in shape]
64
- if stride is not None:
65
- stride_list = [int(x) for x in stride]
66
- else:
67
- stride_list = None
68
-
69
- lines: List[str] = []
70
- # allocate backing storage
71
- storage_numel = _compute_storage_numel(shape, stride_list)
72
- lines.append(
73
- f"# {name}: shape={shape}, dtype={dtype}, device={device}, stride={stride_list}"
74
- )
75
- lines.append(
76
- f"_storage_{name} = torch.empty(({storage_numel},), dtype={dtype}, device=device)"
77
- )
78
- if stride_list:
79
- # Create an as_strided view over the 1D storage
80
- sizes_expr = str(tuple(shape))
81
- strides_expr = str(tuple(stride_list))
82
- lines.append(
83
- f"{name} = _storage_{name}.as_strided(size={sizes_expr}, stride={strides_expr})"
84
- )
85
- else:
86
- # contiguous allocation
87
- size_expr = str(tuple(shape))
88
- lines.append(f"{name} = torch.empty({size_expr}, dtype={dtype}, device=device)")
89
- return "\n".join(lines)
90
-
91
-
92
- def _emit_scalar(name: str, spec: Dict[str, Any]) -> str:
93
- value = spec.get("value")
94
- # Preserve JSON-serializable value as-is
95
- return f"{name} = {json.dumps(value)}"
96
-
97
-
98
- def generate_allocation_snippet(bundle: Dict[str, Any]) -> str:
99
- """Generate a self-contained code snippet that:
100
- - imports torch
101
- - sets device
102
- - seeds RNG
103
- - allocates tensors and defines scalars for all args
104
- Returns Python source as a string.
105
- """
106
- tensor_args: Dict[str, Any] = bundle.get("tensor_args", {}) or {}
107
- args_all: Dict[str, Any] = bundle.get("args", {}) or {}
108
-
109
- # Pick device from any tensor arg, fallback to cuda:0
110
- device = "cuda:0"
111
- for spec in tensor_args.values():
112
- dev = spec.get("device")
113
- if dev:
114
- device = str(dev)
115
- break
116
-
117
- lines: List[str] = []
118
- lines.append("import torch")
119
- lines.append(f"device = '{device}'")
120
- lines.append("torch.manual_seed(0)")
121
- lines.append("if torch.cuda.is_available(): torch.cuda.manual_seed_all(0)")
122
- lines.append("")
123
-
124
- # Emit tensors first for names with type==tensor in args_all
125
- for name, spec in args_all.items():
126
- if isinstance(spec, dict) and spec.get("type") == "tensor":
127
- lines.append(_emit_tensor_alloc(name, spec))
128
- lines.append("")
129
-
130
- # Emit non-tensor scalars next
131
- for name, spec in args_all.items():
132
- if not isinstance(spec, dict) or spec.get("type") == "tensor":
133
- continue
134
- lines.append(_emit_scalar(name, spec))
135
- return "\n".join(lines)
136
-
137
-
138
- def generate_kwargs_dict(bundle: Dict[str, Any]) -> Dict[str, Any]:
139
- """Return a kwargs dict derived from bundle['launch']['kwargs'] suitable for kernel call."""
140
- launch = bundle.get("launch", {}) or {}
141
- kwargs = launch.get("kwargs", {}) or {}
142
- return kwargs
@@ -1 +0,0 @@
1
- __all__ = []
@@ -1,18 +0,0 @@
1
- import json
2
- from pathlib import Path
3
- from typing import Any, Dict
4
-
5
- PROMPTS_DIR = Path(__file__).parent
6
-
7
-
8
- def render_prompt(name: str, context: Dict[str, Any]) -> str:
9
- text = (PROMPTS_DIR / name).read_text(encoding="utf-8")
10
- # very simple {{key}} replacement for top-level keys; JSON for dicts
11
- for k, v in context.items():
12
- token = "{{ " + k + " }}"
13
- if token in text:
14
- if isinstance(v, (dict, list)):
15
- text = text.replace(token, json.dumps(v, ensure_ascii=False, indent=2))
16
- else:
17
- text = text.replace(token, str(v))
18
- return text
@@ -1 +0,0 @@
1
- __all__ = []
@@ -1,14 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Protocol
2
-
3
-
4
- class LLMProvider(Protocol):
5
- def generate_code(
6
- self,
7
- system_prompt: str,
8
- user_prompt: str,
9
- *,
10
- temperature: float = 0.2,
11
- max_tokens: int = 8192,
12
- stop: Optional[List[str]] = None,
13
- extra: Optional[Dict[str, Any]] = None,
14
- ) -> str: ...
@@ -1,47 +0,0 @@
1
- import os
2
- import re
3
- from typing import Any, Dict, List, Optional
4
-
5
- from google.genai import Client
6
-
7
-
8
- def _extract_python_block(s: str) -> str:
9
- m = re.search(r"""```python\s+(.*?)```""", s, flags=re.S)
10
- return m.group(1).strip() if m else ""
11
-
12
-
13
- class GeminiProvider:
14
- def __init__(
15
- self, project: str, location: str = "us-central1", model: str = "gemini-2.5-pro"
16
- ):
17
- # Expect GOOGLE_APPLICATIONS_CREDENTIALS to be set
18
- if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
19
- raise EnvironmentError("GOOGLE_APPLICATION_CREDENTIALS not set.")
20
- self.client = Client(vertexai=True, project=project, location=location)
21
- self.model = model
22
-
23
- def generate_code(
24
- self,
25
- system_prompt: str,
26
- user_prompt: str,
27
- *,
28
- temperature: float = 0.2,
29
- max_tokens: int = 8192,
30
- stop: Optional[List[str]] = None,
31
- extra: Optional[Dict[str, Any]] = None,
32
- ) -> str:
33
- # Gemini doesn't have a 'system' role in this SDK, prepend system to user
34
- full_prompt = f"{system_prompt.strip()}\n\n---\n\n{user_prompt.strip()}"
35
- resp = self.client.models.generate_content(
36
- model=self.model,
37
- contents=full_prompt,
38
- config={
39
- "temperature": temperature,
40
- "max_output_tokens": max_tokens,
41
- },
42
- )
43
- text = getattr(resp, "text", "") or ""
44
- code = _extract_python_block(text) or text
45
- if not code.strip():
46
- raise RuntimeError(f"Empty response from Gemini. Raw: {text[:2000]}")
47
- return code
@@ -1 +0,0 @@
1
- __all__ = []
@@ -1,13 +0,0 @@
1
- import subprocess
2
- import sys
3
-
4
-
5
- def run_python(path: str, timeout: int = 60):
6
- p = subprocess.Popen(
7
- [sys.executable, path],
8
- stdout=subprocess.PIPE,
9
- stderr=subprocess.PIPE,
10
- text=True,
11
- )
12
- out, err = p.communicate(timeout=timeout)
13
- return p.returncode, out, err
@@ -1,6 +0,0 @@
1
- from pathlib import Path
2
-
3
-
4
- def write_text(path: str, content: str, *, encoding="utf-8"):
5
- Path(path).parent.mkdir(parents=True, exist_ok=True)
6
- Path(path).write_text(content, encoding=encoding)
@@ -1,10 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: tritonparse
3
- Version: 0.1.1
4
- Project-URL: Homepage, https://github.com/meta-pytorch/tritonparse
5
- Requires-Python: >=3.10
6
- License-File: LICENSE
7
- Requires-Dist: triton>3.3.1
8
- Provides-Extra: test
9
- Requires-Dist: coverage>=7.0.0; extra == "test"
10
- Dynamic: license-file
@@ -1,40 +0,0 @@
1
- tritonparse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- tritonparse/common.py,sha256=VWKFZJG7msIExMH0QNYCO8SHKqUBFtdRuLaCy_K8DFI,13725
3
- tritonparse/event_diff.py,sha256=yOD6uNxLJroatfx2nEGr-erw24ObOrHU9P6V5pzr8do,4907
4
- tritonparse/extract_source_mappings.py,sha256=Z6UxFj2cCE5NCWLQTYPKqUpLfbYhqP8xgCl5mvud9KI,1451
5
- tritonparse/ir_parser.py,sha256=1j1tP9jpUN7wH3e01bKUkUPgTMlNXUdp8LKRCC-WTro,9324
6
- tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
7
- tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
8
- tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
9
- tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
10
- tritonparse/structured_logging.py,sha256=qWGkAr2oZT8Adxj-EfQqQdX2l-jq9xKi7WhBecYq2bg,39600
11
- tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
12
- tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
13
- tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
14
- tritonparse/reproducer/__init__.py,sha256=VcCpYVUUmclWotkQmPLlDu5iFOUE4N-4FzcbzXwIow0,773
15
- tritonparse/reproducer/__main__.py,sha256=ydLMbWx7SFlvAb1erObvKcJ-uxhHNShGRHRrZO6-5ww,2266
16
- tritonparse/reproducer/cli.py,sha256=nnMgdT4tQzBWjPowOy2a_5QRsnsTMEAA3uegYpLEyRE,1165
17
- tritonparse/reproducer/config.py,sha256=-hmE5ZqEtYo2WKjXbwMi6k6XzzfQZAbL50UURPvcF3A,478
18
- tritonparse/reproducer/factory.py,sha256=sFcIjIayfHAqPqMVT8Rnsz9tpMmQXBzoOlKprS1P_1g,341
19
- tritonparse/reproducer/orchestrator.py,sha256=iD7zZZHE4FU3nNOwNV9SUY2WUcpv_Amg0SvnRxrseEQ,2045
20
- tritonparse/reproducer/param_generator.py,sha256=m-C_Z1TLd1ZX49EpsWELVfB6tkwOfi-ZHma7wXwz2g4,4654
21
- tritonparse/reproducer/ingestion/__init__.py,sha256=2AQHxWlUl5JXM4a8F033wzxVnjCVPBEf-4H99kep-OA,99
22
- tritonparse/reproducer/ingestion/ndjson.py,sha256=_E_dXXjxu438OYomQ1zFFk3jV9Wr1jNoXHiP2gJG7_4,5172
23
- tritonparse/reproducer/prompts/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
24
- tritonparse/reproducer/prompts/loader.py,sha256=n6Of98eEXNz9mI7ZH073X5FihNZD7tI-ehfjN_4yEl0,610
25
- tritonparse/reproducer/providers/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
26
- tritonparse/reproducer/providers/base.py,sha256=DgP_4AdrEf48kstOfBJFvK3pndcHH0vRUGjp6k1bdsY,362
27
- tritonparse/reproducer/providers/gemini.py,sha256=VlOCdTGRTQdr3c2HMclKFIk-133puGSAjhK_6m6Zj9g,1609
28
- tritonparse/reproducer/runtime/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
29
- tritonparse/reproducer/runtime/executor.py,sha256=AqBFnoEqURoMGDdLC2G3WpHIP3Y4wWGJHEZrjS-NQFM,304
30
- tritonparse/reproducer/utils/io.py,sha256=95NF9QCGawl-5p5c5yCQHynVBNKS_B_7nIrqnRvAt-E,200
31
- tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
33
- tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
34
- tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
35
- tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
36
- tritonparse-0.1.1.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
37
- tritonparse-0.1.1.dist-info/METADATA,sha256=qETEInGJRT7fzf-Rl8cAf6QnEa5eJgiYo4rbwBA63yc,287
38
- tritonparse-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- tritonparse-0.1.1.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
40
- tritonparse-0.1.1.dist-info/RECORD,,