tritonparse 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (40) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/common.py +409 -0
  3. tritonparse/event_diff.py +120 -0
  4. tritonparse/extract_source_mappings.py +49 -0
  5. tritonparse/ir_parser.py +220 -0
  6. tritonparse/mapper.py +100 -0
  7. tritonparse/reproducer/__init__.py +21 -0
  8. tritonparse/reproducer/__main__.py +81 -0
  9. tritonparse/reproducer/cli.py +37 -0
  10. tritonparse/reproducer/config.py +15 -0
  11. tritonparse/reproducer/factory.py +16 -0
  12. tritonparse/reproducer/ingestion/__init__.py +6 -0
  13. tritonparse/reproducer/ingestion/ndjson.py +165 -0
  14. tritonparse/reproducer/orchestrator.py +65 -0
  15. tritonparse/reproducer/param_generator.py +142 -0
  16. tritonparse/reproducer/prompts/__init__.py +1 -0
  17. tritonparse/reproducer/prompts/loader.py +18 -0
  18. tritonparse/reproducer/providers/__init__.py +1 -0
  19. tritonparse/reproducer/providers/base.py +14 -0
  20. tritonparse/reproducer/providers/gemini.py +47 -0
  21. tritonparse/reproducer/runtime/__init__.py +1 -0
  22. tritonparse/reproducer/runtime/executor.py +13 -0
  23. tritonparse/reproducer/utils/io.py +6 -0
  24. tritonparse/shared_vars.py +9 -0
  25. tritonparse/source_type.py +56 -0
  26. tritonparse/sourcemap_utils.py +72 -0
  27. tritonparse/structured_logging.py +1046 -0
  28. tritonparse/tools/__init__.py +0 -0
  29. tritonparse/tools/decompress_bin_ndjson.py +118 -0
  30. tritonparse/tools/format_fix.py +149 -0
  31. tritonparse/tools/load_tensor.py +58 -0
  32. tritonparse/tools/prettify_ndjson.py +315 -0
  33. tritonparse/tp_logger.py +9 -0
  34. tritonparse/trace_processor.py +331 -0
  35. tritonparse/utils.py +156 -0
  36. tritonparse-0.1.1.dist-info/METADATA +10 -0
  37. tritonparse-0.1.1.dist-info/RECORD +40 -0
  38. tritonparse-0.1.1.dist-info/WHEEL +5 -0
  39. tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
  40. tritonparse-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,65 @@
