tritonparse 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/structured_logging.py +29 -5
- tritonparse/tools/readme.md +1 -0
- tritonparse-0.2.0.dist-info/METADATA +139 -0
- tritonparse-0.2.0.dist-info/RECORD +24 -0
- tritonparse/reproducer/__init__.py +0 -21
- tritonparse/reproducer/__main__.py +0 -81
- tritonparse/reproducer/cli.py +0 -37
- tritonparse/reproducer/config.py +0 -15
- tritonparse/reproducer/factory.py +0 -16
- tritonparse/reproducer/ingestion/__init__.py +0 -6
- tritonparse/reproducer/ingestion/ndjson.py +0 -165
- tritonparse/reproducer/orchestrator.py +0 -65
- tritonparse/reproducer/param_generator.py +0 -142
- tritonparse/reproducer/prompts/__init__.py +0 -1
- tritonparse/reproducer/prompts/loader.py +0 -18
- tritonparse/reproducer/providers/__init__.py +0 -1
- tritonparse/reproducer/providers/base.py +0 -14
- tritonparse/reproducer/providers/gemini.py +0 -47
- tritonparse/reproducer/runtime/__init__.py +0 -1
- tritonparse/reproducer/runtime/executor.py +0 -13
- tritonparse/reproducer/utils/io.py +0 -6
- tritonparse-0.1.1.dist-info/METADATA +0 -10
- tritonparse-0.1.1.dist-info/RECORD +0 -40
- {tritonparse-0.1.1.dist-info → tritonparse-0.2.0.dist-info}/WHEEL +0 -0
- {tritonparse-0.1.1.dist-info → tritonparse-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {tritonparse-0.1.1.dist-info → tritonparse-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ from collections.abc import Mapping
|
|
|
15
15
|
from dataclasses import asdict, is_dataclass
|
|
16
16
|
from datetime import date, datetime
|
|
17
17
|
from enum import Enum
|
|
18
|
+
from functools import partial
|
|
18
19
|
from pathlib import Path
|
|
19
20
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
20
21
|
|
|
@@ -756,7 +757,7 @@ def maybe_trace_triton(
|
|
|
756
757
|
|
|
757
758
|
# Add timing information if available
|
|
758
759
|
if times:
|
|
759
|
-
trace_data["times"] = times
|
|
760
|
+
trace_data["metadata"]["times"] = times
|
|
760
761
|
# Log the collected information through the tracing system
|
|
761
762
|
trace_structured_triton(
|
|
762
763
|
event_type,
|
|
@@ -819,10 +820,18 @@ def extract_arg_info(arg_dict):
|
|
|
819
820
|
return extracted_args
|
|
820
821
|
|
|
821
822
|
|
|
822
|
-
def add_launch_metadata(grid, metadata, arg_dict):
|
|
823
|
+
def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
|
|
823
824
|
# Extract detailed argument information
|
|
824
825
|
extracted_args = extract_arg_info(arg_dict)
|
|
825
|
-
|
|
826
|
+
extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
|
|
827
|
+
return {
|
|
828
|
+
"launch_metadata_tritonparse": (
|
|
829
|
+
grid,
|
|
830
|
+
metadata._asdict(),
|
|
831
|
+
extracted_args,
|
|
832
|
+
extracted_inductor_args,
|
|
833
|
+
)
|
|
834
|
+
}
|
|
826
835
|
|
|
827
836
|
|
|
828
837
|
class JITHookImpl(JITHook):
|
|
@@ -848,6 +857,7 @@ class JITHookImpl(JITHook):
|
|
|
848
857
|
compile,
|
|
849
858
|
is_manual_warmup: bool,
|
|
850
859
|
already_compiled: bool,
|
|
860
|
+
inductor_args: Optional[Dict[str, Any]] = None,
|
|
851
861
|
) -> Optional[bool]:
|
|
852
862
|
"""
|
|
853
863
|
Override or set the launch_metadata function for the JIT-compiled kernel.
|
|
@@ -875,12 +885,16 @@ class JITHookImpl(JITHook):
|
|
|
875
885
|
return True
|
|
876
886
|
|
|
877
887
|
# Get the current launch_metadata function if it exists
|
|
878
|
-
|
|
888
|
+
function = getattr(fn, "jit_function", fn)
|
|
889
|
+
|
|
890
|
+
current_launch_metadata = getattr(function, "launch_metadata", None)
|
|
879
891
|
if current_launch_metadata is not None:
|
|
880
892
|
log.warning(
|
|
881
893
|
f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
|
|
882
894
|
)
|
|
883
|
-
|
|
895
|
+
function.launch_metadata = partial(
|
|
896
|
+
add_launch_metadata, inductor_args=inductor_args
|
|
897
|
+
)
|
|
884
898
|
return True
|
|
885
899
|
|
|
886
900
|
|
|
@@ -944,6 +958,7 @@ class LaunchHookImpl(LaunchHook):
|
|
|
944
958
|
trace_data["extracted_args"] = launch_metadata_tritonparse[
|
|
945
959
|
2
|
|
946
960
|
] # Now contains detailed arg info
|
|
961
|
+
trace_data["extracted_inductor_args"] = launch_metadata_tritonparse[3]
|
|
947
962
|
trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
|
|
948
963
|
|
|
949
964
|
|
|
@@ -1011,6 +1026,15 @@ def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
|
|
|
1011
1026
|
knobs.compilation.listener = maybe_trace_triton
|
|
1012
1027
|
|
|
1013
1028
|
|
|
1029
|
+
def init_with_env():
|
|
1030
|
+
"""
|
|
1031
|
+
This function is used to initialize TritonParse with the environment variable TRITON_TRACE_FOLDER and TRITON_TRACE_LAUNCH specifically.
|
|
1032
|
+
It is only supposed to be used in OSS triton's source code.
|
|
1033
|
+
"""
|
|
1034
|
+
if triton_trace_folder:
|
|
1035
|
+
init(triton_trace_folder, enable_trace_launch=TRITON_TRACE_LAUNCH)
|
|
1036
|
+
|
|
1037
|
+
|
|
1014
1038
|
def clear_logging_config():
|
|
1015
1039
|
"""
|
|
1016
1040
|
Clear all configurations made by init() and init_basic().
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
The tool scripts in this folder are used separately. They are not part of the main tritonparse functionality.
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tritonparse
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
|
|
5
|
+
Author-email: Yueming Hao <yhao@meta.com>
|
|
6
|
+
License-Expression: BSD-3-Clause
|
|
7
|
+
Project-URL: Homepage, https://github.com/meta-pytorch/tritonparse
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: triton>3.3.1
|
|
12
|
+
Provides-Extra: test
|
|
13
|
+
Requires-Dist: coverage>=7.0.0; extra == "test"
|
|
14
|
+
Dynamic: license-file
|
|
15
|
+
|
|
16
|
+
# TritonParse
|
|
17
|
+
|
|
18
|
+
[](https://opensource.org/licenses/BSD-3-Clause)
|
|
19
|
+
[](https://meta-pytorch.org/tritonparse/)
|
|
20
|
+
|
|
21
|
+
**A comprehensive visualization and analysis tool for Triton IR files** — helping developers analyze, debug, and understand Triton kernel compilation processes.
|
|
22
|
+
|
|
23
|
+
🌐 **[Try it online →](https://meta-pytorch.org/tritonparse/?json_url=https://meta-pytorch.org/tritonparse/dedicated_log_triton_trace_findhao__mapped.ndjson.gz)**
|
|
24
|
+
|
|
25
|
+
## ✨ Key Features
|
|
26
|
+
|
|
27
|
+
- **🚀 Launch Difference Analysis** - Automatically detect and visualize variations in kernel launch parameters, helping you pinpoint performance bottlenecks and debug launch configurations.
|
|
28
|
+
- **🔍 Interactive Visualization** - Explore Triton kernels with detailed metadata and stack traces
|
|
29
|
+
- **📊 Multi-format IR Support** - View TTGIR, TTIR, LLIR, PTX, and AMDGCN in one place
|
|
30
|
+
- **🔄 Side-by-side Comparison** - Compare IR stages with synchronized highlighting
|
|
31
|
+
- **📝 Structured Logging** - Capture detailed compilation and launch events with source mapping
|
|
32
|
+
- **🌐 Ready-to-use Interface** - No installation required, works in your browser
|
|
33
|
+
- **🔒 Privacy-first** - All processing happens locally in your browser, no data uploaded
|
|
34
|
+
|
|
35
|
+
## 🚀 Quick Start
|
|
36
|
+
|
|
37
|
+
### 1. Generate Traces
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import tritonparse.structured_logging
|
|
41
|
+
|
|
42
|
+
# Initialize logging with launch tracing enabled
|
|
43
|
+
tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)
|
|
44
|
+
|
|
45
|
+
# Your Triton/PyTorch code here
|
|
46
|
+
# ... your kernels ...
|
|
47
|
+
|
|
48
|
+
# Parse and generate trace files
|
|
49
|
+
import tritonparse.utils
|
|
50
|
+
tritonparse.utils.unified_parse("./logs/")
|
|
51
|
+
```
|
|
52
|
+
The example terminal output is:
|
|
53
|
+
```bash
|
|
54
|
+
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
|
|
55
|
+
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output
|
|
56
|
+
|
|
57
|
+
================================================================================
|
|
58
|
+
📁 TRITONPARSE PARSING RESULTS
|
|
59
|
+
================================================================================
|
|
60
|
+
📂 Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
|
|
61
|
+
📊 Total files generated: 2
|
|
62
|
+
|
|
63
|
+
📄 Generated files:
|
|
64
|
+
--------------------------------------------------
|
|
65
|
+
1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
|
|
66
|
+
2. 📝 log_file_list.json (181B)
|
|
67
|
+
================================================================================
|
|
68
|
+
✅ Parsing completed successfully!
|
|
69
|
+
================================================================================
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
### 2. Visualize Results
|
|
73
|
+
|
|
74
|
+
**Visit [https://meta-pytorch.org/tritonparse/](https://meta-pytorch.org/tritonparse/?json_url=https://meta-pytorch.org/tritonparse/dedicated_log_triton_trace_findhao__mapped.ndjson.gz)** and open your local trace files (.ndjson.gz format).
|
|
75
|
+
|
|
76
|
+
> **🔒 Privacy Note**: Your trace files are processed entirely in your browser - nothing is uploaded to any server!
|
|
77
|
+
|
|
78
|
+
## 🛠️ Installation
|
|
79
|
+
|
|
80
|
+
**For basic usage (trace generation):**
|
|
81
|
+
```bash
|
|
82
|
+
git clone https://github.com/meta-pytorch/tritonparse.git
|
|
83
|
+
cd tritonparse
|
|
84
|
+
pip install -e .
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
**Prerequisites:** Python ≥ 3.10, Triton ≥ 3.4.0, GPU required (NVIDIA/AMD)
|
|
88
|
+
|
|
89
|
+
TritonParse relies on new features in Triton. Please install the latest version of Triton:
|
|
90
|
+
```bash
|
|
91
|
+
pip install triton
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## 📚 Complete Documentation
|
|
95
|
+
|
|
96
|
+
| 📖 Guide | Description |
|
|
97
|
+
|----------|-------------|
|
|
98
|
+
| **[🏠 Wiki Home](https://github.com/meta-pytorch/tritonparse/wiki)** | Complete documentation and navigation |
|
|
99
|
+
| **[📦 Installation Guide](https://github.com/meta-pytorch/tritonparse/wiki/01.-Installation)** | Detailed setup for all scenarios |
|
|
100
|
+
| **[📋 Usage Guide](https://github.com/meta-pytorch/tritonparse/wiki/02.-Usage-Guide)** | Complete workflow and examples |
|
|
101
|
+
| **[🌐 Web Interface Guide](https://github.com/meta-pytorch/tritonparse/wiki/03.-Web-Interface-Guide)** | Master the visualization interface |
|
|
102
|
+
| **[🔧 Developer Guide](https://github.com/meta-pytorch/tritonparse/wiki/04.-Developer-Guide)** | Contributing and development setup |
|
|
103
|
+
| **[❓ FAQ](https://github.com/meta-pytorch/tritonparse/wiki/06.-FAQ)** | Frequently asked questions |
|
|
104
|
+
|
|
105
|
+
## 🛠️ Tech Stack
|
|
106
|
+
|
|
107
|
+
- **Frontend**: React 19, TypeScript, Vite, Tailwind CSS, Monaco Editor
|
|
108
|
+
- **Backend**: Python with Triton integration, structured logging
|
|
109
|
+
- **Deployment**: GitHub Pages, automatic deployment
|
|
110
|
+
|
|
111
|
+
## 📊 Understanding Triton Compilation
|
|
112
|
+
|
|
113
|
+
TritonParse visualizes the complete Triton compilation pipeline:
|
|
114
|
+
|
|
115
|
+
**Python Source** → **TTIR** → **TTGIR** → **LLIR** → **PTX/AMDGCN**
|
|
116
|
+
|
|
117
|
+
Each stage can be inspected and compared to understand optimization transformations.
|
|
118
|
+
|
|
119
|
+
## 🤝 Contributing
|
|
120
|
+
|
|
121
|
+
We welcome contributions! Please see our **[Developer Guide](https://github.com/meta-pytorch/tritonparse/wiki/04.-Developer-Guide)** for:
|
|
122
|
+
- Development setup
|
|
123
|
+
- Code formatting standards
|
|
124
|
+
- Pull request process
|
|
125
|
+
- Architecture overview
|
|
126
|
+
|
|
127
|
+
## 📞 Support & Community
|
|
128
|
+
|
|
129
|
+
- **🐛 Report Issues**: [GitHub Issues](https://github.com/meta-pytorch/tritonparse/issues)
|
|
130
|
+
- **💬 Discussions**: [GitHub Discussions](https://github.com/meta-pytorch/tritonparse/discussions)
|
|
131
|
+
- **📚 Documentation**: [TritonParse Wiki](https://github.com/meta-pytorch/tritonparse/wiki)
|
|
132
|
+
|
|
133
|
+
## 📄 License
|
|
134
|
+
|
|
135
|
+
This project is licensed under the BSD-3 License - see the [LICENSE](LICENSE) file for details.
|
|
136
|
+
|
|
137
|
+
---
|
|
138
|
+
|
|
139
|
+
**✨ Ready to get started?** Visit our **[Installation Guide](https://github.com/meta-pytorch/tritonparse/wiki/01.-Installation)** or try the **[online tool](https://meta-pytorch.org/tritonparse/)** directly!
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
tritonparse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
tritonparse/common.py,sha256=VWKFZJG7msIExMH0QNYCO8SHKqUBFtdRuLaCy_K8DFI,13725
|
|
3
|
+
tritonparse/event_diff.py,sha256=yOD6uNxLJroatfx2nEGr-erw24ObOrHU9P6V5pzr8do,4907
|
|
4
|
+
tritonparse/extract_source_mappings.py,sha256=Z6UxFj2cCE5NCWLQTYPKqUpLfbYhqP8xgCl5mvud9KI,1451
|
|
5
|
+
tritonparse/ir_parser.py,sha256=1j1tP9jpUN7wH3e01bKUkUPgTMlNXUdp8LKRCC-WTro,9324
|
|
6
|
+
tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
|
|
7
|
+
tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
|
|
8
|
+
tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
|
|
9
|
+
tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
|
|
10
|
+
tritonparse/structured_logging.py,sha256=mrApsVigIHZln6De2ElwNTUSZfO61OHXgd08o2X26sM,40430
|
|
11
|
+
tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
|
|
12
|
+
tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
|
|
13
|
+
tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
|
|
14
|
+
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
16
|
+
tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
|
|
17
|
+
tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
|
|
18
|
+
tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
|
|
19
|
+
tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
|
|
20
|
+
tritonparse-0.2.0.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
21
|
+
tritonparse-0.2.0.dist-info/METADATA,sha256=kHgvrwCWkFfyy2EvE1pRR8OyfP4OamUvaajLhJE7hyo,6151
|
|
22
|
+
tritonparse-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
tritonparse-0.2.0.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
24
|
+
tritonparse-0.2.0.dist-info/RECORD,,
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
"""Reproducer subpackage: generate runnable Triton repro scripts from traces.
|
|
2
|
-
|
|
3
|
-
Contains:
|
|
4
|
-
- ingestion.ndjson: parse NDJSON and build a context bundle
|
|
5
|
-
- orchestrator: LLM-based code generation with optional execute/repair
|
|
6
|
-
- providers: LLM provider protocol and Gemini provider
|
|
7
|
-
- prompts: simple prompt loader and templates
|
|
8
|
-
- runtime.executor: helper to run generated Python scripts
|
|
9
|
-
- param_generator: synthesize tensor/scalar allocations to reduce LLM burden
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
from .ingestion.ndjson import build_context_bundle
|
|
13
|
-
from .orchestrator import generate_from_ndjson
|
|
14
|
-
from .param_generator import generate_allocation_snippet, generate_kwargs_dict
|
|
15
|
-
|
|
16
|
-
__all__ = [
|
|
17
|
-
"build_context_bundle",
|
|
18
|
-
"generate_from_ndjson",
|
|
19
|
-
"generate_allocation_snippet",
|
|
20
|
-
"generate_kwargs_dict",
|
|
21
|
-
]
|
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import sys
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def main() -> None:
|
|
6
|
-
p = argparse.ArgumentParser(
|
|
7
|
-
description=(
|
|
8
|
-
"Generate a runnable Triton repro script from a tritonparse NDJSON" " trace"
|
|
9
|
-
)
|
|
10
|
-
)
|
|
11
|
-
p.add_argument("--ndjson", required=True, help="Path to NDJSON trace file")
|
|
12
|
-
p.add_argument(
|
|
13
|
-
"--launch-index",
|
|
14
|
-
type=int,
|
|
15
|
-
default=0,
|
|
16
|
-
help="Launch index to reproduce",
|
|
17
|
-
)
|
|
18
|
-
p.add_argument("--out", default="repro.py", help="Output Python file path")
|
|
19
|
-
p.add_argument(
|
|
20
|
-
"--execute",
|
|
21
|
-
action="store_true",
|
|
22
|
-
help="Execute the generated script",
|
|
23
|
-
)
|
|
24
|
-
p.add_argument(
|
|
25
|
-
"--retries",
|
|
26
|
-
type=int,
|
|
27
|
-
default=0,
|
|
28
|
-
help="Auto-repair attempts if execution fails",
|
|
29
|
-
)
|
|
30
|
-
p.add_argument(
|
|
31
|
-
"--temperature",
|
|
32
|
-
type=float,
|
|
33
|
-
help="Override sampling temperature",
|
|
34
|
-
)
|
|
35
|
-
p.add_argument(
|
|
36
|
-
"--max-tokens",
|
|
37
|
-
type=int,
|
|
38
|
-
help="Override max tokens for generation",
|
|
39
|
-
)
|
|
40
|
-
args = p.parse_args()
|
|
41
|
-
|
|
42
|
-
# Lazy imports to allow `--help` without optional deps installed
|
|
43
|
-
from .config import load_config
|
|
44
|
-
from .orchestrator import generate_from_ndjson
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
from .factory import make_gemini_provider
|
|
48
|
-
except Exception: # pragma: no cover
|
|
49
|
-
print(
|
|
50
|
-
"Failed to import provider factory. Ensure optional deps are installed (e.g. google-genai).",
|
|
51
|
-
file=sys.stderr,
|
|
52
|
-
)
|
|
53
|
-
raise
|
|
54
|
-
|
|
55
|
-
cfg = load_config()
|
|
56
|
-
try:
|
|
57
|
-
provider = make_gemini_provider()
|
|
58
|
-
except ModuleNotFoundError: # pragma: no cover
|
|
59
|
-
print(
|
|
60
|
-
"Gemini provider requires 'google-genai'. Install via: pip install google-genai",
|
|
61
|
-
file=sys.stderr,
|
|
62
|
-
)
|
|
63
|
-
sys.exit(2)
|
|
64
|
-
temperature = args.temperature if args.temperature is not None else cfg.temperature
|
|
65
|
-
max_tokens = args.max_tokens if args.max_tokens is not None else cfg.max_tokens
|
|
66
|
-
|
|
67
|
-
res = generate_from_ndjson(
|
|
68
|
-
args.ndjson,
|
|
69
|
-
provider,
|
|
70
|
-
launch_index=args.launch_index,
|
|
71
|
-
out_py=args.out,
|
|
72
|
-
execute=args.execute,
|
|
73
|
-
retries=args.retries,
|
|
74
|
-
temperature=temperature,
|
|
75
|
-
max_tokens=max_tokens,
|
|
76
|
-
)
|
|
77
|
-
print(res)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
if __name__ == "__main__": # pragma: no cover
|
|
81
|
-
main()
|
tritonparse/reproducer/cli.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
|
|
3
|
-
from .config import load_config
|
|
4
|
-
from .factory import make_gemini_provider
|
|
5
|
-
from .orchestrator import generate_from_ndjson
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def add_reproducer_subparser(parser: argparse.ArgumentParser) -> None:
|
|
9
|
-
sub = parser.add_subparsers(dest="subcommand")
|
|
10
|
-
repro = sub.add_parser(
|
|
11
|
-
"repro",
|
|
12
|
-
help="Generate a runnable Triton repro script from NDJSON",
|
|
13
|
-
)
|
|
14
|
-
repro.add_argument("--ndjson", required=True)
|
|
15
|
-
repro.add_argument("--launch-index", type=int, default=0)
|
|
16
|
-
repro.add_argument("--out", default="repro.py")
|
|
17
|
-
repro.add_argument("--execute", action="store_true")
|
|
18
|
-
repro.add_argument("--retries", type=int, default=0)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def maybe_handle_reproducer(args: argparse.Namespace) -> bool:
|
|
22
|
-
if getattr(args, "subcommand", None) != "repro":
|
|
23
|
-
return False
|
|
24
|
-
cfg = load_config()
|
|
25
|
-
provider = make_gemini_provider()
|
|
26
|
-
res = generate_from_ndjson(
|
|
27
|
-
args.ndjson,
|
|
28
|
-
provider,
|
|
29
|
-
launch_index=args.launch_index,
|
|
30
|
-
out_py=args.out,
|
|
31
|
-
execute=args.execute,
|
|
32
|
-
retries=args.retries,
|
|
33
|
-
temperature=cfg.temperature,
|
|
34
|
-
max_tokens=cfg.max_tokens,
|
|
35
|
-
)
|
|
36
|
-
print(res)
|
|
37
|
-
return True
|
tritonparse/reproducer/config.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
@dataclass
|
|
6
|
-
class ReproducerConfig:
|
|
7
|
-
project: str = os.getenv("GOOGLE_CLOUD_PROJECT", "")
|
|
8
|
-
location: str = os.getenv("GOOGLE_LOCATION", "us-central1")
|
|
9
|
-
model: str = os.getenv("TP_REPRO_MODEL", "gemini-2.5-pro")
|
|
10
|
-
temperature: float = float(os.getenv("TP_REPRO_TEMPERATURE", "0.1"))
|
|
11
|
-
max_tokens: int = int(os.getenv("TP_REPRO_MAX_TOKENS", "10240"))
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def load_config() -> ReproducerConfig:
|
|
15
|
-
return ReproducerConfig()
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
"""Provider factory for reproducer.
|
|
2
|
-
|
|
3
|
-
Currently supports Gemini only.
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from .config import load_config
|
|
7
|
-
from .providers.gemini import GeminiProvider
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def make_gemini_provider() -> GeminiProvider:
|
|
11
|
-
cfg = load_config()
|
|
12
|
-
return GeminiProvider(
|
|
13
|
-
project=cfg.project,
|
|
14
|
-
location=cfg.location,
|
|
15
|
-
model=cfg.model,
|
|
16
|
-
)
|
|
@@ -1,165 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
from typing import Any, Dict, List
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def _iter_events(path: str):
|
|
6
|
-
with open(path, "r", encoding="utf-8") as f:
|
|
7
|
-
for line in f:
|
|
8
|
-
line = line.strip()
|
|
9
|
-
if not line:
|
|
10
|
-
continue
|
|
11
|
-
try:
|
|
12
|
-
yield json.loads(line)
|
|
13
|
-
except json.JSONDecodeError:
|
|
14
|
-
# skip malformed lines
|
|
15
|
-
continue
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _index_compilations(events: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
19
|
-
idx = {}
|
|
20
|
-
for e in events:
|
|
21
|
-
if e.get("event_type") != "compilation":
|
|
22
|
-
continue
|
|
23
|
-
payload = e.get("payload") or {}
|
|
24
|
-
meta = payload.get("metadata") or {}
|
|
25
|
-
h = meta.get("hash")
|
|
26
|
-
if h:
|
|
27
|
-
idx[h] = e
|
|
28
|
-
return idx
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def _get_launches(events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
32
|
-
return [e for e in events if e.get("event_type") == "launch"]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _resolve_kernel_source(
|
|
36
|
-
launch: Dict[str, Any], comp_idx: Dict[str, Dict[str, Any]]
|
|
37
|
-
) -> str:
|
|
38
|
-
# In new format, launch has top-level compilation_metadata, not payload.*
|
|
39
|
-
comp_meta = (
|
|
40
|
-
launch.get("compilation_metadata")
|
|
41
|
-
or launch.get("payload", {}).get("compilation_metadata")
|
|
42
|
-
or {}
|
|
43
|
-
)
|
|
44
|
-
h = comp_meta.get("hash")
|
|
45
|
-
if not h:
|
|
46
|
-
return ""
|
|
47
|
-
comp = comp_idx.get(h, {})
|
|
48
|
-
payload = comp.get("payload") or {}
|
|
49
|
-
py = payload.get("python_source") or {}
|
|
50
|
-
return py.get("code", "")
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
|
54
|
-
packed = {}
|
|
55
|
-
for k, v in args.items():
|
|
56
|
-
t = v.get("type") if isinstance(v, dict) else None
|
|
57
|
-
if t == "tensor":
|
|
58
|
-
packed[k] = {
|
|
59
|
-
"type": "tensor",
|
|
60
|
-
"shape": v.get("shape") if isinstance(v, dict) else None,
|
|
61
|
-
"dtype": v.get("dtype") if isinstance(v, dict) else None,
|
|
62
|
-
"device": v.get("device") if isinstance(v, dict) else None,
|
|
63
|
-
"stride": v.get("stride") if isinstance(v, dict) else None,
|
|
64
|
-
"is_contiguous": (
|
|
65
|
-
v.get("is_contiguous") if isinstance(v, dict) else None
|
|
66
|
-
),
|
|
67
|
-
"numel": v.get("numel") if isinstance(v, dict) else None,
|
|
68
|
-
}
|
|
69
|
-
else:
|
|
70
|
-
# scalar / NoneType etc
|
|
71
|
-
if isinstance(v, dict):
|
|
72
|
-
packed[k] = {
|
|
73
|
-
"type": v.get("type"),
|
|
74
|
-
"value": v.get("value", v.get("repr")),
|
|
75
|
-
}
|
|
76
|
-
else:
|
|
77
|
-
packed[k] = {
|
|
78
|
-
"type": None,
|
|
79
|
-
"value": v,
|
|
80
|
-
}
|
|
81
|
-
return packed
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
# Sentinel and helper to normalize extracted argument values
|
|
85
|
-
_SKIP = object()
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def _decode_arg(raw: Any):
|
|
89
|
-
if not isinstance(raw, dict):
|
|
90
|
-
return raw
|
|
91
|
-
t = raw.get("type")
|
|
92
|
-
if t == "tensor":
|
|
93
|
-
return _SKIP
|
|
94
|
-
if t == "NoneType":
|
|
95
|
-
return None
|
|
96
|
-
return raw.get("value", raw.get("repr"))
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def build_context_bundle(ndjson_path: str, launch_index: int = 0) -> Dict[str, Any]:
|
|
100
|
-
events = list(_iter_events(ndjson_path))
|
|
101
|
-
launches = _get_launches(events)
|
|
102
|
-
if not launches:
|
|
103
|
-
raise RuntimeError("No launch events found in NDJSON.")
|
|
104
|
-
if launch_index < 0 or launch_index >= len(launches):
|
|
105
|
-
raise IndexError(
|
|
106
|
-
f"launch_index out of range: {launch_index} (total {len(launches)})"
|
|
107
|
-
)
|
|
108
|
-
launch = launches[launch_index]
|
|
109
|
-
comp_idx = _index_compilations(events)
|
|
110
|
-
kernel_source = _resolve_kernel_source(launch, comp_idx)
|
|
111
|
-
# find '@triton.jit' and slice the string
|
|
112
|
-
jit_marker = "@triton.jit"
|
|
113
|
-
jit_pos = kernel_source.find(jit_marker)
|
|
114
|
-
if jit_pos != -1:
|
|
115
|
-
kernel_source = kernel_source[jit_pos:]
|
|
116
|
-
|
|
117
|
-
# flatten launch fields (support both formats)
|
|
118
|
-
grid = launch.get("grid") or (launch.get("payload", {})).get("grid")
|
|
119
|
-
comp_meta = (
|
|
120
|
-
launch.get("compilation_metadata")
|
|
121
|
-
or (launch.get("payload", {})).get("compilation_metadata")
|
|
122
|
-
or {}
|
|
123
|
-
)
|
|
124
|
-
extracted_args = (
|
|
125
|
-
launch.get("extracted_args")
|
|
126
|
-
or (launch.get("payload", {})).get("extracted_args")
|
|
127
|
-
or {}
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
# compile metadata subset we care about
|
|
131
|
-
compile_block = {
|
|
132
|
-
"num_warps": comp_meta.get("num_warps"),
|
|
133
|
-
"num_stages": comp_meta.get("num_stages"),
|
|
134
|
-
"arch": comp_meta.get("arch"),
|
|
135
|
-
"backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
|
|
136
|
-
"triton_version": comp_meta.get("triton_version"),
|
|
137
|
-
"hash": comp_meta.get("hash"),
|
|
138
|
-
}
|
|
139
|
-
|
|
140
|
-
# kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
|
|
141
|
-
kwargs = {}
|
|
142
|
-
for k, v in extracted_args.items():
|
|
143
|
-
val = _decode_arg(v)
|
|
144
|
-
if val is _SKIP:
|
|
145
|
-
continue
|
|
146
|
-
kwargs[k] = val
|
|
147
|
-
|
|
148
|
-
# tensor args: only tensors
|
|
149
|
-
tensor_args = {
|
|
150
|
-
k: v
|
|
151
|
-
for k, v in extracted_args.items()
|
|
152
|
-
if isinstance(v, dict) and v.get("type") == "tensor"
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
bundle = {
|
|
156
|
-
"kernel_source": kernel_source,
|
|
157
|
-
"compile": compile_block,
|
|
158
|
-
"launch": {
|
|
159
|
-
"grid": grid,
|
|
160
|
-
"kwargs": kwargs,
|
|
161
|
-
},
|
|
162
|
-
"args": _pack_args(extracted_args),
|
|
163
|
-
"tensor_args": _pack_args(tensor_args),
|
|
164
|
-
}
|
|
165
|
-
return bundle
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from typing import Any, Dict
|
|
3
|
-
|
|
4
|
-
from .ingestion.ndjson import build_context_bundle
|
|
5
|
-
from .param_generator import generate_allocation_snippet, generate_kwargs_dict
|
|
6
|
-
from .prompts.loader import render_prompt
|
|
7
|
-
from .providers.base import LLMProvider
|
|
8
|
-
from .runtime.executor import run_python
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _excerpt(s: str, n: int = 160):
|
|
12
|
-
lines = s.splitlines()
|
|
13
|
-
return "\n".join(lines[:n])
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def generate_from_ndjson(
|
|
17
|
-
ndjson_path: str,
|
|
18
|
-
provider: LLMProvider,
|
|
19
|
-
*,
|
|
20
|
-
launch_index=0,
|
|
21
|
-
out_py="repro.py",
|
|
22
|
-
execute=False,
|
|
23
|
-
retries: int = 0,
|
|
24
|
-
**gen_kwargs,
|
|
25
|
-
) -> Dict[str, Any]:
|
|
26
|
-
bundle = build_context_bundle(ndjson_path, launch_index=launch_index)
|
|
27
|
-
# Augment bundle with pre-generated parameter allocation code to reduce LLM burden
|
|
28
|
-
allocation_snippet = generate_allocation_snippet(bundle)
|
|
29
|
-
kwargs_dict = generate_kwargs_dict(bundle)
|
|
30
|
-
context = {
|
|
31
|
-
**bundle,
|
|
32
|
-
"allocation_snippet": allocation_snippet,
|
|
33
|
-
"kwargs_dict": kwargs_dict,
|
|
34
|
-
}
|
|
35
|
-
system_prompt = render_prompt("system.txt", context)
|
|
36
|
-
user_prompt = render_prompt("generate_one_shot.txt", context)
|
|
37
|
-
|
|
38
|
-
code = provider.generate_code(system_prompt, user_prompt, **gen_kwargs)
|
|
39
|
-
Path(out_py).write_text(code, encoding="utf-8")
|
|
40
|
-
|
|
41
|
-
if not execute:
|
|
42
|
-
return {"path": out_py}
|
|
43
|
-
|
|
44
|
-
# Execute and optionally repair
|
|
45
|
-
rc, out, err = run_python(out_py)
|
|
46
|
-
attempt = 0
|
|
47
|
-
while rc != 0 and attempt < retries:
|
|
48
|
-
attempt += 1
|
|
49
|
-
# Build repair prompt
|
|
50
|
-
repair_ctx = {
|
|
51
|
-
"prev_code_excerpt": _excerpt(code, 200),
|
|
52
|
-
"error_text": err[-4000:] if err else "(no stderr)",
|
|
53
|
-
}
|
|
54
|
-
repair_prompt = render_prompt("repair_loop.txt", repair_ctx)
|
|
55
|
-
code = provider.generate_code(system_prompt, repair_prompt, **gen_kwargs)
|
|
56
|
-
Path(out_py).write_text(code, encoding="utf-8")
|
|
57
|
-
rc, out, err = run_python(out_py)
|
|
58
|
-
|
|
59
|
-
return {
|
|
60
|
-
"path": out_py,
|
|
61
|
-
"returncode": rc,
|
|
62
|
-
"stdout": out,
|
|
63
|
-
"stderr": err,
|
|
64
|
-
"retries_used": attempt,
|
|
65
|
-
}
|
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
"""Parameter generator: produce deterministic allocation code from a bundle.
|
|
2
|
-
|
|
3
|
-
This module reduces LLM burden by emitting Python code that:
|
|
4
|
-
- selects a device
|
|
5
|
-
- seeds RNG
|
|
6
|
-
- allocates tensors with the exact shape/dtype/device/stride
|
|
7
|
-
- prepares scalar/constexpr kwargs
|
|
8
|
-
|
|
9
|
-
The generated code is intended to be inserted into the final repro script.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
import json
|
|
13
|
-
from typing import Any, Dict, List, Optional
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _torch_dtype_expr(dtype: str) -> str:
|
|
17
|
-
mapping = {
|
|
18
|
-
"float16": "torch.float16",
|
|
19
|
-
"bfloat16": "torch.bfloat16",
|
|
20
|
-
"float32": "torch.float32",
|
|
21
|
-
"float": "torch.float32",
|
|
22
|
-
"float64": "torch.float64",
|
|
23
|
-
"half": "torch.float16",
|
|
24
|
-
"bf16": "torch.bfloat16",
|
|
25
|
-
"fp16": "torch.float16",
|
|
26
|
-
"fp32": "torch.float32",
|
|
27
|
-
"fp64": "torch.float64",
|
|
28
|
-
"int8": "torch.int8",
|
|
29
|
-
"int16": "torch.int16",
|
|
30
|
-
"int32": "torch.int32",
|
|
31
|
-
"int64": "torch.int64",
|
|
32
|
-
"long": "torch.int64",
|
|
33
|
-
"bool": "torch.bool",
|
|
34
|
-
}
|
|
35
|
-
return mapping.get(str(dtype).lower(), "torch.float32")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _compute_storage_numel(shape: List[int], stride: Optional[List[int]]) -> int:
|
|
39
|
-
if not shape:
|
|
40
|
-
return 1
|
|
41
|
-
if not stride:
|
|
42
|
-
# contiguous default
|
|
43
|
-
numel = 1
|
|
44
|
-
for s in shape:
|
|
45
|
-
numel *= int(s)
|
|
46
|
-
return numel
|
|
47
|
-
# minimal storage size (in elements) to support the given logical shape/stride
|
|
48
|
-
max_index = 0
|
|
49
|
-
for dim, (sz, st) in enumerate(zip(shape, stride)):
|
|
50
|
-
if sz <= 0:
|
|
51
|
-
continue
|
|
52
|
-
max_index = max(max_index, (int(sz) - 1) * int(st))
|
|
53
|
-
return int(max_index) + 1
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _emit_tensor_alloc(name: str, spec: Dict[str, Any]) -> str:
|
|
57
|
-
shape = spec.get("shape") or []
|
|
58
|
-
dtype = _torch_dtype_expr(spec.get("dtype"))
|
|
59
|
-
device = spec.get("device") or "cuda:0"
|
|
60
|
-
stride = spec.get("stride")
|
|
61
|
-
|
|
62
|
-
# ensure ints
|
|
63
|
-
shape = [int(s) for s in shape]
|
|
64
|
-
if stride is not None:
|
|
65
|
-
stride_list = [int(x) for x in stride]
|
|
66
|
-
else:
|
|
67
|
-
stride_list = None
|
|
68
|
-
|
|
69
|
-
lines: List[str] = []
|
|
70
|
-
# allocate backing storage
|
|
71
|
-
storage_numel = _compute_storage_numel(shape, stride_list)
|
|
72
|
-
lines.append(
|
|
73
|
-
f"# {name}: shape={shape}, dtype={dtype}, device={device}, stride={stride_list}"
|
|
74
|
-
)
|
|
75
|
-
lines.append(
|
|
76
|
-
f"_storage_{name} = torch.empty(({storage_numel},), dtype={dtype}, device=device)"
|
|
77
|
-
)
|
|
78
|
-
if stride_list:
|
|
79
|
-
# Create an as_strided view over the 1D storage
|
|
80
|
-
sizes_expr = str(tuple(shape))
|
|
81
|
-
strides_expr = str(tuple(stride_list))
|
|
82
|
-
lines.append(
|
|
83
|
-
f"{name} = _storage_{name}.as_strided(size={sizes_expr}, stride={strides_expr})"
|
|
84
|
-
)
|
|
85
|
-
else:
|
|
86
|
-
# contiguous allocation
|
|
87
|
-
size_expr = str(tuple(shape))
|
|
88
|
-
lines.append(f"{name} = torch.empty({size_expr}, dtype={dtype}, device=device)")
|
|
89
|
-
return "\n".join(lines)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def _emit_scalar(name: str, spec: Dict[str, Any]) -> str:
|
|
93
|
-
value = spec.get("value")
|
|
94
|
-
# Preserve JSON-serializable value as-is
|
|
95
|
-
return f"{name} = {json.dumps(value)}"
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def generate_allocation_snippet(bundle: Dict[str, Any]) -> str:
|
|
99
|
-
"""Generate a self-contained code snippet that:
|
|
100
|
-
- imports torch
|
|
101
|
-
- sets device
|
|
102
|
-
- seeds RNG
|
|
103
|
-
- allocates tensors and defines scalars for all args
|
|
104
|
-
Returns Python source as a string.
|
|
105
|
-
"""
|
|
106
|
-
tensor_args: Dict[str, Any] = bundle.get("tensor_args", {}) or {}
|
|
107
|
-
args_all: Dict[str, Any] = bundle.get("args", {}) or {}
|
|
108
|
-
|
|
109
|
-
# Pick device from any tensor arg, fallback to cuda:0
|
|
110
|
-
device = "cuda:0"
|
|
111
|
-
for spec in tensor_args.values():
|
|
112
|
-
dev = spec.get("device")
|
|
113
|
-
if dev:
|
|
114
|
-
device = str(dev)
|
|
115
|
-
break
|
|
116
|
-
|
|
117
|
-
lines: List[str] = []
|
|
118
|
-
lines.append("import torch")
|
|
119
|
-
lines.append(f"device = '{device}'")
|
|
120
|
-
lines.append("torch.manual_seed(0)")
|
|
121
|
-
lines.append("if torch.cuda.is_available(): torch.cuda.manual_seed_all(0)")
|
|
122
|
-
lines.append("")
|
|
123
|
-
|
|
124
|
-
# Emit tensors first for names with type==tensor in args_all
|
|
125
|
-
for name, spec in args_all.items():
|
|
126
|
-
if isinstance(spec, dict) and spec.get("type") == "tensor":
|
|
127
|
-
lines.append(_emit_tensor_alloc(name, spec))
|
|
128
|
-
lines.append("")
|
|
129
|
-
|
|
130
|
-
# Emit non-tensor scalars next
|
|
131
|
-
for name, spec in args_all.items():
|
|
132
|
-
if not isinstance(spec, dict) or spec.get("type") == "tensor":
|
|
133
|
-
continue
|
|
134
|
-
lines.append(_emit_scalar(name, spec))
|
|
135
|
-
return "\n".join(lines)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def generate_kwargs_dict(bundle: Dict[str, Any]) -> Dict[str, Any]:
|
|
139
|
-
"""Return a kwargs dict derived from bundle['launch']['kwargs'] suitable for kernel call."""
|
|
140
|
-
launch = bundle.get("launch", {}) or {}
|
|
141
|
-
kwargs = launch.get("kwargs", {}) or {}
|
|
142
|
-
return kwargs
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__all__ = []
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Any, Dict
|
|
4
|
-
|
|
5
|
-
PROMPTS_DIR = Path(__file__).parent
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def render_prompt(name: str, context: Dict[str, Any]) -> str:
|
|
9
|
-
text = (PROMPTS_DIR / name).read_text(encoding="utf-8")
|
|
10
|
-
# very simple {{key}} replacement for top-level keys; JSON for dicts
|
|
11
|
-
for k, v in context.items():
|
|
12
|
-
token = "{{ " + k + " }}"
|
|
13
|
-
if token in text:
|
|
14
|
-
if isinstance(v, (dict, list)):
|
|
15
|
-
text = text.replace(token, json.dumps(v, ensure_ascii=False, indent=2))
|
|
16
|
-
else:
|
|
17
|
-
text = text.replace(token, str(v))
|
|
18
|
-
return text
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__all__ = []
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional, Protocol
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class LLMProvider(Protocol):
|
|
5
|
-
def generate_code(
|
|
6
|
-
self,
|
|
7
|
-
system_prompt: str,
|
|
8
|
-
user_prompt: str,
|
|
9
|
-
*,
|
|
10
|
-
temperature: float = 0.2,
|
|
11
|
-
max_tokens: int = 8192,
|
|
12
|
-
stop: Optional[List[str]] = None,
|
|
13
|
-
extra: Optional[Dict[str, Any]] = None,
|
|
14
|
-
) -> str: ...
|
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import re
|
|
3
|
-
from typing import Any, Dict, List, Optional
|
|
4
|
-
|
|
5
|
-
from google.genai import Client
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def _extract_python_block(s: str) -> str:
|
|
9
|
-
m = re.search(r"""```python\s+(.*?)```""", s, flags=re.S)
|
|
10
|
-
return m.group(1).strip() if m else ""
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class GeminiProvider:
|
|
14
|
-
def __init__(
|
|
15
|
-
self, project: str, location: str = "us-central1", model: str = "gemini-2.5-pro"
|
|
16
|
-
):
|
|
17
|
-
# Expect GOOGLE_APPLICATIONS_CREDENTIALS to be set
|
|
18
|
-
if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
|
19
|
-
raise EnvironmentError("GOOGLE_APPLICATION_CREDENTIALS not set.")
|
|
20
|
-
self.client = Client(vertexai=True, project=project, location=location)
|
|
21
|
-
self.model = model
|
|
22
|
-
|
|
23
|
-
def generate_code(
|
|
24
|
-
self,
|
|
25
|
-
system_prompt: str,
|
|
26
|
-
user_prompt: str,
|
|
27
|
-
*,
|
|
28
|
-
temperature: float = 0.2,
|
|
29
|
-
max_tokens: int = 8192,
|
|
30
|
-
stop: Optional[List[str]] = None,
|
|
31
|
-
extra: Optional[Dict[str, Any]] = None,
|
|
32
|
-
) -> str:
|
|
33
|
-
# Gemini doesn't have a 'system' role in this SDK, prepend system to user
|
|
34
|
-
full_prompt = f"{system_prompt.strip()}\n\n---\n\n{user_prompt.strip()}"
|
|
35
|
-
resp = self.client.models.generate_content(
|
|
36
|
-
model=self.model,
|
|
37
|
-
contents=full_prompt,
|
|
38
|
-
config={
|
|
39
|
-
"temperature": temperature,
|
|
40
|
-
"max_output_tokens": max_tokens,
|
|
41
|
-
},
|
|
42
|
-
)
|
|
43
|
-
text = getattr(resp, "text", "") or ""
|
|
44
|
-
code = _extract_python_block(text) or text
|
|
45
|
-
if not code.strip():
|
|
46
|
-
raise RuntimeError(f"Empty response from Gemini. Raw: {text[:2000]}")
|
|
47
|
-
return code
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__all__ = []
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
import subprocess
|
|
2
|
-
import sys
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def run_python(path: str, timeout: int = 60):
|
|
6
|
-
p = subprocess.Popen(
|
|
7
|
-
[sys.executable, path],
|
|
8
|
-
stdout=subprocess.PIPE,
|
|
9
|
-
stderr=subprocess.PIPE,
|
|
10
|
-
text=True,
|
|
11
|
-
)
|
|
12
|
-
out, err = p.communicate(timeout=timeout)
|
|
13
|
-
return p.returncode, out, err
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: tritonparse
|
|
3
|
-
Version: 0.1.1
|
|
4
|
-
Project-URL: Homepage, https://github.com/meta-pytorch/tritonparse
|
|
5
|
-
Requires-Python: >=3.10
|
|
6
|
-
License-File: LICENSE
|
|
7
|
-
Requires-Dist: triton>3.3.1
|
|
8
|
-
Provides-Extra: test
|
|
9
|
-
Requires-Dist: coverage>=7.0.0; extra == "test"
|
|
10
|
-
Dynamic: license-file
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
tritonparse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
tritonparse/common.py,sha256=VWKFZJG7msIExMH0QNYCO8SHKqUBFtdRuLaCy_K8DFI,13725
|
|
3
|
-
tritonparse/event_diff.py,sha256=yOD6uNxLJroatfx2nEGr-erw24ObOrHU9P6V5pzr8do,4907
|
|
4
|
-
tritonparse/extract_source_mappings.py,sha256=Z6UxFj2cCE5NCWLQTYPKqUpLfbYhqP8xgCl5mvud9KI,1451
|
|
5
|
-
tritonparse/ir_parser.py,sha256=1j1tP9jpUN7wH3e01bKUkUPgTMlNXUdp8LKRCC-WTro,9324
|
|
6
|
-
tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
|
|
7
|
-
tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
|
|
8
|
-
tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
|
|
9
|
-
tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
|
|
10
|
-
tritonparse/structured_logging.py,sha256=qWGkAr2oZT8Adxj-EfQqQdX2l-jq9xKi7WhBecYq2bg,39600
|
|
11
|
-
tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
|
|
12
|
-
tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
|
|
13
|
-
tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
|
|
14
|
-
tritonparse/reproducer/__init__.py,sha256=VcCpYVUUmclWotkQmPLlDu5iFOUE4N-4FzcbzXwIow0,773
|
|
15
|
-
tritonparse/reproducer/__main__.py,sha256=ydLMbWx7SFlvAb1erObvKcJ-uxhHNShGRHRrZO6-5ww,2266
|
|
16
|
-
tritonparse/reproducer/cli.py,sha256=nnMgdT4tQzBWjPowOy2a_5QRsnsTMEAA3uegYpLEyRE,1165
|
|
17
|
-
tritonparse/reproducer/config.py,sha256=-hmE5ZqEtYo2WKjXbwMi6k6XzzfQZAbL50UURPvcF3A,478
|
|
18
|
-
tritonparse/reproducer/factory.py,sha256=sFcIjIayfHAqPqMVT8Rnsz9tpMmQXBzoOlKprS1P_1g,341
|
|
19
|
-
tritonparse/reproducer/orchestrator.py,sha256=iD7zZZHE4FU3nNOwNV9SUY2WUcpv_Amg0SvnRxrseEQ,2045
|
|
20
|
-
tritonparse/reproducer/param_generator.py,sha256=m-C_Z1TLd1ZX49EpsWELVfB6tkwOfi-ZHma7wXwz2g4,4654
|
|
21
|
-
tritonparse/reproducer/ingestion/__init__.py,sha256=2AQHxWlUl5JXM4a8F033wzxVnjCVPBEf-4H99kep-OA,99
|
|
22
|
-
tritonparse/reproducer/ingestion/ndjson.py,sha256=_E_dXXjxu438OYomQ1zFFk3jV9Wr1jNoXHiP2gJG7_4,5172
|
|
23
|
-
tritonparse/reproducer/prompts/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
|
|
24
|
-
tritonparse/reproducer/prompts/loader.py,sha256=n6Of98eEXNz9mI7ZH073X5FihNZD7tI-ehfjN_4yEl0,610
|
|
25
|
-
tritonparse/reproducer/providers/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
|
|
26
|
-
tritonparse/reproducer/providers/base.py,sha256=DgP_4AdrEf48kstOfBJFvK3pndcHH0vRUGjp6k1bdsY,362
|
|
27
|
-
tritonparse/reproducer/providers/gemini.py,sha256=VlOCdTGRTQdr3c2HMclKFIk-133puGSAjhK_6m6Zj9g,1609
|
|
28
|
-
tritonparse/reproducer/runtime/__init__.py,sha256=da1PTClDMl-IBkrSvq6JC1lnS-K_BASzCvxVhNxN5Ls,13
|
|
29
|
-
tritonparse/reproducer/runtime/executor.py,sha256=AqBFnoEqURoMGDdLC2G3WpHIP3Y4wWGJHEZrjS-NQFM,304
|
|
30
|
-
tritonparse/reproducer/utils/io.py,sha256=95NF9QCGawl-5p5c5yCQHynVBNKS_B_7nIrqnRvAt-E,200
|
|
31
|
-
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
-
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
33
|
-
tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
|
|
34
|
-
tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
|
|
35
|
-
tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
|
|
36
|
-
tritonparse-0.1.1.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
37
|
-
tritonparse-0.1.1.dist-info/METADATA,sha256=qETEInGJRT7fzf-Rl8cAf6QnEa5eJgiYo4rbwBA63yc,287
|
|
38
|
-
tritonparse-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
39
|
-
tritonparse-0.1.1.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
40
|
-
tritonparse-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|