glasstrace 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glasstrace-0.2.0/.github/workflows/ci.yml +33 -0
- glasstrace-0.2.0/.gitignore +223 -0
- glasstrace-0.2.0/LICENSE +21 -0
- glasstrace-0.2.0/PKG-INFO +77 -0
- glasstrace-0.2.0/README.md +50 -0
- glasstrace-0.2.0/examples/basic.py +32 -0
- glasstrace-0.2.0/figures/01_cpu_baseline.png +0 -0
- glasstrace-0.2.0/figures/02_coldstart_gpu.png +0 -0
- glasstrace-0.2.0/figures/03_gpu_warmed.png +0 -0
- glasstrace-0.2.0/figures/04_prefill_decode_split.png +0 -0
- glasstrace-0.2.0/figures/05_smollm2_profile.png +0 -0
- glasstrace-0.2.0/figures/beforeAfterSplit.pdf +0 -0
- glasstrace-0.2.0/figures/benchmark_graphic.png +0 -0
- glasstrace-0.2.0/figures/blog/blog_post.md +0 -0
- glasstrace-0.2.0/glasstrace/__init__.py +7 -0
- glasstrace-0.2.0/glasstrace/hooks.py +190 -0
- glasstrace-0.2.0/glasstrace/profiler.py +61 -0
- glasstrace-0.2.0/glasstrace/report.py +133 -0
- glasstrace-0.2.0/pyproject.toml +48 -0
- glasstrace-0.2.0/scripts/benchmark_plot.py +80 -0
- glasstrace-0.2.0/tests/test_smoke.py +103 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
strategy:
|
|
13
|
+
matrix:
|
|
14
|
+
python-version: ["3.11", "3.12"]
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v5
|
|
18
|
+
|
|
19
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
20
|
+
uses: actions/setup-python@v6
|
|
21
|
+
with:
|
|
22
|
+
python-version: ${{ matrix.python-version }}
|
|
23
|
+
|
|
24
|
+
- name: Install package
|
|
25
|
+
run: |
|
|
26
|
+
python -m pip install --upgrade pip
|
|
27
|
+
pip install -e ".[dev]"
|
|
28
|
+
|
|
29
|
+
- name: Lint
|
|
30
|
+
run: ruff check glasstrace tests
|
|
31
|
+
|
|
32
|
+
- name: Test
|
|
33
|
+
run: pytest tests/ -v
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[codz]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py.cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
# Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# UV
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# uv.lock
|
|
102
|
+
|
|
103
|
+
# poetry
|
|
104
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
105
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
106
|
+
# commonly ignored for libraries.
|
|
107
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
108
|
+
# poetry.lock
|
|
109
|
+
# poetry.toml
|
|
110
|
+
|
|
111
|
+
# pdm
|
|
112
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
113
|
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
|
114
|
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
|
115
|
+
# pdm.lock
|
|
116
|
+
# pdm.toml
|
|
117
|
+
.pdm-python
|
|
118
|
+
.pdm-build/
|
|
119
|
+
|
|
120
|
+
# pixi
|
|
121
|
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
|
122
|
+
# pixi.lock
|
|
123
|
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
|
124
|
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
|
125
|
+
.pixi
|
|
126
|
+
|
|
127
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
128
|
+
__pypackages__/
|
|
129
|
+
|
|
130
|
+
# Celery stuff
|
|
131
|
+
celerybeat-schedule
|
|
132
|
+
celerybeat.pid
|
|
133
|
+
|
|
134
|
+
# Redis
|
|
135
|
+
*.rdb
|
|
136
|
+
*.aof
|
|
137
|
+
*.pid
|
|
138
|
+
|
|
139
|
+
# RabbitMQ
|
|
140
|
+
mnesia/
|
|
141
|
+
rabbitmq/
|
|
142
|
+
rabbitmq-data/
|
|
143
|
+
|
|
144
|
+
# ActiveMQ
|
|
145
|
+
activemq-data/
|
|
146
|
+
|
|
147
|
+
# SageMath parsed files
|
|
148
|
+
*.sage.py
|
|
149
|
+
|
|
150
|
+
# Environments
|
|
151
|
+
.env
|
|
152
|
+
.envrc
|
|
153
|
+
.venv
|
|
154
|
+
env/
|
|
155
|
+
venv/
|
|
156
|
+
ENV/
|
|
157
|
+
env.bak/
|
|
158
|
+
venv.bak/
|
|
159
|
+
|
|
160
|
+
# Spyder project settings
|
|
161
|
+
.spyderproject
|
|
162
|
+
.spyproject
|
|
163
|
+
|
|
164
|
+
# Rope project settings
|
|
165
|
+
.ropeproject
|
|
166
|
+
|
|
167
|
+
# mkdocs documentation
|
|
168
|
+
/site
|
|
169
|
+
|
|
170
|
+
# mypy
|
|
171
|
+
.mypy_cache/
|
|
172
|
+
.dmypy.json
|
|
173
|
+
dmypy.json
|
|
174
|
+
|
|
175
|
+
# Pyre type checker
|
|
176
|
+
.pyre/
|
|
177
|
+
|
|
178
|
+
# pytype static type analyzer
|
|
179
|
+
.pytype/
|
|
180
|
+
|
|
181
|
+
# Cython debug symbols
|
|
182
|
+
cython_debug/
|
|
183
|
+
|
|
184
|
+
# PyCharm
|
|
185
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
186
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
187
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
188
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
189
|
+
# .idea/
|
|
190
|
+
|
|
191
|
+
# Abstra
|
|
192
|
+
# Abstra is an AI-powered process automation framework.
|
|
193
|
+
# Ignore directories containing user credentials, local state, and settings.
|
|
194
|
+
# Learn more at https://abstra.io/docs
|
|
195
|
+
.abstra/
|
|
196
|
+
|
|
197
|
+
# Visual Studio Code
|
|
198
|
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
|
199
|
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
|
200
|
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
|
201
|
+
# you could uncomment the following to ignore the entire vscode folder
|
|
202
|
+
# .vscode/
|
|
203
|
+
# Temporary file for partial code execution
|
|
204
|
+
tempCodeRunnerFile.py
|
|
205
|
+
|
|
206
|
+
# Ruff stuff:
|
|
207
|
+
.ruff_cache/
|
|
208
|
+
|
|
209
|
+
# PyPI configuration file
|
|
210
|
+
.pypirc
|
|
211
|
+
|
|
212
|
+
# Marimo
|
|
213
|
+
marimo/_static/
|
|
214
|
+
marimo/_lsp/
|
|
215
|
+
__marimo__/
|
|
216
|
+
|
|
217
|
+
# Streamlit
|
|
218
|
+
.streamlit/secrets.toml
|
|
219
|
+
|
|
220
|
+
# Tooling caches
|
|
221
|
+
__pycache__/
|
|
222
|
+
.pytest_cache/
|
|
223
|
+
.ruff_cache/
|
glasstrace-0.2.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Manu Jawahar
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: glasstrace
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Per-layer latency and memory profiler for transformer inference.
|
|
5
|
+
Project-URL: Homepage, https://github.com/manu-j3400/glasstrace
|
|
6
|
+
Project-URL: Repository, https://github.com/manu-j3400/glasstrace
|
|
7
|
+
Project-URL: Issues, https://github.com/manu-j3400/glasstrace/issues
|
|
8
|
+
Author-email: Manu <therealmanujawahar@gmail.com>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: inference,llm,profiler,pytorch,transformers
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.11
|
|
20
|
+
Requires-Dist: tabulate>=0.9
|
|
21
|
+
Requires-Dist: torch>=2.0
|
|
22
|
+
Requires-Dist: transformers>=4.40
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
25
|
+
Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# glasstrace
|
|
29
|
+
[](https://github.com/manu-j3400/glasstrace/actions/workflows/ci.yml)
|
|
30
|
+
|
|
31
|
+
> Per-layer latency and memory profiler for transformer inference.
|
|
32
|
+
|
|
33
|
+
`glasstrace` shows you where time actually goes inside your LLM. Decomposes inference cost by layer, operation, and inference phase (prefill vs decode).
|
|
34
|
+
|
|
35
|
+
## Why
|
|
36
|
+
|
|
37
|
+
When you call `model.generate()`, you get a number: total latency. That's not enough to make anything faster. `glasstrace` turns the black box into a measured picture: which layers are slow, where memory pressure lives, and what changes when you tweak batch size or sequence length.
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install git+https://github.com/manu-j3400/glasstrace.git
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
PyPI release coming with v1.0.
|
|
46
|
+
|
|
47
|
+
## Usage
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
import glasstrace
|
|
51
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
52
|
+
|
|
53
|
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
54
|
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
55
|
+
inputs = tokenizer("Hello, world!", return_tensors="pt")
|
|
56
|
+
|
|
57
|
+
with glasstrace.profile(model) as p:
|
|
58
|
+
out = model.generate(**inputs, max_new_tokens=50)
|
|
59
|
+
|
|
60
|
+
print(p.report())
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Status
|
|
64
|
+
|
|
65
|
+
**v0.1.0 — alpha.** Works on Qwen 2.5 0.5B and Llama 3.2 1B with CUDA. Tracks `nn.Linear` and `nn.LayerNorm` modules. Memory tracking, HTML reports, and broader model coverage planned for v0.2.
|
|
66
|
+
|
|
67
|
+
## Roadmap
|
|
68
|
+
|
|
69
|
+
- [x] v0.1 — Per-module CUDA timing, text-table report
|
|
70
|
+
- [x] v0.2 — Prefill vs decode split, memory tracking, HTML report
|
|
71
|
+
- [ ] v0.3 — Multi-model tested coverage, CLI
|
|
72
|
+
- [ ] v0.4 — Comparative analyses across Llama, Qwen, Phi (blog post)
|
|
73
|
+
- [ ] v1.0 — PyPI release, docs, demo video
|
|
74
|
+
|
|
75
|
+
## License
|
|
76
|
+
|
|
77
|
+
MIT
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# glasstrace
|
|
2
|
+
[](https://github.com/manu-j3400/glasstrace/actions/workflows/ci.yml)
|
|
3
|
+
|
|
4
|
+
> Per-layer latency and memory profiler for transformer inference.
|
|
5
|
+
|
|
6
|
+
`glasstrace` shows you where time actually goes inside your LLM. Decomposes inference cost by layer, operation, and inference phase (prefill vs decode).
|
|
7
|
+
|
|
8
|
+
## Why
|
|
9
|
+
|
|
10
|
+
When you call `model.generate()`, you get a number: total latency. That's not enough to make anything faster. `glasstrace` turns the black box into a measured picture: which layers are slow, where memory pressure lives, and what changes when you tweak batch size or sequence length.
|
|
11
|
+
|
|
12
|
+
## Install
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
pip install git+https://github.com/manu-j3400/glasstrace.git
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
PyPI release coming with v1.0.
|
|
19
|
+
|
|
20
|
+
## Usage
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import glasstrace
|
|
24
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
25
|
+
|
|
26
|
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
27
|
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
28
|
+
inputs = tokenizer("Hello, world!", return_tensors="pt")
|
|
29
|
+
|
|
30
|
+
with glasstrace.profile(model) as p:
|
|
31
|
+
out = model.generate(**inputs, max_new_tokens=50)
|
|
32
|
+
|
|
33
|
+
print(p.report())
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Status
|
|
37
|
+
|
|
38
|
+
**v0.1.0 — alpha.** Works on Qwen 2.5 0.5B and Llama 3.2 1B with CUDA. Tracks `nn.Linear` and `nn.LayerNorm` modules. Memory tracking, HTML reports, and broader model coverage planned for v0.2.
|
|
39
|
+
|
|
40
|
+
## Roadmap
|
|
41
|
+
|
|
42
|
+
- [x] v0.1 — Per-module CUDA timing, text-table report
|
|
43
|
+
- [x] v0.2 — Prefill vs decode split, memory tracking, HTML report
|
|
44
|
+
- [ ] v0.3 — Multi-model tested coverage, CLI
|
|
45
|
+
- [ ] v0.4 — Comparative analyses across Llama, Qwen, Phi (blog post)
|
|
46
|
+
- [ ] v1.0 — PyPI release, docs, demo video
|
|
47
|
+
|
|
48
|
+
## License
|
|
49
|
+
|
|
50
|
+
MIT
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Smallest possible example: profile a model's forward pass."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
5
|
+
|
|
6
|
+
import glasstrace
|
|
7
|
+
|
|
8
|
+
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main() -> None:
|
|
13
|
+
def warmup():
|
|
14
|
+
model.generate(**inputs, max_new_tokens=5, do_sample=False)
|
|
15
|
+
|
|
16
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
17
|
+
print(f"Loading {MODEL_NAME} on {device}...")
|
|
18
|
+
|
|
19
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
|
|
20
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
21
|
+
inputs = tokenizer("Hello, world!", return_tensors="pt").to(device)
|
|
22
|
+
|
|
23
|
+
print("Profiling forward pass...")
|
|
24
|
+
with glasstrace.profile(model, warmup=warmup) as p:
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
27
|
+
|
|
28
|
+
print(p.report())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
main()
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
File without changes
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Forward hooks that record per-module timing and shape info during inference."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Phase(str, Enum):
|
|
15
|
+
PREFILL = "prefill"
|
|
16
|
+
DECODE = "decode"
|
|
17
|
+
UNKNOWN = "unknown"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ModuleEvent:
|
|
22
|
+
"""A single recorded forward pass through one module."""
|
|
23
|
+
|
|
24
|
+
module_path: str # e.g. "model.layers.0.self_attn.q_proj"
|
|
25
|
+
module_type: str # e.g. "Linear"
|
|
26
|
+
input_shape: tuple | None # shape of the first tensor input, if any
|
|
27
|
+
output_shape: tuple | None # shape of the output tensor, if any
|
|
28
|
+
duration_ms: float # how long the forward pass took, in milliseconds
|
|
29
|
+
device: str # "cuda", "mps", or "cpu"
|
|
30
|
+
phase: Phase = Phase.UNKNOWN
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ModuleTracer:
|
|
35
|
+
"""Registers forward hooks on a model and collects per-module timing.
|
|
36
|
+
|
|
37
|
+
Uses CUDA events for accurate GPU timing when available, falls back to
|
|
38
|
+
wall-clock time otherwise. CPU/MPS wall-clock timing is approximate but
|
|
39
|
+
fine for development."""
|
|
40
|
+
|
|
41
|
+
target_types: tuple[type, ...] = (nn.Linear, nn.LayerNorm)
|
|
42
|
+
events: list[ModuleEvent] = field(default_factory=list)
|
|
43
|
+
memory_samples: list[dict] = field(default_factory=list)
|
|
44
|
+
_handles: list[Any] = field(default_factory=list)
|
|
45
|
+
_pending: dict[int, dict[str, Any]] = field(default_factory=dict)
|
|
46
|
+
_pass_count: int = 0 #tracks forward pass number
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def attach(self, model: nn.Module) -> None:
|
|
50
|
+
"""Walk the model and register hooks on every module of a target type."""
|
|
51
|
+
for name, module in model.named_modules():
|
|
52
|
+
if isinstance(module, self.target_types):
|
|
53
|
+
pre_handle = module.register_forward_pre_hook(
|
|
54
|
+
self._make_pre_hook(name, type(module).__name__)
|
|
55
|
+
)
|
|
56
|
+
post_handle = module.register_forward_hook(
|
|
57
|
+
self._make_post_hook(name, type(module).__name__)
|
|
58
|
+
)
|
|
59
|
+
self._handles.extend([pre_handle, post_handle])
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
self._attach_memory_sampler(model)
|
|
63
|
+
|
|
64
|
+
def _attach_memory_sampler(self, model: nn.Module) -> None:
|
|
65
|
+
"""Sample GPU memory allocated at the start of each forward pass."""
|
|
66
|
+
import torch
|
|
67
|
+
if not torch.cuda.is_available():
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
tracer_ref = self # capture self for the closure
|
|
71
|
+
|
|
72
|
+
for name, module in model.named_modules():
|
|
73
|
+
if isinstance(module, nn.Linear):
|
|
74
|
+
def memory_hook(mod, inputs):
|
|
75
|
+
mem_bytes = torch.cuda.memory_allocated()
|
|
76
|
+
phase = tracer_ref._detect_phase(
|
|
77
|
+
tracer_ref._shape_of(inputs[0]) if inputs else None
|
|
78
|
+
)
|
|
79
|
+
tracer_ref.memory_samples.append({
|
|
80
|
+
"pass": tracer_ref._pass_count,
|
|
81
|
+
"phase": phase.value,
|
|
82
|
+
"memory_bytes": mem_bytes,
|
|
83
|
+
})
|
|
84
|
+
tracer_ref._pass_count += 1
|
|
85
|
+
|
|
86
|
+
handle = module.register_forward_pre_hook(memory_hook)
|
|
87
|
+
self._handles.append(handle)
|
|
88
|
+
break # first Linear only
|
|
89
|
+
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
def detach(self) -> None:
|
|
93
|
+
"""Remove all registered hooks."""
|
|
94
|
+
for handle in self._handles:
|
|
95
|
+
handle.remove()
|
|
96
|
+
self._handles.clear()
|
|
97
|
+
self._pending.clear()
|
|
98
|
+
|
|
99
|
+
def _make_pre_hook(self, module_path: str, module_type: str):
|
|
100
|
+
def pre_hook(module: nn.Module, inputs: tuple) -> None:
|
|
101
|
+
device = self._device_of(inputs, module)
|
|
102
|
+
input_shape = self._shape_of(inputs[0]) if inputs else None
|
|
103
|
+
|
|
104
|
+
timing: dict[str, Any] = {
|
|
105
|
+
"module_path": module_path,
|
|
106
|
+
"module_type": module_type,
|
|
107
|
+
"input_shape": input_shape,
|
|
108
|
+
"device": device,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if device == "cuda":
|
|
112
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
113
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
114
|
+
start.record()
|
|
115
|
+
timing["cuda_start"] = start
|
|
116
|
+
timing["cuda_end"] = end
|
|
117
|
+
else:
|
|
118
|
+
timing["wall_start"] = time.perf_counter()
|
|
119
|
+
|
|
120
|
+
self._pending[id(module)] = timing
|
|
121
|
+
|
|
122
|
+
return pre_hook
|
|
123
|
+
|
|
124
|
+
def _make_post_hook(self, module_path: str, module_type: str):
|
|
125
|
+
def post_hook(module: nn.Module, inputs: tuple, output: Any) -> None:
|
|
126
|
+
timing = self._pending.pop(id(module), None)
|
|
127
|
+
if timing is None:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
output_shape = self._shape_of(output)
|
|
131
|
+
|
|
132
|
+
if timing["device"] == "cuda":
|
|
133
|
+
timing["cuda_end"].record()
|
|
134
|
+
torch.cuda.synchronize()
|
|
135
|
+
duration_ms = timing["cuda_start"].elapsed_time(timing["cuda_end"])
|
|
136
|
+
else:
|
|
137
|
+
duration_ms = (time.perf_counter() - timing["wall_start"]) * 1000.0
|
|
138
|
+
|
|
139
|
+
# Detect phase from input sequence dimension
|
|
140
|
+
phase = self._detect_phase(timing["input_shape"])
|
|
141
|
+
|
|
142
|
+
self.events.append(
|
|
143
|
+
ModuleEvent(
|
|
144
|
+
module_path=timing["module_path"],
|
|
145
|
+
module_type=timing["module_type"],
|
|
146
|
+
input_shape=timing["input_shape"],
|
|
147
|
+
output_shape=output_shape,
|
|
148
|
+
duration_ms=duration_ms,
|
|
149
|
+
device=timing["device"],
|
|
150
|
+
phase=phase,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return post_hook
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def _shape_of(x: Any) -> tuple | None:
|
|
158
|
+
if isinstance(x, torch.Tensor):
|
|
159
|
+
return tuple(x.shape)
|
|
160
|
+
if isinstance(x, (list, tuple)) and len(x) > 0 and isinstance(x[0], torch.Tensor):
|
|
161
|
+
return tuple(x[0].shape)
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _device_of(inputs: tuple, module: nn.Module) -> str:
|
|
166
|
+
# Prefer the input's device; fall back to a parameter's device.
|
|
167
|
+
if inputs and isinstance(inputs[0], torch.Tensor):
|
|
168
|
+
return inputs[0].device.type
|
|
169
|
+
for p in module.parameters():
|
|
170
|
+
return p.device.type
|
|
171
|
+
return "cpu"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _detect_phase(input_shape: tuple | None) -> Phase:
|
|
176
|
+
"""Infer prefill vs decode from the sequence dimension of the input.
|
|
177
|
+
|
|
178
|
+
For decoder-only transformers: seq_len > 1 means prefill (processing
|
|
179
|
+
the full prompt). seq_len == 1 means decode (one new token per pass).
|
|
180
|
+
"""
|
|
181
|
+
if input_shape is None:
|
|
182
|
+
return Phase.UNKNOWN
|
|
183
|
+
# Shape is (batch, seq_len, hidden_dim) for most transformer layers
|
|
184
|
+
if len(input_shape) >= 2:
|
|
185
|
+
seq_len = input_shape[1]
|
|
186
|
+
if seq_len == 1:
|
|
187
|
+
return Phase.DECODE
|
|
188
|
+
if seq_len > 1:
|
|
189
|
+
return Phase.PREFILL
|
|
190
|
+
return Phase.UNKNOWN
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""The user-facing profile() context manager."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from glasstrace.hooks import ModuleEvent, ModuleTracer
|
|
12
|
+
from glasstrace.report import format_report
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ProfileResult:
|
|
17
|
+
"""Holds events and memory samples from a profile() block."""
|
|
18
|
+
|
|
19
|
+
events: list[ModuleEvent] = field(default_factory=list)
|
|
20
|
+
memory_samples: list[dict] = field(default_factory=list)
|
|
21
|
+
|
|
22
|
+
def report(self, top_n: int = 20) -> str:
|
|
23
|
+
"""Return a formatted two-section text report."""
|
|
24
|
+
return format_report(self.events, self.memory_samples, top_n=top_n)
|
|
25
|
+
|
|
26
|
+
def __len__(self) -> int:
|
|
27
|
+
return len(self.events)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@contextmanager
|
|
31
|
+
def profile(model: nn.Module, warmup: Callable[[], None] | None = None):
|
|
32
|
+
"""Profile a model's forward passes within a with-block.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model: the model to instrument.
|
|
36
|
+
warmup: optional zero-arg callable run once before profiling starts,
|
|
37
|
+
with its events discarded. Strongly recommended on CUDA to avoid
|
|
38
|
+
cold-start timing artifacts.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
def warmup():
|
|
42
|
+
model.generate(**inputs, max_new_tokens=5)
|
|
43
|
+
|
|
44
|
+
with glasstrace.profile(model, warmup=warmup) as p:
|
|
45
|
+
model.generate(**inputs, max_new_tokens=50)
|
|
46
|
+
print(p.report())
|
|
47
|
+
"""
|
|
48
|
+
if warmup is not None:
|
|
49
|
+
import torch
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
warmup()
|
|
52
|
+
if torch.cuda.is_available():
|
|
53
|
+
torch.cuda.synchronize()
|
|
54
|
+
|
|
55
|
+
tracer = ModuleTracer()
|
|
56
|
+
tracer.attach(model)
|
|
57
|
+
result = ProfileResult(events=tracer.events, memory_samples=tracer.memory_samples)
|
|
58
|
+
try:
|
|
59
|
+
yield result
|
|
60
|
+
finally:
|
|
61
|
+
tracer.detach()
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Text-table report generation from ModuleEvent lists."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Iterable
|
|
7
|
+
|
|
8
|
+
from tabulate import tabulate
|
|
9
|
+
|
|
10
|
+
from glasstrace.hooks import ModuleEvent, Phase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _aggregate(events: list[ModuleEvent]) -> list[dict]:
|
|
14
|
+
"""Aggregate events by module path: sum times, count calls."""
|
|
15
|
+
agg: dict[str, dict] = defaultdict(
|
|
16
|
+
lambda: {"calls": 0, "total_ms": 0.0, "module_type": "", "device": ""}
|
|
17
|
+
)
|
|
18
|
+
for e in events:
|
|
19
|
+
a = agg[e.module_path]
|
|
20
|
+
a["calls"] += 1
|
|
21
|
+
a["total_ms"] += e.duration_ms
|
|
22
|
+
a["module_type"] = e.module_type
|
|
23
|
+
a["device"] = e.device
|
|
24
|
+
return [
|
|
25
|
+
{"path": path, **vals}
|
|
26
|
+
for path, vals in sorted(
|
|
27
|
+
agg.items(), key=lambda x: x[1]["total_ms"], reverse=True
|
|
28
|
+
)
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _section_table(rows: list[dict], total_ms: float, extra_col: str | None = None) -> str:
|
|
33
|
+
"""Format a list of aggregated module rows as a text table."""
|
|
34
|
+
if not rows:
|
|
35
|
+
return " (no events)\n"
|
|
36
|
+
|
|
37
|
+
table_rows = []
|
|
38
|
+
for r in rows:
|
|
39
|
+
row = {
|
|
40
|
+
"Module": r["path"],
|
|
41
|
+
"Type": r["module_type"],
|
|
42
|
+
"Calls": r["calls"],
|
|
43
|
+
"Total ms": f"{r['total_ms']:.2f}",
|
|
44
|
+
"Per-call ms": f"{r['total_ms'] / r['calls']:.2f}",
|
|
45
|
+
"% of phase": f"{r['total_ms'] / total_ms * 100:.1f}" if total_ms > 0 else "—",
|
|
46
|
+
}
|
|
47
|
+
table_rows.append(row)
|
|
48
|
+
|
|
49
|
+
return tabulate(table_rows, headers="keys", tablefmt="simple") + "\n"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def format_report(
|
|
53
|
+
events: Iterable[ModuleEvent],
|
|
54
|
+
memory_samples: list[dict] | None = None,
|
|
55
|
+
top_n: int = 20,
|
|
56
|
+
) -> str:
|
|
57
|
+
"""Produce a two-section report: prefill and decode."""
|
|
58
|
+
events = list(events)
|
|
59
|
+
if not events:
|
|
60
|
+
return (
|
|
61
|
+
"glasstrace: no events recorded.\n"
|
|
62
|
+
"(Was the model actually run inside the profile() block?)"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
device = events[0].device
|
|
66
|
+
|
|
67
|
+
prefill = [e for e in events if e.phase == Phase.PREFILL]
|
|
68
|
+
decode = [e for e in events if e.phase == Phase.DECODE]
|
|
69
|
+
unknown = [e for e in events if e.phase == Phase.UNKNOWN]
|
|
70
|
+
|
|
71
|
+
prefill_ms = sum(e.duration_ms for e in prefill)
|
|
72
|
+
decode_ms = sum(e.duration_ms for e in decode)
|
|
73
|
+
total_ms = sum(e.duration_ms for e in events)
|
|
74
|
+
|
|
75
|
+
# Decode passes = number of unique decode events for one module
|
|
76
|
+
# (all modules fire once per decode token)
|
|
77
|
+
decode_passes = decode[0].module_path and len(
|
|
78
|
+
[e for e in decode if e.module_path == decode[0].module_path]
|
|
79
|
+
) if decode else 0
|
|
80
|
+
per_token_ms = decode_ms / decode_passes if decode_passes > 0 else 0.0
|
|
81
|
+
|
|
82
|
+
# Memory summary
|
|
83
|
+
mem_summary = ""
|
|
84
|
+
if memory_samples:
|
|
85
|
+
decode_samples = [s for s in memory_samples if s["phase"] == "decode"]
|
|
86
|
+
if decode_samples:
|
|
87
|
+
min_mem = min(s["memory_bytes"] for s in decode_samples)
|
|
88
|
+
max_mem = max(s["memory_bytes"] for s in decode_samples)
|
|
89
|
+
kv_growth_mb = (max_mem - min_mem) / (1024 ** 2)
|
|
90
|
+
mem_summary = f" kv-cache growth during decode: {kv_growth_mb:.1f} MB\n"
|
|
91
|
+
|
|
92
|
+
header = (
|
|
93
|
+
f"\nglasstrace report\n"
|
|
94
|
+
f" modules profiled: {len({e.module_path for e in events})}\n"
|
|
95
|
+
f" total events: {len(events)}\n"
|
|
96
|
+
f" total measured time: {total_ms:.2f} ms\n"
|
|
97
|
+
f" device: {device}\n"
|
|
98
|
+
+ mem_summary
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Prefill section
|
|
102
|
+
prefill_header = (
|
|
103
|
+
f"\n── prefill (1 pass, {prefill_ms:.1f} ms total) "
|
|
104
|
+
+ "─" * 40 + "\n"
|
|
105
|
+
)
|
|
106
|
+
prefill_rows = _aggregate(prefill)[:top_n]
|
|
107
|
+
prefill_table = _section_table(prefill_rows, prefill_ms)
|
|
108
|
+
|
|
109
|
+
# Decode section
|
|
110
|
+
decode_header = (
|
|
111
|
+
f"\n── decode ({decode_passes} passes, {decode_ms:.1f} ms total"
|
|
112
|
+
+ (f", {per_token_ms:.1f} ms/token avg" if per_token_ms > 0 else "")
|
|
113
|
+
+ ") " + "─" * 20 + "\n"
|
|
114
|
+
)
|
|
115
|
+
decode_rows = _aggregate(decode)[:top_n]
|
|
116
|
+
decode_table = _section_table(decode_rows, decode_ms)
|
|
117
|
+
|
|
118
|
+
# Unknown section (should be empty for standard transformer runs)
|
|
119
|
+
unknown_section = ""
|
|
120
|
+
if unknown:
|
|
121
|
+
unknown_ms = sum(e.duration_ms for e in unknown)
|
|
122
|
+
unknown_section = (
|
|
123
|
+
f"\n── unclassified ({len(unknown)} events, {unknown_ms:.1f} ms) ──\n"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return (
|
|
127
|
+
header
|
|
128
|
+
+ prefill_header
|
|
129
|
+
+ prefill_table
|
|
130
|
+
+ decode_header
|
|
131
|
+
+ decode_table
|
|
132
|
+
+ unknown_section
|
|
133
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "glasstrace"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "Per-layer latency and memory profiler for transformer inference."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [{ name = "Manu", email = "therealmanujawahar@gmail.com" }]
|
|
13
|
+
keywords = ["pytorch", "transformers", "profiler", "inference", "llm"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Programming Language :: Python :: 3",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
22
|
+
]
|
|
23
|
+
dependencies = [
|
|
24
|
+
"torch>=2.0",
|
|
25
|
+
"transformers>=4.40",
|
|
26
|
+
"tabulate>=0.9",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[project.optional-dependencies]
|
|
30
|
+
dev = [
|
|
31
|
+
"pytest>=8.0",
|
|
32
|
+
"ruff>=0.5",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[project.urls]
|
|
36
|
+
Homepage = "https://github.com/manu-j3400/glasstrace"
|
|
37
|
+
Repository = "https://github.com/manu-j3400/glasstrace"
|
|
38
|
+
Issues = "https://github.com/manu-j3400/glasstrace/issues"
|
|
39
|
+
|
|
40
|
+
[tool.hatch.build.targets.wheel]
|
|
41
|
+
packages = ["glasstrace"]
|
|
42
|
+
|
|
43
|
+
[tool.ruff]
|
|
44
|
+
line-length = 100
|
|
45
|
+
target-version = "py311"
|
|
46
|
+
|
|
47
|
+
[tool.ruff.lint]
|
|
48
|
+
select = ["E", "F", "I", "W"]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Generate glasstrace benchmark graphic across 4 models."""
|
|
2
|
+
|
|
3
|
+
import matplotlib.patches as mpatches
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
|
|
6
|
+
models = ["Qwen2.5\n0.5B", "Qwen2.5\n1.5B", "Qwen2.5\n3B", "SmolLM2\n1.7B"]
|
|
7
|
+
colors = ["#4C9BE8", "#4C9BE8", "#4C9BE8", "#E8824C"]
|
|
8
|
+
|
|
9
|
+
ms_per_token = [17.4, 26.6, 43.9, 23.0]
|
|
10
|
+
kv_growth = [0.21, 0.49, 0.63, 3.38]
|
|
11
|
+
lm_head_pct = [11.9, 10.7, 5.9, 3.7]
|
|
12
|
+
|
|
13
|
+
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
|
|
14
|
+
fig.patch.set_facecolor("#0D1117")
|
|
15
|
+
|
|
16
|
+
for ax in axes:
|
|
17
|
+
ax.set_facecolor("#161B22")
|
|
18
|
+
ax.tick_params(colors="white")
|
|
19
|
+
for spine in ax.spines.values():
|
|
20
|
+
spine.set_edgecolor("#30363D")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def bar_chart(ax, values, title, ylabel, fmt=".1f"):
|
|
24
|
+
bars = ax.bar(models, values, color=colors, width=0.5, zorder=3)
|
|
25
|
+
ax.set_title(title, color="white", fontsize=12, pad=12, fontweight="bold")
|
|
26
|
+
ax.set_ylabel(ylabel, color="#8B949E", fontsize=10)
|
|
27
|
+
ax.tick_params(axis="x", colors="white", labelsize=9)
|
|
28
|
+
ax.tick_params(axis="y", colors="#8B949E", labelsize=9)
|
|
29
|
+
ax.yaxis.set_major_formatter(
|
|
30
|
+
plt.FuncFormatter(lambda x, _: f"{x:{fmt}}")
|
|
31
|
+
)
|
|
32
|
+
ax.grid(axis="y", color="#30363D", linewidth=0.8, zorder=0)
|
|
33
|
+
ax.set_axisbelow(True)
|
|
34
|
+
for bar, val in zip(bars, values):
|
|
35
|
+
ax.text(
|
|
36
|
+
bar.get_x() + bar.get_width() / 2,
|
|
37
|
+
bar.get_height() + max(values) * 0.02,
|
|
38
|
+
f"{val:{fmt}}",
|
|
39
|
+
ha="center",
|
|
40
|
+
va="bottom",
|
|
41
|
+
color="white",
|
|
42
|
+
fontsize=9,
|
|
43
|
+
fontweight="bold",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
bar_chart(axes[0], ms_per_token, "Decode Speed", "ms / token")
|
|
48
|
+
bar_chart(axes[1], kv_growth, "KV-Cache Growth", "MB (20 tokens)")
|
|
49
|
+
bar_chart(axes[2], lm_head_pct, "lm_head Share of Decode", "% of decode time")
|
|
50
|
+
|
|
51
|
+
blue = mpatches.Patch(color="#4C9BE8", label="Qwen 2.5 family")
|
|
52
|
+
orange = mpatches.Patch(color="#E8824C", label="SmolLM2 1.7B")
|
|
53
|
+
fig.legend(
|
|
54
|
+
handles=[blue, orange],
|
|
55
|
+
loc="lower center",
|
|
56
|
+
ncol=2,
|
|
57
|
+
frameon=False,
|
|
58
|
+
labelcolor="white",
|
|
59
|
+
fontsize=10,
|
|
60
|
+
bbox_to_anchor=(0.5, -0.02),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
fig.suptitle(
|
|
64
|
+
"glasstrace benchmark — 4 models on T4 GPU (fp16, 20 decode tokens)",
|
|
65
|
+
color="white",
|
|
66
|
+
fontsize=13,
|
|
67
|
+
fontweight="bold",
|
|
68
|
+
y=0.98,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
|
|
72
|
+
plt.savefig(
|
|
73
|
+
"figures/benchmark_graphic.png",
|
|
74
|
+
dpi=180,
|
|
75
|
+
bbox_inches="tight",
|
|
76
|
+
pad_inches=0.3,
|
|
77
|
+
facecolor="#0D1117",
|
|
78
|
+
)
|
|
79
|
+
plt.show()
|
|
80
|
+
print("Saved to figures/benchmark_graphic.png")
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Smoke tests — verify the package imports and the profiler runs end-to-end."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
import glasstrace
|
|
7
|
+
from glasstrace.hooks import Phase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_import():
|
|
11
|
+
assert glasstrace is not None
|
|
12
|
+
assert glasstrace.profile is not None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_version():
|
|
16
|
+
assert isinstance(glasstrace.__version__, str)
|
|
17
|
+
assert len(glasstrace.__version__) > 0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_profile_tiny_model():
|
|
21
|
+
"""Profiler captures events when a tiny model is run inside the context."""
|
|
22
|
+
|
|
23
|
+
class Tiny(nn.Module):
|
|
24
|
+
def __init__(self):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.fc1 = nn.Linear(8, 16)
|
|
27
|
+
self.norm = nn.LayerNorm(16)
|
|
28
|
+
self.fc2 = nn.Linear(16, 4)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
return self.fc2(self.norm(self.fc1(x)))
|
|
32
|
+
|
|
33
|
+
model = Tiny()
|
|
34
|
+
x = torch.randn(2, 8)
|
|
35
|
+
|
|
36
|
+
with glasstrace.profile(model) as p:
|
|
37
|
+
_ = model(x)
|
|
38
|
+
|
|
39
|
+
assert len(p) >= 3
|
|
40
|
+
module_paths = {e.module_path for e in p.events}
|
|
41
|
+
assert "fc1" in module_paths
|
|
42
|
+
assert "norm" in module_paths
|
|
43
|
+
assert "fc2" in module_paths
|
|
44
|
+
for e in p.events:
|
|
45
|
+
assert e.duration_ms >= 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_report_format():
|
|
49
|
+
"""Report renders something non-empty for a small profile."""
|
|
50
|
+
model = nn.Sequential(nn.Linear(4, 8), nn.LayerNorm(8), nn.Linear(8, 2))
|
|
51
|
+
x = torch.randn(1, 4)
|
|
52
|
+
|
|
53
|
+
with glasstrace.profile(model) as p:
|
|
54
|
+
_ = model(x)
|
|
55
|
+
|
|
56
|
+
report = p.report()
|
|
57
|
+
assert isinstance(report, str)
|
|
58
|
+
assert "glasstrace report" in report
|
|
59
|
+
assert "Module" in report
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_phase_detection():
|
|
63
|
+
"""Events are tagged with correct phase based on sequence dimension."""
|
|
64
|
+
from glasstrace.hooks import ModuleTracer
|
|
65
|
+
|
|
66
|
+
tracer = ModuleTracer()
|
|
67
|
+
|
|
68
|
+
# seq_len > 1 → prefill
|
|
69
|
+
assert tracer._detect_phase((2, 10, 64)) == Phase.PREFILL
|
|
70
|
+
|
|
71
|
+
# seq_len == 1 → decode
|
|
72
|
+
assert tracer._detect_phase((2, 1, 64)) == Phase.DECODE
|
|
73
|
+
|
|
74
|
+
# No shape → unknown
|
|
75
|
+
assert tracer._detect_phase(None) == Phase.UNKNOWN
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_prefill_decode_split():
|
|
79
|
+
"""Profile of a sequence model separates prefill from decode events."""
|
|
80
|
+
|
|
81
|
+
class SeqModel(nn.Module):
|
|
82
|
+
"""Simulates a tiny sequence model with variable-length inputs."""
|
|
83
|
+
def __init__(self):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.proj = nn.Linear(8, 8)
|
|
86
|
+
|
|
87
|
+
def forward(self, x):
|
|
88
|
+
return self.proj(x)
|
|
89
|
+
|
|
90
|
+
model = SeqModel()
|
|
91
|
+
|
|
92
|
+
with glasstrace.profile(model) as p:
|
|
93
|
+
# Simulate prefill: batch=1, seq=5
|
|
94
|
+
_ = model(torch.randn(1, 5, 8))
|
|
95
|
+
# Simulate 3 decode steps: batch=1, seq=1 each
|
|
96
|
+
for _ in range(3):
|
|
97
|
+
_ = model(torch.randn(1, 1, 8))
|
|
98
|
+
|
|
99
|
+
prefill_events = [e for e in p.events if e.phase == Phase.PREFILL]
|
|
100
|
+
decode_events = [e for e in p.events if e.phase == Phase.DECODE]
|
|
101
|
+
|
|
102
|
+
assert len(prefill_events) >= 1, "Expected at least one prefill event"
|
|
103
|
+
assert len(decode_events) >= 3, "Expected at least three decode events"
|