batch-probe 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batch_probe-0.1.0/LICENSE +21 -0
- batch_probe-0.1.0/PKG-INFO +194 -0
- batch_probe-0.1.0/README.md +165 -0
- batch_probe-0.1.0/pyproject.toml +51 -0
- batch_probe-0.1.0/setup.cfg +4 -0
- batch_probe-0.1.0/src/batch_probe.egg-info/PKG-INFO +194 -0
- batch_probe-0.1.0/src/batch_probe.egg-info/SOURCES.txt +15 -0
- batch_probe-0.1.0/src/batch_probe.egg-info/dependency_links.txt +1 -0
- batch_probe-0.1.0/src/batch_probe.egg-info/requires.txt +6 -0
- batch_probe-0.1.0/src/batch_probe.egg-info/top_level.txt +1 -0
- batch_probe-0.1.0/src/torch_probe/__init__.py +10 -0
- batch_probe-0.1.0/src/torch_probe/_cache.py +62 -0
- batch_probe-0.1.0/src/torch_probe/_cleanup.py +18 -0
- batch_probe-0.1.0/src/torch_probe/_probe.py +173 -0
- batch_probe-0.1.0/src/torch_probe/py.typed +0 -0
- batch_probe-0.1.0/tests/test_cache.py +70 -0
- batch_probe-0.1.0/tests/test_probe.py +230 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Andrew H. Bond
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: batch-probe
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Find the maximum batch size that fits in GPU memory. Binary search with OOM recovery.
|
|
5
|
+
Author: Andrew H. Bond
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/ahb-sjsu/batch-probe
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/ahb-sjsu/batch-probe/issues
|
|
9
|
+
Keywords: pytorch,gpu,memory,batch-size,oom,cuda
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=1.13.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: ruff>=0.4.0; extra == "dev"
|
|
28
|
+
Dynamic: license-file
|
|
29
|
+
|
|
30
|
+
# batch-probe
|
|
31
|
+
|
|
32
|
+
Find the maximum batch size that fits in GPU memory.
|
|
33
|
+
|
|
34
|
+
Binary search with OOM recovery, configurable safety headroom, no framework required.
|
|
35
|
+
|
|
36
|
+
## The Problem
|
|
37
|
+
|
|
38
|
+
Every ML practitioner has done this:
|
|
39
|
+
|
|
40
|
+
```
|
|
41
|
+
batch_size = 64 # OOM
|
|
42
|
+
batch_size = 32 # OOM
|
|
43
|
+
batch_size = 16 # OOM
|
|
44
|
+
batch_size = 8 # works... but am I leaving GPU memory on the table?
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
`batch-probe` automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.
|
|
48
|
+
|
|
49
|
+
## Install
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install batch-probe
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Quick Start
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
from torch_probe import probe_batch_size
|
|
59
|
+
|
|
60
|
+
batch_size = probe_batch_size(
|
|
61
|
+
model,
|
|
62
|
+
lambda bs: {
|
|
63
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
64
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
65
|
+
},
|
|
66
|
+
)
|
|
67
|
+
# torch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
That's it. Three lines. Works with any `nn.Module`.
|
|
71
|
+
|
|
72
|
+
## Usage
|
|
73
|
+
|
|
74
|
+
### Encoder models (BERT, RoBERTa, etc.)
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
batch_size = probe_batch_size(
|
|
78
|
+
model,
|
|
79
|
+
lambda bs: {
|
|
80
|
+
"input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
|
|
81
|
+
"attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
|
|
82
|
+
},
|
|
83
|
+
mode="train",
|
|
84
|
+
)
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
### Seq2seq models (T5, BART, etc.)
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
batch_size = probe_batch_size(
|
|
91
|
+
model,
|
|
92
|
+
lambda bs: {
|
|
93
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
94
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
95
|
+
"labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
96
|
+
},
|
|
97
|
+
mode="train",
|
|
98
|
+
)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
### Vision models
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
batch_size = probe_batch_size(
|
|
105
|
+
model,
|
|
106
|
+
lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
|
|
107
|
+
mode="infer",
|
|
108
|
+
)
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### Inference-only probing
|
|
112
|
+
|
|
113
|
+
Inference uses ~2-4x less memory than training (no gradients stored):
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
infer_batch = probe_batch_size(model, input_fn, mode="infer")
|
|
117
|
+
train_batch = probe_batch_size(model, input_fn, mode="train")
|
|
118
|
+
# infer_batch >> train_batch
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### Custom headroom
|
|
122
|
+
|
|
123
|
+
Default is 20% safety margin. Adjust for your risk tolerance:
|
|
124
|
+
|
|
125
|
+
```python
|
|
126
|
+
# Conservative (40% headroom) — for long training runs
|
|
127
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.4)
|
|
128
|
+
|
|
129
|
+
# Aggressive (5% headroom) — squeeze every last sample
|
|
130
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.05)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
### Caching
|
|
134
|
+
|
|
135
|
+
Use `cached_probe` to avoid re-probing the same model:
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
from torch_probe import cached_probe, clear_cache
|
|
139
|
+
|
|
140
|
+
batch_size = cached_probe(model, input_fn, mode="train") # probes
|
|
141
|
+
batch_size = cached_probe(model, input_fn, mode="train") # cache hit
|
|
142
|
+
|
|
143
|
+
clear_cache() # reset if model changed
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
## How It Works
|
|
147
|
+
|
|
148
|
+
1. Binary search between `low` (default 1) and `high` (default 4096)
|
|
149
|
+
2. At each midpoint, create dummy tensors via your `input_fn`
|
|
150
|
+
3. Run a forward pass (+ backward pass in train mode)
|
|
151
|
+
4. If OOM: upper bound ← midpoint − 1, clean GPU memory
|
|
152
|
+
5. If success: lower bound ← midpoint + 1
|
|
153
|
+
6. Return `int(max_successful × (1 − headroom))`
|
|
154
|
+
|
|
155
|
+
The OOM recovery uses `gc.collect()` + `torch.cuda.empty_cache()` + `torch.cuda.synchronize()` to fully reclaim memory between iterations.
|
|
156
|
+
|
|
157
|
+
## vs. Alternatives
|
|
158
|
+
|
|
159
|
+
| Feature | batch-probe | Lightning BatchSizeFinder | HF `auto_find_batch_size` |
|
|
160
|
+
|---|---|---|---|
|
|
161
|
+
| Works with raw PyTorch | Yes | No (needs LightningModule) | No (needs HF Trainer) |
|
|
162
|
+
| Algorithm | Binary search | Power-of-2 scaling | Halve on OOM |
|
|
163
|
+
| Configurable headroom | Yes | No | No |
|
|
164
|
+
| Train + infer modes | Yes | Train only | Train only |
|
|
165
|
+
| Dependencies | torch only | pytorch-lightning | accelerate |
|
|
166
|
+
|
|
167
|
+
## API Reference
|
|
168
|
+
|
|
169
|
+
### `probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)`
|
|
170
|
+
|
|
171
|
+
Find the maximum safe batch size.
|
|
172
|
+
|
|
173
|
+
- **model** (`nn.Module`): Your model, already on the target device.
|
|
174
|
+
- **input_fn** (`Callable[[int], dict[str, Tensor]]`): Takes batch size, returns dict of tensors for `model(**inputs)`.
|
|
175
|
+
- **mode** (`"train"` | `"infer"`): Train mode runs forward + backward. Default: `"train"`.
|
|
176
|
+
- **low** (`int`): Minimum batch size. Default: `1`.
|
|
177
|
+
- **high** (`int`): Upper bound for search. Default: `4096`.
|
|
178
|
+
- **headroom** (`float`): Safety margin. Default: `0.2` (20%).
|
|
179
|
+
- **device** (`str | torch.device | None`): Override device. Default: model's device.
|
|
180
|
+
- **verbose** (`bool`): Print progress. Default: `True`.
|
|
181
|
+
|
|
182
|
+
Returns: `int` — safe batch size.
|
|
183
|
+
|
|
184
|
+
### `cached_probe(model, input_fn, *, mode, **kwargs)`
|
|
185
|
+
|
|
186
|
+
Same as `probe_batch_size` but caches results keyed on model class, param count, input shapes, and mode.
|
|
187
|
+
|
|
188
|
+
### `clear_cache()`
|
|
189
|
+
|
|
190
|
+
Clear all cached probe results.
|
|
191
|
+
|
|
192
|
+
## License
|
|
193
|
+
|
|
194
|
+
MIT
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# batch-probe
|
|
2
|
+
|
|
3
|
+
Find the maximum batch size that fits in GPU memory.
|
|
4
|
+
|
|
5
|
+
Binary search with OOM recovery, configurable safety headroom, no framework required.
|
|
6
|
+
|
|
7
|
+
## The Problem
|
|
8
|
+
|
|
9
|
+
Every ML practitioner has done this:
|
|
10
|
+
|
|
11
|
+
```
|
|
12
|
+
batch_size = 64 # OOM
|
|
13
|
+
batch_size = 32 # OOM
|
|
14
|
+
batch_size = 16 # OOM
|
|
15
|
+
batch_size = 8 # works... but am I leaving GPU memory on the table?
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
`batch-probe` automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.
|
|
19
|
+
|
|
20
|
+
## Install
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install batch-probe
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Quick Start
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
from torch_probe import probe_batch_size
|
|
30
|
+
|
|
31
|
+
batch_size = probe_batch_size(
|
|
32
|
+
model,
|
|
33
|
+
lambda bs: {
|
|
34
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
35
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
36
|
+
},
|
|
37
|
+
)
|
|
38
|
+
# torch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
That's it. Three lines. Works with any `nn.Module`.
|
|
42
|
+
|
|
43
|
+
## Usage
|
|
44
|
+
|
|
45
|
+
### Encoder models (BERT, RoBERTa, etc.)
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
batch_size = probe_batch_size(
|
|
49
|
+
model,
|
|
50
|
+
lambda bs: {
|
|
51
|
+
"input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
|
|
52
|
+
"attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
|
|
53
|
+
},
|
|
54
|
+
mode="train",
|
|
55
|
+
)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
### Seq2seq models (T5, BART, etc.)
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
batch_size = probe_batch_size(
|
|
62
|
+
model,
|
|
63
|
+
lambda bs: {
|
|
64
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
65
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
66
|
+
"labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
67
|
+
},
|
|
68
|
+
mode="train",
|
|
69
|
+
)
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
### Vision models
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
batch_size = probe_batch_size(
|
|
76
|
+
model,
|
|
77
|
+
lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
|
|
78
|
+
mode="infer",
|
|
79
|
+
)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### Inference-only probing
|
|
83
|
+
|
|
84
|
+
Inference uses ~2-4x less memory than training (no gradients stored):
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
infer_batch = probe_batch_size(model, input_fn, mode="infer")
|
|
88
|
+
train_batch = probe_batch_size(model, input_fn, mode="train")
|
|
89
|
+
# infer_batch >> train_batch
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Custom headroom
|
|
93
|
+
|
|
94
|
+
Default is 20% safety margin. Adjust for your risk tolerance:
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
# Conservative (40% headroom) — for long training runs
|
|
98
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.4)
|
|
99
|
+
|
|
100
|
+
# Aggressive (5% headroom) — squeeze every last sample
|
|
101
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.05)
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Caching
|
|
105
|
+
|
|
106
|
+
Use `cached_probe` to avoid re-probing the same model:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
from torch_probe import cached_probe, clear_cache
|
|
110
|
+
|
|
111
|
+
batch_size = cached_probe(model, input_fn, mode="train") # probes
|
|
112
|
+
batch_size = cached_probe(model, input_fn, mode="train") # cache hit
|
|
113
|
+
|
|
114
|
+
clear_cache() # reset if model changed
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
## How It Works
|
|
118
|
+
|
|
119
|
+
1. Binary search between `low` (default 1) and `high` (default 4096)
|
|
120
|
+
2. At each midpoint, create dummy tensors via your `input_fn`
|
|
121
|
+
3. Run a forward pass (+ backward pass in train mode)
|
|
122
|
+
4. If OOM: upper bound ← midpoint − 1, clean GPU memory
|
|
123
|
+
5. If success: lower bound ← midpoint + 1
|
|
124
|
+
6. Return `int(max_successful × (1 − headroom))`
|
|
125
|
+
|
|
126
|
+
The OOM recovery uses `gc.collect()` + `torch.cuda.empty_cache()` + `torch.cuda.synchronize()` to fully reclaim memory between iterations.
|
|
127
|
+
|
|
128
|
+
## vs. Alternatives
|
|
129
|
+
|
|
130
|
+
| Feature | batch-probe | Lightning BatchSizeFinder | HF `auto_find_batch_size` |
|
|
131
|
+
|---|---|---|---|
|
|
132
|
+
| Works with raw PyTorch | Yes | No (needs LightningModule) | No (needs HF Trainer) |
|
|
133
|
+
| Algorithm | Binary search | Power-of-2 scaling | Halve on OOM |
|
|
134
|
+
| Configurable headroom | Yes | No | No |
|
|
135
|
+
| Train + infer modes | Yes | Train only | Train only |
|
|
136
|
+
| Dependencies | torch only | pytorch-lightning | accelerate |
|
|
137
|
+
|
|
138
|
+
## API Reference
|
|
139
|
+
|
|
140
|
+
### `probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)`
|
|
141
|
+
|
|
142
|
+
Find the maximum safe batch size.
|
|
143
|
+
|
|
144
|
+
- **model** (`nn.Module`): Your model, already on the target device.
|
|
145
|
+
- **input_fn** (`Callable[[int], dict[str, Tensor]]`): Takes batch size, returns dict of tensors for `model(**inputs)`.
|
|
146
|
+
- **mode** (`"train"` | `"infer"`): Train mode runs forward + backward. Default: `"train"`.
|
|
147
|
+
- **low** (`int`): Minimum batch size. Default: `1`.
|
|
148
|
+
- **high** (`int`): Upper bound for search. Default: `4096`.
|
|
149
|
+
- **headroom** (`float`): Safety margin. Default: `0.2` (20%).
|
|
150
|
+
- **device** (`str | torch.device | None`): Override device. Default: model's device.
|
|
151
|
+
- **verbose** (`bool`): Print progress. Default: `True`.
|
|
152
|
+
|
|
153
|
+
Returns: `int` — safe batch size.
|
|
154
|
+
|
|
155
|
+
### `cached_probe(model, input_fn, *, mode, **kwargs)`
|
|
156
|
+
|
|
157
|
+
Same as `probe_batch_size` but caches results keyed on model class, param count, input shapes, and mode.
|
|
158
|
+
|
|
159
|
+
### `clear_cache()`
|
|
160
|
+
|
|
161
|
+
Clear all cached probe results.
|
|
162
|
+
|
|
163
|
+
## License
|
|
164
|
+
|
|
165
|
+
MIT
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "batch-probe"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Find the maximum batch size that fits in GPU memory. Binary search with OOM recovery."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Andrew H. Bond"},
|
|
14
|
+
]
|
|
15
|
+
keywords = ["pytorch", "gpu", "memory", "batch-size", "oom", "cuda"]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"Intended Audience :: Science/Research",
|
|
20
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
21
|
+
"Programming Language :: Python :: 3",
|
|
22
|
+
"Programming Language :: Python :: 3.9",
|
|
23
|
+
"Programming Language :: Python :: 3.10",
|
|
24
|
+
"Programming Language :: Python :: 3.11",
|
|
25
|
+
"Programming Language :: Python :: 3.12",
|
|
26
|
+
"Programming Language :: Python :: 3.13",
|
|
27
|
+
]
|
|
28
|
+
dependencies = [
|
|
29
|
+
"torch>=1.13.0",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[project.optional-dependencies]
|
|
33
|
+
dev = [
|
|
34
|
+
"pytest>=7.0.0",
|
|
35
|
+
"pytest-cov>=4.0.0",
|
|
36
|
+
"ruff>=0.4.0",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
[project.urls]
|
|
40
|
+
"Homepage" = "https://github.com/ahb-sjsu/batch-probe"
|
|
41
|
+
"Bug Tracker" = "https://github.com/ahb-sjsu/batch-probe/issues"
|
|
42
|
+
|
|
43
|
+
[tool.setuptools.packages.find]
|
|
44
|
+
where = ["src"]
|
|
45
|
+
|
|
46
|
+
[tool.pytest.ini_options]
|
|
47
|
+
testpaths = ["tests"]
|
|
48
|
+
|
|
49
|
+
[tool.ruff]
|
|
50
|
+
line-length = 100
|
|
51
|
+
target-version = "py39"
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: batch-probe
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Find the maximum batch size that fits in GPU memory. Binary search with OOM recovery.
|
|
5
|
+
Author: Andrew H. Bond
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/ahb-sjsu/batch-probe
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/ahb-sjsu/batch-probe/issues
|
|
9
|
+
Keywords: pytorch,gpu,memory,batch-size,oom,cuda
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=1.13.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: ruff>=0.4.0; extra == "dev"
|
|
28
|
+
Dynamic: license-file
|
|
29
|
+
|
|
30
|
+
# batch-probe
|
|
31
|
+
|
|
32
|
+
Find the maximum batch size that fits in GPU memory.
|
|
33
|
+
|
|
34
|
+
Binary search with OOM recovery, configurable safety headroom, no framework required.
|
|
35
|
+
|
|
36
|
+
## The Problem
|
|
37
|
+
|
|
38
|
+
Every ML practitioner has done this:
|
|
39
|
+
|
|
40
|
+
```
|
|
41
|
+
batch_size = 64 # OOM
|
|
42
|
+
batch_size = 32 # OOM
|
|
43
|
+
batch_size = 16 # OOM
|
|
44
|
+
batch_size = 8 # works... but am I leaving GPU memory on the table?
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
`batch-probe` automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.
|
|
48
|
+
|
|
49
|
+
## Install
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install batch-probe
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Quick Start
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
from torch_probe import probe_batch_size
|
|
59
|
+
|
|
60
|
+
batch_size = probe_batch_size(
|
|
61
|
+
model,
|
|
62
|
+
lambda bs: {
|
|
63
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
64
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
65
|
+
},
|
|
66
|
+
)
|
|
67
|
+
# torch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
That's it. Three lines. Works with any `nn.Module`.
|
|
71
|
+
|
|
72
|
+
## Usage
|
|
73
|
+
|
|
74
|
+
### Encoder models (BERT, RoBERTa, etc.)
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
batch_size = probe_batch_size(
|
|
78
|
+
model,
|
|
79
|
+
lambda bs: {
|
|
80
|
+
"input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
|
|
81
|
+
"attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
|
|
82
|
+
},
|
|
83
|
+
mode="train",
|
|
84
|
+
)
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
### Seq2seq models (T5, BART, etc.)
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
batch_size = probe_batch_size(
|
|
91
|
+
model,
|
|
92
|
+
lambda bs: {
|
|
93
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
94
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
95
|
+
"labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
96
|
+
},
|
|
97
|
+
mode="train",
|
|
98
|
+
)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
### Vision models
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
batch_size = probe_batch_size(
|
|
105
|
+
model,
|
|
106
|
+
lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
|
|
107
|
+
mode="infer",
|
|
108
|
+
)
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### Inference-only probing
|
|
112
|
+
|
|
113
|
+
Inference uses ~2-4x less memory than training (no gradients stored):
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
infer_batch = probe_batch_size(model, input_fn, mode="infer")
|
|
117
|
+
train_batch = probe_batch_size(model, input_fn, mode="train")
|
|
118
|
+
# infer_batch >> train_batch
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### Custom headroom
|
|
122
|
+
|
|
123
|
+
Default is 20% safety margin. Adjust for your risk tolerance:
|
|
124
|
+
|
|
125
|
+
```python
|
|
126
|
+
# Conservative (40% headroom) — for long training runs
|
|
127
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.4)
|
|
128
|
+
|
|
129
|
+
# Aggressive (5% headroom) — squeeze every last sample
|
|
130
|
+
batch_size = probe_batch_size(model, input_fn, headroom=0.05)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
### Caching
|
|
134
|
+
|
|
135
|
+
Use `cached_probe` to avoid re-probing the same model:
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
from torch_probe import cached_probe, clear_cache
|
|
139
|
+
|
|
140
|
+
batch_size = cached_probe(model, input_fn, mode="train") # probes
|
|
141
|
+
batch_size = cached_probe(model, input_fn, mode="train") # cache hit
|
|
142
|
+
|
|
143
|
+
clear_cache() # reset if model changed
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
## How It Works
|
|
147
|
+
|
|
148
|
+
1. Binary search between `low` (default 1) and `high` (default 4096)
|
|
149
|
+
2. At each midpoint, create dummy tensors via your `input_fn`
|
|
150
|
+
3. Run a forward pass (+ backward pass in train mode)
|
|
151
|
+
4. If OOM: upper bound ← midpoint − 1, clean GPU memory
|
|
152
|
+
5. If success: lower bound ← midpoint + 1
|
|
153
|
+
6. Return `int(max_successful × (1 − headroom))`
|
|
154
|
+
|
|
155
|
+
The OOM recovery uses `gc.collect()` + `torch.cuda.empty_cache()` + `torch.cuda.synchronize()` to fully reclaim memory between iterations.
|
|
156
|
+
|
|
157
|
+
## vs. Alternatives
|
|
158
|
+
|
|
159
|
+
| Feature | batch-probe | Lightning BatchSizeFinder | HF `auto_find_batch_size` |
|
|
160
|
+
|---|---|---|---|
|
|
161
|
+
| Works with raw PyTorch | Yes | No (needs LightningModule) | No (needs HF Trainer) |
|
|
162
|
+
| Algorithm | Binary search | Power-of-2 scaling | Halve on OOM |
|
|
163
|
+
| Configurable headroom | Yes | No | No |
|
|
164
|
+
| Train + infer modes | Yes | Train only | Train only |
|
|
165
|
+
| Dependencies | torch only | pytorch-lightning | accelerate |
|
|
166
|
+
|
|
167
|
+
## API Reference
|
|
168
|
+
|
|
169
|
+
### `probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)`
|
|
170
|
+
|
|
171
|
+
Find the maximum safe batch size.
|
|
172
|
+
|
|
173
|
+
- **model** (`nn.Module`): Your model, already on the target device.
|
|
174
|
+
- **input_fn** (`Callable[[int], dict[str, Tensor]]`): Takes batch size, returns dict of tensors for `model(**inputs)`.
|
|
175
|
+
- **mode** (`"train"` | `"infer"`): Train mode runs forward + backward. Default: `"train"`.
|
|
176
|
+
- **low** (`int`): Minimum batch size. Default: `1`.
|
|
177
|
+
- **high** (`int`): Upper bound for search. Default: `4096`.
|
|
178
|
+
- **headroom** (`float`): Safety margin. Default: `0.2` (20%).
|
|
179
|
+
- **device** (`str | torch.device | None`): Override device. Default: model's device.
|
|
180
|
+
- **verbose** (`bool`): Print progress. Default: `True`.
|
|
181
|
+
|
|
182
|
+
Returns: `int` — safe batch size.
|
|
183
|
+
|
|
184
|
+
### `cached_probe(model, input_fn, *, mode, **kwargs)`
|
|
185
|
+
|
|
186
|
+
Same as `probe_batch_size` but caches results keyed on model class, param count, input shapes, and mode.
|
|
187
|
+
|
|
188
|
+
### `clear_cache()`
|
|
189
|
+
|
|
190
|
+
Clear all cached probe results.
|
|
191
|
+
|
|
192
|
+
## License
|
|
193
|
+
|
|
194
|
+
MIT
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/batch_probe.egg-info/PKG-INFO
|
|
5
|
+
src/batch_probe.egg-info/SOURCES.txt
|
|
6
|
+
src/batch_probe.egg-info/dependency_links.txt
|
|
7
|
+
src/batch_probe.egg-info/requires.txt
|
|
8
|
+
src/batch_probe.egg-info/top_level.txt
|
|
9
|
+
src/torch_probe/__init__.py
|
|
10
|
+
src/torch_probe/_cache.py
|
|
11
|
+
src/torch_probe/_cleanup.py
|
|
12
|
+
src/torch_probe/_probe.py
|
|
13
|
+
src/torch_probe/py.typed
|
|
14
|
+
tests/test_cache.py
|
|
15
|
+
tests/test_probe.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch_probe
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""torch-probe: Find the maximum batch size that fits in GPU memory."""
|
|
5
|
+
|
|
6
|
+
from torch_probe._cache import cached_probe, clear_cache
|
|
7
|
+
from torch_probe._probe import probe_batch_size
|
|
8
|
+
|
|
9
|
+
__all__ = ["probe_batch_size", "cached_probe", "clear_cache"]
|
|
10
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""In-memory cache for probe results."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
from torch_probe._probe import probe_batch_size
|
|
14
|
+
|
|
15
|
+
_cache: dict[str, int] = {}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _make_key(
|
|
19
|
+
model: nn.Module,
|
|
20
|
+
input_fn: Callable[[int], Dict[str, torch.Tensor]],
|
|
21
|
+
mode: str,
|
|
22
|
+
) -> str:
|
|
23
|
+
"""Build a cache key from model identity and input shape."""
|
|
24
|
+
# Model class + param count gives a stable identity
|
|
25
|
+
model_id = f"{model.__class__.__name__}_{sum(p.numel() for p in model.parameters())}"
|
|
26
|
+
|
|
27
|
+
# Probe input shapes at batch=1
|
|
28
|
+
try:
|
|
29
|
+
sample = input_fn(1)
|
|
30
|
+
shapes = "_".join(
|
|
31
|
+
f"{k}:{tuple(v.shape)}:{v.dtype}" for k, v in sorted(sample.items())
|
|
32
|
+
)
|
|
33
|
+
# Clean up sample tensors
|
|
34
|
+
del sample
|
|
35
|
+
except Exception:
|
|
36
|
+
shapes = "unknown"
|
|
37
|
+
|
|
38
|
+
return f"{model_id}__{mode}__{shapes}"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def cached_probe(
|
|
42
|
+
model: nn.Module,
|
|
43
|
+
input_fn: Callable[[int], Dict[str, torch.Tensor]],
|
|
44
|
+
*,
|
|
45
|
+
mode: Literal["train", "infer"] = "train",
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> int:
|
|
48
|
+
"""Like :func:`probe_batch_size` but caches results.
|
|
49
|
+
|
|
50
|
+
Same arguments as :func:`probe_batch_size`. Returns a cached result
|
|
51
|
+
if the same model class, parameter count, input shapes, and mode
|
|
52
|
+
have been probed before.
|
|
53
|
+
"""
|
|
54
|
+
key = _make_key(model, input_fn, mode)
|
|
55
|
+
if key not in _cache:
|
|
56
|
+
_cache[key] = probe_batch_size(model, input_fn, mode=mode, **kwargs)
|
|
57
|
+
return _cache[key]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def clear_cache() -> None:
|
|
61
|
+
"""Clear all cached probe results."""
|
|
62
|
+
_cache.clear()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""GPU memory cleanup utilities."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def gpu_cleanup() -> None:
|
|
14
|
+
"""Aggressively free GPU memory after an OOM or between probe iterations."""
|
|
15
|
+
gc.collect()
|
|
16
|
+
if torch.cuda.is_available():
|
|
17
|
+
torch.cuda.empty_cache()
|
|
18
|
+
torch.cuda.synchronize()
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Core binary-search GPU memory probe."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
from torch_probe._cleanup import gpu_cleanup
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _extract_loss(outputs: Any) -> torch.Tensor:
|
|
17
|
+
"""Extract a scalar loss from model outputs for the backward pass.
|
|
18
|
+
|
|
19
|
+
Handles:
|
|
20
|
+
- HuggingFace ModelOutput / dataclass with .loss attribute
|
|
21
|
+
- dict with "loss" key
|
|
22
|
+
- plain Tensor
|
|
23
|
+
- tuple (uses first element)
|
|
24
|
+
- dict without "loss" (uses first value)
|
|
25
|
+
"""
|
|
26
|
+
# .loss attribute (HuggingFace ModelOutput, dataclasses)
|
|
27
|
+
if hasattr(outputs, "loss") and outputs.loss is not None:
|
|
28
|
+
return outputs.loss
|
|
29
|
+
|
|
30
|
+
# dict with "loss" key
|
|
31
|
+
if isinstance(outputs, dict):
|
|
32
|
+
if "loss" in outputs:
|
|
33
|
+
return outputs["loss"]
|
|
34
|
+
# Fall back to first tensor value
|
|
35
|
+
for v in outputs.values():
|
|
36
|
+
if isinstance(v, torch.Tensor):
|
|
37
|
+
return v.mean()
|
|
38
|
+
|
|
39
|
+
# plain Tensor
|
|
40
|
+
if isinstance(outputs, torch.Tensor):
|
|
41
|
+
return outputs.mean()
|
|
42
|
+
|
|
43
|
+
# tuple / list
|
|
44
|
+
if isinstance(outputs, (tuple, list)) and len(outputs) > 0:
|
|
45
|
+
first = outputs[0]
|
|
46
|
+
if isinstance(first, torch.Tensor):
|
|
47
|
+
return first.mean()
|
|
48
|
+
|
|
49
|
+
raise TypeError(
|
|
50
|
+
f"Cannot extract a loss from model output of type {type(outputs)}. "
|
|
51
|
+
"Ensure your model returns a tensor, a dict with a 'loss' key, "
|
|
52
|
+
"or an object with a .loss attribute."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def probe_batch_size(
|
|
57
|
+
model: nn.Module,
|
|
58
|
+
input_fn: Callable[[int], Dict[str, torch.Tensor]],
|
|
59
|
+
*,
|
|
60
|
+
mode: Literal["train", "infer"] = "train",
|
|
61
|
+
low: int = 1,
|
|
62
|
+
high: int = 4096,
|
|
63
|
+
headroom: float = 0.2,
|
|
64
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
65
|
+
verbose: bool = True,
|
|
66
|
+
) -> int:
|
|
67
|
+
"""Find the maximum batch size that fits in GPU memory.
|
|
68
|
+
|
|
69
|
+
Uses binary search with OOM recovery. Tries a forward pass (and backward
|
|
70
|
+
pass in train mode) at each candidate batch size. Returns the largest
|
|
71
|
+
successful size minus a safety margin.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model: Any ``nn.Module``, already on the target device.
|
|
75
|
+
input_fn: A callable that takes a batch size ``int`` and returns a dict
|
|
76
|
+
of tensors to pass as ``**kwargs`` to ``model()``. Tensors must
|
|
77
|
+
already be on the correct device.
|
|
78
|
+
mode: ``"train"`` runs forward + backward (2-4x more memory).
|
|
79
|
+
``"infer"`` runs forward only under ``torch.no_grad()``.
|
|
80
|
+
low: Minimum batch size to try (and the floor for the return value).
|
|
81
|
+
high: Starting upper bound for binary search.
|
|
82
|
+
headroom: Fraction of headroom to subtract. ``0.2`` (default) means
|
|
83
|
+
the returned batch size is ``int(max_successful * 0.8)``.
|
|
84
|
+
device: Device to check. Defaults to the device of the model's first
|
|
85
|
+
parameter. On CPU the probe still runs but skips CUDA-specific cleanup.
|
|
86
|
+
verbose: Print probe progress.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Safe batch size (``int``), guaranteed ``>= low``.
|
|
90
|
+
|
|
91
|
+
Example::
|
|
92
|
+
|
|
93
|
+
from torch_probe import probe_batch_size
|
|
94
|
+
|
|
95
|
+
batch_size = probe_batch_size(
|
|
96
|
+
model,
|
|
97
|
+
lambda bs: {
|
|
98
|
+
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
|
|
99
|
+
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
|
|
100
|
+
},
|
|
101
|
+
)
|
|
102
|
+
"""
|
|
103
|
+
# Resolve device
|
|
104
|
+
if device is None:
|
|
105
|
+
try:
|
|
106
|
+
device = next(model.parameters()).device
|
|
107
|
+
except StopIteration:
|
|
108
|
+
device = torch.device("cpu")
|
|
109
|
+
device = torch.device(device) if isinstance(device, str) else device
|
|
110
|
+
|
|
111
|
+
is_cuda = device.type == "cuda"
|
|
112
|
+
|
|
113
|
+
# Save and restore model state
|
|
114
|
+
was_training = model.training
|
|
115
|
+
best = low
|
|
116
|
+
|
|
117
|
+
if verbose:
|
|
118
|
+
print(f"torch-probe: probing batch size (mode={mode}, range=[{low}, {high}], "
|
|
119
|
+
f"headroom={headroom:.0%})...", end="", flush=True)
|
|
120
|
+
|
|
121
|
+
while low <= high:
|
|
122
|
+
mid = (low + high) // 2
|
|
123
|
+
success = False
|
|
124
|
+
inputs = None
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
if is_cuda:
|
|
128
|
+
gpu_cleanup()
|
|
129
|
+
inputs = input_fn(mid)
|
|
130
|
+
|
|
131
|
+
if mode == "train":
|
|
132
|
+
model.train()
|
|
133
|
+
outputs = model(**inputs)
|
|
134
|
+
loss = _extract_loss(outputs)
|
|
135
|
+
loss.backward()
|
|
136
|
+
model.zero_grad(set_to_none=True)
|
|
137
|
+
else:
|
|
138
|
+
model.eval()
|
|
139
|
+
with torch.no_grad():
|
|
140
|
+
model(**inputs)
|
|
141
|
+
|
|
142
|
+
success = True
|
|
143
|
+
|
|
144
|
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
145
|
+
err_msg = str(e).lower()
|
|
146
|
+
if "out of memory" not in err_msg and "cuda" not in err_msg:
|
|
147
|
+
# Not an OOM — re-raise
|
|
148
|
+
model.train(was_training)
|
|
149
|
+
raise
|
|
150
|
+
finally:
|
|
151
|
+
# Always clean up tensors
|
|
152
|
+
if inputs is not None:
|
|
153
|
+
del inputs
|
|
154
|
+
if is_cuda:
|
|
155
|
+
gpu_cleanup()
|
|
156
|
+
|
|
157
|
+
if success:
|
|
158
|
+
best = mid
|
|
159
|
+
low = mid + 1
|
|
160
|
+
else:
|
|
161
|
+
high = mid - 1
|
|
162
|
+
|
|
163
|
+
# Restore model state
|
|
164
|
+
model.train(was_training)
|
|
165
|
+
|
|
166
|
+
safe = max(1, int(best * (1.0 - headroom)))
|
|
167
|
+
# Never go below the user's requested minimum
|
|
168
|
+
safe = max(safe, 1)
|
|
169
|
+
|
|
170
|
+
if verbose:
|
|
171
|
+
print(f" max={best}, safe={safe}")
|
|
172
|
+
|
|
173
|
+
return safe
|
|
File without changes
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Tests for the caching layer."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from torch_probe import cached_probe, clear_cache
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CountingModel(nn.Module):
|
|
15
|
+
"""Model that counts how many times forward() is called."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, oom_threshold: int = 8):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.linear = nn.Linear(10, 2)
|
|
20
|
+
self.oom_threshold = oom_threshold
|
|
21
|
+
self.call_count = 0
|
|
22
|
+
|
|
23
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
self.call_count += 1
|
|
25
|
+
if x.shape[0] > self.oom_threshold:
|
|
26
|
+
raise torch.cuda.OutOfMemoryError("CUDA out of memory.")
|
|
27
|
+
return self.linear(x)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestCachedProbe:
|
|
31
|
+
def setup_method(self):
|
|
32
|
+
clear_cache()
|
|
33
|
+
|
|
34
|
+
def test_cache_hit(self):
|
|
35
|
+
model = CountingModel(oom_threshold=8)
|
|
36
|
+
input_fn = lambda bs: {"x": torch.randn(bs, 10)}
|
|
37
|
+
|
|
38
|
+
r1 = cached_probe(model, input_fn, mode="infer", high=32, verbose=False)
|
|
39
|
+
calls_after_first = model.call_count
|
|
40
|
+
|
|
41
|
+
r2 = cached_probe(model, input_fn, mode="infer", high=32, verbose=False)
|
|
42
|
+
calls_after_second = model.call_count
|
|
43
|
+
|
|
44
|
+
assert r1 == r2
|
|
45
|
+
assert calls_after_second == calls_after_first # No new forward calls
|
|
46
|
+
|
|
47
|
+
def test_different_modes_separate_cache(self):
|
|
48
|
+
model = CountingModel(oom_threshold=8)
|
|
49
|
+
input_fn = lambda bs: {"x": torch.randn(bs, 10)}
|
|
50
|
+
|
|
51
|
+
r_train = cached_probe(model, input_fn, mode="train", high=32, verbose=False)
|
|
52
|
+
r_infer = cached_probe(model, input_fn, mode="infer", high=32, verbose=False)
|
|
53
|
+
|
|
54
|
+
# Both should probe (different modes = different cache keys)
|
|
55
|
+
# Results may differ since train mode uses backward pass
|
|
56
|
+
assert isinstance(r_train, int)
|
|
57
|
+
assert isinstance(r_infer, int)
|
|
58
|
+
|
|
59
|
+
def test_clear_cache(self):
|
|
60
|
+
model = CountingModel(oom_threshold=8)
|
|
61
|
+
input_fn = lambda bs: {"x": torch.randn(bs, 10)}
|
|
62
|
+
|
|
63
|
+
cached_probe(model, input_fn, mode="infer", high=32, verbose=False)
|
|
64
|
+
calls_first = model.call_count
|
|
65
|
+
|
|
66
|
+
clear_cache()
|
|
67
|
+
cached_probe(model, input_fn, mode="infer", high=32, verbose=False)
|
|
68
|
+
calls_second = model.call_count
|
|
69
|
+
|
|
70
|
+
assert calls_second > calls_first # Had to re-probe after clearing
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) 2026 Andrew H. Bond
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Tests for the core probe_batch_size function.
|
|
5
|
+
|
|
6
|
+
These tests work without a GPU by using models that simulate OOM
|
|
7
|
+
via a Python-side threshold check.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
|
|
16
|
+
from torch_probe import probe_batch_size
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FakeOOMModel(nn.Module):
|
|
20
|
+
"""Model that raises OOM above a configurable batch size threshold."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, oom_threshold: int = 16):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.linear = nn.Linear(10, 2)
|
|
25
|
+
self.oom_threshold = oom_threshold
|
|
26
|
+
|
|
27
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
28
|
+
if x.shape[0] > self.oom_threshold:
|
|
29
|
+
raise torch.cuda.OutOfMemoryError(
|
|
30
|
+
"CUDA out of memory. Tried to allocate 2.00 GiB"
|
|
31
|
+
)
|
|
32
|
+
return self.linear(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FakeDictModel(nn.Module):
|
|
36
|
+
"""Model that returns a dict with 'loss' key (like HuggingFace)."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, oom_threshold: int = 8):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.linear = nn.Linear(10, 2)
|
|
41
|
+
self.oom_threshold = oom_threshold
|
|
42
|
+
|
|
43
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict:
|
|
44
|
+
if input_ids.shape[0] > self.oom_threshold:
|
|
45
|
+
raise torch.cuda.OutOfMemoryError("CUDA out of memory.")
|
|
46
|
+
out = self.linear(input_ids.float())
|
|
47
|
+
return {"loss": out.mean(), "logits": out}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _make_input_fn(feature_dim: int = 10):
|
|
51
|
+
"""Create an input_fn for FakeOOMModel."""
|
|
52
|
+
return lambda bs: {"x": torch.randn(bs, feature_dim)}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _make_dict_input_fn(feature_dim: int = 10):
|
|
56
|
+
"""Create an input_fn for FakeDictModel."""
|
|
57
|
+
return lambda bs: {
|
|
58
|
+
"input_ids": torch.zeros(bs, feature_dim, dtype=torch.long),
|
|
59
|
+
"attention_mask": torch.ones(bs, feature_dim, dtype=torch.long),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TestProbeBatchSize:
|
|
64
|
+
def test_finds_max_batch_infer(self):
|
|
65
|
+
model = FakeOOMModel(oom_threshold=32)
|
|
66
|
+
result = probe_batch_size(
|
|
67
|
+
model,
|
|
68
|
+
_make_input_fn(),
|
|
69
|
+
mode="infer",
|
|
70
|
+
high=128,
|
|
71
|
+
headroom=0.2,
|
|
72
|
+
verbose=False,
|
|
73
|
+
)
|
|
74
|
+
# max successful = 32, safe = int(32 * 0.8) = 25
|
|
75
|
+
assert result == 25
|
|
76
|
+
|
|
77
|
+
def test_finds_max_batch_train(self):
|
|
78
|
+
model = FakeOOMModel(oom_threshold=32)
|
|
79
|
+
result = probe_batch_size(
|
|
80
|
+
model,
|
|
81
|
+
_make_input_fn(),
|
|
82
|
+
mode="train",
|
|
83
|
+
high=128,
|
|
84
|
+
headroom=0.2,
|
|
85
|
+
verbose=False,
|
|
86
|
+
)
|
|
87
|
+
assert result == 25
|
|
88
|
+
|
|
89
|
+
def test_headroom_zero(self):
|
|
90
|
+
model = FakeOOMModel(oom_threshold=32)
|
|
91
|
+
result = probe_batch_size(
|
|
92
|
+
model,
|
|
93
|
+
_make_input_fn(),
|
|
94
|
+
mode="infer",
|
|
95
|
+
high=128,
|
|
96
|
+
headroom=0.0,
|
|
97
|
+
verbose=False,
|
|
98
|
+
)
|
|
99
|
+
assert result == 32
|
|
100
|
+
|
|
101
|
+
def test_headroom_fifty_percent(self):
|
|
102
|
+
model = FakeOOMModel(oom_threshold=32)
|
|
103
|
+
result = probe_batch_size(
|
|
104
|
+
model,
|
|
105
|
+
_make_input_fn(),
|
|
106
|
+
mode="infer",
|
|
107
|
+
high=128,
|
|
108
|
+
headroom=0.5,
|
|
109
|
+
verbose=False,
|
|
110
|
+
)
|
|
111
|
+
assert result == 16
|
|
112
|
+
|
|
113
|
+
def test_returns_at_least_one(self):
|
|
114
|
+
# Even with high headroom and low threshold, never returns 0
|
|
115
|
+
model = FakeOOMModel(oom_threshold=2)
|
|
116
|
+
result = probe_batch_size(
|
|
117
|
+
model,
|
|
118
|
+
_make_input_fn(),
|
|
119
|
+
mode="infer",
|
|
120
|
+
high=64,
|
|
121
|
+
headroom=0.9,
|
|
122
|
+
verbose=False,
|
|
123
|
+
)
|
|
124
|
+
assert result >= 1
|
|
125
|
+
|
|
126
|
+
def test_all_oom(self):
|
|
127
|
+
# Model OOMs even at batch=1
|
|
128
|
+
model = FakeOOMModel(oom_threshold=0)
|
|
129
|
+
result = probe_batch_size(
|
|
130
|
+
model,
|
|
131
|
+
_make_input_fn(),
|
|
132
|
+
mode="infer",
|
|
133
|
+
low=1,
|
|
134
|
+
high=64,
|
|
135
|
+
headroom=0.2,
|
|
136
|
+
verbose=False,
|
|
137
|
+
)
|
|
138
|
+
# best stays at low=1 but it OOMs, so best stays at initial=1
|
|
139
|
+
# This is an edge case — we return 1 as the floor
|
|
140
|
+
assert result >= 1
|
|
141
|
+
|
|
142
|
+
def test_dict_output_model(self):
|
|
143
|
+
model = FakeDictModel(oom_threshold=8)
|
|
144
|
+
result = probe_batch_size(
|
|
145
|
+
model,
|
|
146
|
+
_make_dict_input_fn(),
|
|
147
|
+
mode="train",
|
|
148
|
+
high=64,
|
|
149
|
+
headroom=0.2,
|
|
150
|
+
verbose=False,
|
|
151
|
+
)
|
|
152
|
+
assert result == max(1, int(8 * 0.8)) # 6
|
|
153
|
+
|
|
154
|
+
def test_cpu_probes_normally(self):
|
|
155
|
+
# On CPU, probe still runs (no CUDA cleanup, but OOM still detected)
|
|
156
|
+
model = FakeOOMModel(oom_threshold=16)
|
|
157
|
+
result = probe_batch_size(
|
|
158
|
+
model,
|
|
159
|
+
_make_input_fn(),
|
|
160
|
+
mode="infer",
|
|
161
|
+
high=64,
|
|
162
|
+
device="cpu",
|
|
163
|
+
headroom=0.2,
|
|
164
|
+
verbose=False,
|
|
165
|
+
)
|
|
166
|
+
assert result == int(16 * 0.8) # 12
|
|
167
|
+
|
|
168
|
+
def test_non_oom_error_propagates(self):
|
|
169
|
+
class BadModel(nn.Module):
|
|
170
|
+
def __init__(self):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.linear = nn.Linear(10, 2)
|
|
173
|
+
|
|
174
|
+
def forward(self, x):
|
|
175
|
+
raise ValueError("Something went wrong")
|
|
176
|
+
|
|
177
|
+
model = BadModel()
|
|
178
|
+
with pytest.raises(ValueError, match="Something went wrong"):
|
|
179
|
+
probe_batch_size(
|
|
180
|
+
model,
|
|
181
|
+
_make_input_fn(),
|
|
182
|
+
mode="infer",
|
|
183
|
+
high=16,
|
|
184
|
+
verbose=False,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def test_restores_training_mode(self):
|
|
188
|
+
model = FakeOOMModel(oom_threshold=16)
|
|
189
|
+
|
|
190
|
+
# Start in eval mode
|
|
191
|
+
model.eval()
|
|
192
|
+
assert not model.training
|
|
193
|
+
|
|
194
|
+
probe_batch_size(
|
|
195
|
+
model,
|
|
196
|
+
_make_input_fn(),
|
|
197
|
+
mode="train",
|
|
198
|
+
high=32,
|
|
199
|
+
verbose=False,
|
|
200
|
+
)
|
|
201
|
+
# Should restore eval mode
|
|
202
|
+
assert not model.training
|
|
203
|
+
|
|
204
|
+
# Start in train mode
|
|
205
|
+
model.train()
|
|
206
|
+
assert model.training
|
|
207
|
+
|
|
208
|
+
probe_batch_size(
|
|
209
|
+
model,
|
|
210
|
+
_make_input_fn(),
|
|
211
|
+
mode="infer",
|
|
212
|
+
high=32,
|
|
213
|
+
verbose=False,
|
|
214
|
+
)
|
|
215
|
+
# Should restore train mode
|
|
216
|
+
assert model.training
|
|
217
|
+
|
|
218
|
+
def test_verbose_output(self, capsys):
|
|
219
|
+
model = FakeOOMModel(oom_threshold=8)
|
|
220
|
+
probe_batch_size(
|
|
221
|
+
model,
|
|
222
|
+
_make_input_fn(),
|
|
223
|
+
mode="infer",
|
|
224
|
+
high=32,
|
|
225
|
+
verbose=True,
|
|
226
|
+
)
|
|
227
|
+
captured = capsys.readouterr()
|
|
228
|
+
assert "torch-probe" in captured.out
|
|
229
|
+
assert "max=8" in captured.out
|
|
230
|
+
assert "safe=" in captured.out
|