1
+ from pathlib import Path
2
+ from typing import Any, Dict
3
+
4
+ from .ingestion.ndjson import build_context_bundle
5
+ from .param_generator import generate_allocation_snippet, generate_kwargs_dict
6
+ from .prompts.loader import render_prompt
7
+ from .providers.base import LLMProvider
8
+ from .runtime.executor import run_python
9
+
10
+
11
+ def _excerpt(s: str, n: int = 160):
12
+ lines = s.splitlines()
13
+ return "\n".join(lines[:n])
14
+
15
+
16
+ def generate_from_ndjson(
17
+ ndjson_path: str,
18
+ provider: LLMProvider,
19
+ *,
20
+ launch_index=0,
21
+ out_py="repro.py",
22
+ execute=False,
23
+ retries: int = 0,
24
+ **gen_kwargs,
25
+ ) -> Dict[str, Any]:
26
+ bundle = build_context_bundle(ndjson_path, launch_index=launch_index)
27
+ # Augment bundle with pre-generated parameter allocation code to reduce LLM burden
28
+ allocation_snippet = generate_allocation_snippet(bundle)
29
+ kwargs_dict = generate_kwargs_dict(bundle)
30
+ context = {
31
+ **bundle,
32
+ "allocation_snippet": allocation_snippet,
33
+ "kwargs_dict": kwargs_dict,
34
+ }
35
+ system_prompt = render_prompt("system.txt", context)
36
+ user_prompt = render_prompt("generate_one_shot.txt", context)
37
+
38
+ code = provider.generate_code(system_prompt, user_prompt, **gen_kwargs)
39
+ Path(out_py).write_text(code, encoding="utf-8")
40
+
41
+ if not execute:
42
+ return {"path": out_py}
43
+
44
+ # Execute and optionally repair
45
+ rc, out, err = run_python(out_py)
46
+ attempt = 0
47
+ while rc != 0 and attempt < retries:
48
+ attempt += 1
49
+ # Build repair prompt
50
+ repair_ctx = {
51
+ "prev_code_excerpt": _excerpt(code, 200),
52
+ "error_text": err[-4000:] if err else "(no stderr)",
53
+ }
54
+ repair_prompt = render_prompt("repair_loop.txt", repair_ctx)
55
+ code = provider.generate_code(system_prompt, repair_prompt, **gen_kwargs)
56
+ Path(out_py).write_text(code, encoding="utf-8")
57
+ rc, out, err = run_python(out_py)
58
+
59
+ return {
60
+ "path": out_py,
61
+ "returncode": rc,
62
+ "stdout": out,
63
+ "stderr": err,
64
+ "retries_used": attempt,
65
+ }
@@ -0,0 +1,142 @@
1
+ """Parameter generator: produce deterministic allocation code from a bundle.
2
+
3
+ This module reduces LLM burden by emitting Python code that:
4
+ - selects a device
5
+ - seeds RNG
6
+ - allocates tensors with the exact shape/dtype/device/stride
7
+ - prepares scalar/constexpr kwargs
8
+
9
+ The generated code is intended to be inserted into the final repro script.
10
+ """
11
+
12
+ import json
13
+ from typing import Any, Dict, List, Optional
14
+
15
+
16
+ def _torch_dtype_expr(dtype: str) -> str:
17
+ mapping = {
18
+ "float16": "torch.float16",
19
+ "bfloat16": "torch.bfloat16",
20
+ "float32": "torch.float32",
21
+ "float": "torch.float32",
22
+ "float64": "torch.float64",
23
+ "half": "torch.float16",
24
+ "bf16": "torch.bfloat16",
25
+ "fp16": "torch.float16",
26
+ "fp32": "torch.float32",
27
+ "fp64": "torch.float64",
28
+ "int8": "torch.int8",
29
+ "int16": "torch.int16",
30
+ "int32": "torch.int32",
31
+ "int64": "torch.int64",
32
+ "long": "torch.int64",
33
+ "bool": "torch.bool",
34
+ }
35
+ return mapping.get(str(dtype).lower(), "torch.float32")
36
+
37
+
38
+ def _compute_storage_numel(shape: List[int], stride: Optional[List[int]]) -> int:
39
+ if not shape:
40
+ return 1
41
+ if not stride:
42
+ # contiguous default
43
+ numel = 1
44
+ for s in shape:
45
+ numel *= int(s)
46
+ return numel
47
+ # minimal storage size (in elements) to support the given logical shape/stride
48
+ max_index = 0
49
+ for dim, (sz, st) in enumerate(zip(shape, stride)):
50
+ if sz <= 0:
51
+ continue
52
+ max_index = max(max_index, (int(sz) - 1) * int(st))
53
+ return int(max_index) + 1
54
+
55
+
56
+ def _emit_tensor_alloc(name: str, spec: Dict[str, Any]) -> str:
57
+ shape = spec.get("shape") or []
58
+ dtype = _torch_dtype_expr(spec.get("dtype"))
59
+ device = spec.get("device") or "cuda:0"
60
+ stride = spec.get("stride")
61
+
62
+ # ensure ints
63
+ shape = [int(s) for s in shape]
64
+ if stride is not None:
65
+ stride_list = [int(x) for x in stride]
66
+ else:
67
+ stride_list = None
68
+
69
+ lines: List[str] = []
70
+ # allocate backing storage
71
+ storage_numel = _compute_storage_numel(shape, stride_list)
72
+ lines.append(
73
+ f"# {name}: shape={shape}, dtype={dtype}, device={device}, stride={stride_list}"
74
+ )
75
+ lines.append(
76
+ f"_storage_{name} = torch.empty(({storage_numel},), dtype={dtype}, device=device)"
77
+ )
78
+ if stride_list:
79
+ # Create an as_strided view over the 1D storage
80
+ sizes_expr = str(tuple(shape))
81
+ strides_expr = str(tuple(stride_list))
82
+ lines.append(
83
+ f"{name} = _storage_{name}.as_strided(size={sizes_expr}, stride={strides_expr})"
84
+ )
85
+ else:
86
+ # contiguous allocation
87
+ size_expr = str(tuple(shape))
88
+ lines.append(f"{name} = torch.empty({size_expr}, dtype={dtype}, device=device)")
89
+ return "\n".join(lines)
90
+
91
+
92
+ def _emit_scalar(name: str, spec: Dict[str, Any]) -> str:
93
+ value = spec.get("value")
94
+ # Preserve JSON-serializable value as-is
95
+ return f"{name} = {json.dumps(value)}"
96
+
97
+
98
+ def generate_allocation_snippet(bundle: Dict[str, Any]) -> str:
99
+ """Generate a self-contained code snippet that:
100
+ - imports torch
101
+ - sets device
102
+ - seeds RNG
103
+ - allocates tensors and defines scalars for all args
104
+ Returns Python source as a string.
105
+ """
106
+ tensor_args: Dict[str, Any] = bundle.get("tensor_args", {}) or {}
107
+ args_all: Dict[str, Any] = bundle.get("args", {}) or {}
108
+
109
+ # Pick device from any tensor arg, fallback to cuda:0
110
+ device = "cuda:0"
111
+ for spec in tensor_args.values():
112
+ dev = spec.get("device")
113
+ if dev:
114
+ device = str(dev)
115
+ break
116
+
117
+ lines: List[str] = []
118
+ lines.append("import torch")
119
+ lines.append(f"device = '{device}'")
120
+ lines.append("torch.manual_seed(0)")
121
+ lines.append("if torch.cuda.is_available(): torch.cuda.manual_seed_all(0)")
122
+ lines.append("")
123
+
124
+ # Emit tensors first for names with type==tensor in args_all
125
+ for name, spec in args_all.items():
126
+ if isinstance(spec, dict) and spec.get("type") == "tensor":
127
+ lines.append(_emit_tensor_alloc(name, spec))
128
+ lines.append("")
129
+
130
+ # Emit non-tensor scalars next
131
+ for name, spec in args_all.items():
132
+ if not isinstance(spec, dict) or spec.get("type") == "tensor":
133
+ continue
134
+ lines.append(_emit_scalar(name, spec))
135
+ return "\n".join(lines)
136
+
137
+
138
+ def generate_kwargs_dict(bundle: Dict[str, Any]) -> Dict[str, Any]:
139
+ """Return a kwargs dict derived from bundle['launch']['kwargs'] suitable for kernel call."""
140
+ launch = bundle.get("launch", {}) or {}
141
+ kwargs = launch.get("kwargs", {}) or {}
142
+ return kwargs
@@ -0,0 +1 @@
1
+ __all__ = []
@@ -0,0 +1,18 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict
4
+
5
+ PROMPTS_DIR = Path(__file__).parent
6
+
7
+
8
+ def render_prompt(name: str, context: Dict[str, Any]) -> str:
9
+ text = (PROMPTS_DIR / name).read_text(encoding="utf-8")
10
+ # very simple {{key}} replacement for top-level keys; JSON for dicts
11
+ for k, v in context.items():
12
+ token = "{{ " + k + " }}"
13
+ if token in text:
14
+ if isinstance(v, (dict, list)):
15
+ text = text.replace(token, json.dumps(v, ensure_ascii=False, indent=2))
16
+ else:
17
+ text = text.replace(token, str(v))
18
+ return text
@@ -0,0 +1 @@
1
+ __all__ = []
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict, List, Optional, Protocol
2
+
3
+
4
+ class LLMProvider(Protocol):
5
+ def generate_code(
6
+ self,
7
+ system_prompt: str,
8
+ user_prompt: str,
9
+ *,
10
+ temperature: float = 0.2,
11
+ max_tokens: int = 8192,
12
+ stop: Optional[List[str]] = None,
13
+ extra: Optional[Dict[str, Any]] = None,
14
+ ) -> str: ...
@@ -0,0 +1,47 @@
1
+ import os
2
+ import re
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from google.genai import Client
6
+
7
+
8
+ def _extract_python_block(s: str) -> str:
9
+ m = re.search(r"""```python\s+(.*?)```""", s, flags=re.S)
10
+ return m.group(1).strip() if m else ""
11
+
12
+
13
+ class GeminiProvider:
14
+ def __init__(
15
+ self, project: str, location: str = "us-central1", model: str = "gemini-2.5-pro"
16
+ ):
17
+ # Expect GOOGLE_APPLICATIONS_CREDENTIALS to be set
18
+ if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
19
+ raise EnvironmentError("GOOGLE_APPLICATION_CREDENTIALS not set.")
20
+ self.client = Client(vertexai=True, project=project, location=location)
21
+ self.model = model
22
+
23
+ def generate_code(
24
+ self,
25
+ system_prompt: str,
26
+ user_prompt: str,
27
+ *,
28
+ temperature: float = 0.2,
29
+ max_tokens: int = 8192,
30
+ stop: Optional[List[str]] = None,
31
+ extra: Optional[Dict[str, Any]] = None,
32
+ ) -> str:
33
+ # Gemini doesn't have a 'system' role in this SDK, prepend system to user
34
+ full_prompt = f"{system_prompt.strip()}\n\n---\n\n{user_prompt.strip()}"
35
+ resp = self.client.models.generate_content(
36
+ model=self.model,
37
+ contents=full_prompt,
38
+ config={
39
+ "temperature": temperature,
40
+ "max_output_tokens": max_tokens,
41
+ },
42
+ )
43
+ text = getattr(resp, "text", "") or ""
44
+ code = _extract_python_block(text) or text
45
+ if not code.strip():
46
+ raise RuntimeError(f"Empty response from Gemini. Raw: {text[:2000]}")
47
+ return code
@@ -0,0 +1 @@
1
+ __all__ = []
@@ -0,0 +1,13 @@
1
+ import subprocess
2
+ import sys
3
+
4
+
5
+ def run_python(path: str, timeout: int = 60):
6
+ p = subprocess.Popen(
7
+ [sys.executable, path],
8
+ stdout=subprocess.PIPE,
9
+ stderr=subprocess.PIPE,
10
+ text=True,
11
+ )
12
+ out, err = p.communicate(timeout=timeout)
13
+ return p.returncode, out, err
@@ -0,0 +1,6 @@
1
+ from pathlib import Path
2
+
3
+
4
+ def write_text(path: str, content: str, *, encoding="utf-8"):
5
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
6
+ Path(path).write_text(content, encoding=encoding)
@@ -0,0 +1,9 @@
1
+ # We'd like to sperate structured logging module and tritonparse module as much as possible. So, put the shared variables here.
2
+ import os
3
+
4
+ # The compilation information will be stored to /logs/DEFAULT_TRACE_FILE_PREFIX by default
5
+ # unless other flags disable or set another store. Add USER to avoid permission issues in shared servers.
6
+ DEFAULT_TRACE_FILE_PREFIX = (
7
+ f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
8
+ )
9
+ DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER = "dedicated_log_triton_trace_"
@@ -0,0 +1,56 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Tuple
6
+
7
+
8
+ class SourceType(str, Enum):
9
+ """Enumeration of supported source types for OSS only."""
10
+
11
+ LOCAL = "local"
12
+ LOCAL_FILE = "local_file"
13
+
14
+ @classmethod
15
+ def _missing_(cls, value: object) -> "SourceType":
16
+ """
17
+ Handle unknown source types by raising a ValueError.
18
+
19
+ Args:
20
+ value: The unknown value that was attempted to be used as a SourceType
21
+
22
+ Returns:
23
+ Never returns, always raises ValueError
24
+ """
25
+ valid_types = [e.value for e in cls]
26
+ raise ValueError(
27
+ f"Invalid source type '{value}'. Valid types are: {', '.join(valid_types)}"
28
+ )
29
+
30
+
31
+ class Source:
32
+ """Represents a source of logs to parse."""
33
+
34
+ def __init__(self, source_str: str, verbose: bool = False):
35
+ """
36
+ Initialize a Source object by parsing the source string.
37
+
38
+ Args:
39
+ source_str: String representing the source
40
+ verbose: Whether to print verbose information
41
+ """
42
+ self.source_str = source_str
43
+ self.verbose = verbose
44
+ self.type, self.value = self._parse_source()
45
+
46
+ def _parse_source(self) -> Tuple[SourceType, str]:
47
+ # Check if it's a local path
48
+ path = Path(self.source_str)
49
+ if path.is_dir():
50
+ return SourceType.LOCAL, str(path.absolute())
51
+ elif path.is_file():
52
+ return SourceType.LOCAL_FILE, str(path.absolute())
53
+ else:
54
+ raise ValueError(
55
+ f"Source '{self.source_str}' is not a valid directory or file"
56
+ )
@@ -0,0 +1,72 @@
1
+ from typing import Any, Dict, List
2
+
3
+
4
+ def get_file_extension(filename: str) -> str:
5
+ """
6
+ Get the file extension from a given filename or return the filename itself if it has no extension.
7
+
8
+ Args:
9
+ filename (str): The filename or file extension.
10
+
11
+ Returns:
12
+ str: The file extension or the filename itself if no extension is present.
13
+ """
14
+ # Split the filename by '.' and return the last part if it exists
15
+ parts = filename.split(".")
16
+ return parts[-1] if len(parts) > 1 else filename
17
+
18
+
19
+ def _flatten_dict(
20
+ d: Dict[str, Any], parent_key: str = "", sep: str = "."
21
+ ) -> Dict[str, Any]:
22
+ """
23
+ Flattens a nested dictionary.
24
+ """
25
+ items = []
26
+ for k, v in d.items():
27
+ new_key = parent_key + sep + k if parent_key else k
28
+ if isinstance(v, dict):
29
+ items.extend(_flatten_dict(v, new_key, sep=sep).items())
30
+ else:
31
+ items.append((new_key, v))
32
+ return dict(items)
33
+
34
+
35
+ def _unflatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
36
+ """
37
+ Unflattens a dictionary with delimited keys.
38
+ """
39
+ result = {}
40
+ for key, value in d.items():
41
+ parts = key.split(sep)
42
+ d_ref = result
43
+ for part in parts[:-1]:
44
+ if part not in d_ref:
45
+ d_ref[part] = {}
46
+ d_ref = d_ref[part]
47
+ d_ref[parts[-1]] = value
48
+ return result
49
+
50
+
51
+ def _to_ranges(indices: List[int]) -> List[Dict[str, int]]:
52
+ """
53
+ Converts a sorted list of indices into a list of continuous ranges.
54
+ e.g., [0, 1, 2, 5, 6, 8] -> [{'start': 0, 'end': 2}, {'start': 5, 'end': 6}, {'start': 8, 'end': 8}]
55
+ """
56
+ if not indices:
57
+ return []
58
+
59
+ indices = sorted(indices)
60
+ ranges = []
61
+ start = indices[0]
62
+ end = indices[0]
63
+
64
+ for i in range(1, len(indices)):
65
+ if indices[i] == end + 1:
66
+ end = indices[i]
67
+ else:
68
+ ranges.append({"start": start, "end": end})
69
+ start = end = indices[i]
70
+
71
+ ranges.append({"start": start, "end": end})
72
+ return ranges