torch-max-mem 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_max_mem/__init__.py +2 -3
- torch_max_mem/api.py +31 -47
- torch_max_mem/py.typed +0 -0
- torch_max_mem/version.py +4 -6
- torch_max_mem-0.1.4.dist-info/METADATA +345 -0
- torch_max_mem-0.1.4.dist-info/RECORD +8 -0
- torch_max_mem-0.1.4.dist-info/WHEEL +4 -0
- {torch_max_mem-0.1.2.dist-info → torch_max_mem-0.1.4.dist-info/licenses}/LICENSE +0 -0
- torch_max_mem-0.1.2.dist-info/METADATA +0 -215
- torch_max_mem-0.1.2.dist-info/RECORD +0 -8
- torch_max_mem-0.1.2.dist-info/WHEEL +0 -5
- torch_max_mem-0.1.2.dist-info/top_level.txt +0 -1
torch_max_mem/__init__.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
1
|
"""Maximize memory utilization with PyTorch."""
|
2
|
+
|
4
3
|
from .api import MemoryUtilizationMaximizer, maximize_memory_utilization
|
5
4
|
|
6
5
|
__all__ = [
|
7
|
-
"maximize_memory_utilization",
|
8
6
|
"MemoryUtilizationMaximizer",
|
7
|
+
"maximize_memory_utilization",
|
9
8
|
]
|
torch_max_mem/api.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
1
|
"""
|
4
2
|
This module contains the public API.
|
5
3
|
|
@@ -9,6 +7,7 @@ Assume you have a function for batched computation of nearest neighbors using br
|
|
9
7
|
|
10
8
|
import torch
|
11
9
|
|
10
|
+
|
12
11
|
def knn(x, y, batch_size, k: int = 3):
|
13
12
|
return torch.cat(
|
14
13
|
[
|
@@ -26,6 +25,7 @@ out-of-memory error occurs.
|
|
26
25
|
import torch
|
27
26
|
from torch_max_mem import maximize_memory_utilization
|
28
27
|
|
28
|
+
|
29
29
|
@maximize_memory_utilization()
|
30
30
|
def knn(x, y, batch_size, k: int = 3):
|
31
31
|
return torch.cat(
|
@@ -45,6 +45,7 @@ In the code, you can now always pass the largest sensible batch size, e.g.,
|
|
45
45
|
y = torch.rand(200, 100, device="cuda")
|
46
46
|
knn(x, y, batch_size=x.shape[0])
|
47
47
|
"""
|
48
|
+
|
48
49
|
# cf. https://gist.github.com/mberr/c37a8068b38cabc98228db2cbe358043
|
49
50
|
from __future__ import annotations
|
50
51
|
|
@@ -52,17 +53,10 @@ import functools
|
|
52
53
|
import inspect
|
53
54
|
import itertools
|
54
55
|
import logging
|
55
|
-
from collections import
|
56
|
+
from collections.abc import Collection, Iterable, Mapping, MutableMapping, Sequence
|
56
57
|
from typing import (
|
57
58
|
Any,
|
58
59
|
Callable,
|
59
|
-
Collection,
|
60
|
-
Iterable,
|
61
|
-
Mapping,
|
62
|
-
MutableMapping,
|
63
|
-
Optional,
|
64
|
-
Sequence,
|
65
|
-
Tuple,
|
66
60
|
TypeVar,
|
67
61
|
)
|
68
62
|
|
@@ -99,9 +93,7 @@ def upgrade_to_sequence(
|
|
99
93
|
when the (inferred) length of q and parameter_name do not match
|
100
94
|
"""
|
101
95
|
# normalize parameter name
|
102
|
-
parameter_names = (
|
103
|
-
(parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name)
|
104
|
-
)
|
96
|
+
parameter_names = (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name)
|
105
97
|
q = (q,) if isinstance(q, int) else tuple(q)
|
106
98
|
q = q * len(parameter_names) if len(q) == 1 else q
|
107
99
|
if len(q) != len(parameter_names):
|
@@ -128,7 +120,7 @@ def determine_default_max_value(
|
|
128
120
|
:raises ValueError:
|
129
121
|
when the function does not have a parameter of the given name
|
130
122
|
"""
|
131
|
-
if parameter_name not in signature.parameters
|
123
|
+
if parameter_name not in signature.parameters:
|
132
124
|
raise ValueError(f"{func} does not have a parameter {parameter_name}.")
|
133
125
|
_parameter = signature.parameters[parameter_name]
|
134
126
|
if _parameter.annotation != inspect.Parameter.empty and _parameter.annotation not in (
|
@@ -146,10 +138,10 @@ def determine_default_max_value(
|
|
146
138
|
|
147
139
|
def determine_max_value(
|
148
140
|
bound_arguments: inspect.BoundArguments,
|
149
|
-
args: P.args,
|
150
|
-
kwargs: P.kwargs,
|
151
141
|
parameter_name: str,
|
152
142
|
default_max_value: int | Callable[P, int] | None,
|
143
|
+
*args: P.args,
|
144
|
+
**kwargs: P.kwargs,
|
153
145
|
) -> int:
|
154
146
|
"""
|
155
147
|
Either use the provided value, or the default maximum value.
|
@@ -204,12 +196,13 @@ ADDITIONAL_OOM_ERROR_INFIXES = {
|
|
204
196
|
def iter_tensor_devices(*args: Any, **kwargs: Any) -> Iterable[torch.device]:
|
205
197
|
"""Iterate over tensors' devices (may contain duplicates)."""
|
206
198
|
for obj in itertools.chain(args, kwargs.values()):
|
207
|
-
if torch.
|
208
|
-
assert isinstance(obj, torch.Tensor)
|
199
|
+
if isinstance(obj, torch.Tensor):
|
209
200
|
yield obj.device
|
210
201
|
|
211
202
|
|
212
|
-
def create_tensor_checker(
|
203
|
+
def create_tensor_checker(
|
204
|
+
safe_devices: Collection[str] | None = None,
|
205
|
+
) -> Callable[P, None]:
|
213
206
|
"""
|
214
207
|
Create a function that warns when tensors are on any device that is not considered safe.
|
215
208
|
|
@@ -274,16 +267,15 @@ def is_oom_error(error: BaseException) -> bool:
|
|
274
267
|
return True
|
275
268
|
if not isinstance(error, RuntimeError):
|
276
269
|
return False
|
277
|
-
|
278
|
-
|
279
|
-
return any(infix in error.args[0] for infix in ADDITIONAL_OOM_ERROR_INFIXES)
|
270
|
+
message = str(error)
|
271
|
+
return any(infix in message for infix in ADDITIONAL_OOM_ERROR_INFIXES)
|
280
272
|
|
281
273
|
|
282
274
|
def maximize_memory_utilization_decorator(
|
283
275
|
parameter_name: str | Sequence[str] = "batch_size",
|
284
276
|
q: int | Sequence[int] = 32,
|
285
277
|
safe_devices: Collection[str] | None = None,
|
286
|
-
) -> Callable[[Callable[P, R]], Callable[P,
|
278
|
+
) -> Callable[[Callable[P, R]], Callable[P, tuple[R, tuple[int, ...]]]]:
|
287
279
|
"""
|
288
280
|
Create decorators to create methods for memory utilization maximization.
|
289
281
|
|
@@ -301,8 +293,8 @@ def maximize_memory_utilization_decorator(
|
|
301
293
|
parameter_names, qs = upgrade_to_sequence(parameter_name, q)
|
302
294
|
|
303
295
|
def decorator_maximize_memory_utilization(
|
304
|
-
func: Callable[P, R]
|
305
|
-
) -> Callable[P,
|
296
|
+
func: Callable[P, R],
|
297
|
+
) -> Callable[P, tuple[R, tuple[int, ...]]]:
|
306
298
|
"""
|
307
299
|
Decorate a function to maximize memory utilization.
|
308
300
|
|
@@ -320,9 +312,7 @@ def maximize_memory_utilization_decorator(
|
|
320
312
|
}
|
321
313
|
|
322
314
|
@functools.wraps(func)
|
323
|
-
def wrapper_maximize_memory_utilization(
|
324
|
-
*args: P.args, **kwargs: P.kwargs
|
325
|
-
) -> Tuple[R, tuple[int, ...]]:
|
315
|
+
def wrapper_maximize_memory_utilization(*args: P.args, **kwargs: P.kwargs) -> tuple[R, tuple[int, ...]]:
|
326
316
|
"""
|
327
317
|
Wrap a function to maximize memory utilization by successive halving.
|
328
318
|
|
@@ -345,11 +335,11 @@ def maximize_memory_utilization_decorator(
|
|
345
335
|
# determine actual max values
|
346
336
|
max_values = [
|
347
337
|
determine_max_value(
|
348
|
-
bound_arguments
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
338
|
+
bound_arguments,
|
339
|
+
name,
|
340
|
+
default_max_value,
|
341
|
+
*args,
|
342
|
+
**kwargs,
|
353
343
|
)
|
354
344
|
for name, default_max_value in default_max_values.items()
|
355
345
|
]
|
@@ -360,13 +350,11 @@ def maximize_memory_utilization_decorator(
|
|
360
350
|
|
361
351
|
while i < len(max_values):
|
362
352
|
while max_values[i] > 0:
|
363
|
-
p_kwargs =
|
364
|
-
|
365
|
-
|
366
|
-
# note: bound_arguments.kwargs is typed as dict, but (silently) immutable (=ignoring updates)...
|
367
|
-
combined_kwargs: P.kwargs = ChainMap(p_kwargs, bound_arguments.kwargs)
|
353
|
+
p_kwargs = dict(zip(parameter_names, max_values))
|
354
|
+
# note: changes to arguments apply to both, .args and .kwargs
|
355
|
+
bound_arguments.arguments.update(p_kwargs)
|
368
356
|
try:
|
369
|
-
return func(*bound_arguments.args, **
|
357
|
+
return func(*bound_arguments.args, **bound_arguments.kwargs), tuple(max_values)
|
370
358
|
except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
|
371
359
|
# raise errors unrelated to out-of-memory
|
372
360
|
if not is_oom_error(error):
|
@@ -391,12 +379,8 @@ def maximize_memory_utilization_decorator(
|
|
391
379
|
i += 1
|
392
380
|
# log memory summary for each CUDA device before raising memory error
|
393
381
|
for device in {d for d in iter_tensor_devices(*args, **kwargs) if d.type == "cuda"}:
|
394
|
-
logger.debug(
|
395
|
-
|
396
|
-
)
|
397
|
-
raise MemoryError(
|
398
|
-
f"Execution did not even succeed with {parameter_names} all equal to 1."
|
399
|
-
) from last_error
|
382
|
+
logger.debug(f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}")
|
383
|
+
raise MemoryError(f"Execution did not even succeed with {parameter_names} all equal to 1.") from last_error
|
400
384
|
|
401
385
|
return wrapper_maximize_memory_utilization
|
402
386
|
|
@@ -455,7 +439,7 @@ class MemoryUtilizationMaximizer:
|
|
455
439
|
parameter_name: str | Sequence[str] = "batch_size",
|
456
440
|
q: int | Sequence[int] = 32,
|
457
441
|
safe_devices: Collection[str] | None = None,
|
458
|
-
hasher:
|
442
|
+
hasher: Callable[[Mapping[str, Any]], int] | None = None,
|
459
443
|
keys: Collection[str] | str | None = None,
|
460
444
|
) -> None:
|
461
445
|
"""
|
@@ -474,7 +458,7 @@ class MemoryUtilizationMaximizer:
|
|
474
458
|
"""
|
475
459
|
self.parameter_names, self.qs = upgrade_to_sequence(parameter_name=parameter_name, q=q)
|
476
460
|
self.safe_devices = safe_devices
|
477
|
-
self.parameter_value: MutableMapping[int, tuple[int, ...]] =
|
461
|
+
self.parameter_value: MutableMapping[int, tuple[int, ...]] = {}
|
478
462
|
if hasher is None:
|
479
463
|
keys = KeyHasher.normalize_keys(keys)
|
480
464
|
intersection = set(keys).intersection(self.parameter_names)
|
torch_max_mem/py.typed
ADDED
File without changes
|
torch_max_mem/version.py
CHANGED
@@ -1,27 +1,25 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
1
|
"""Version information for :mod:`torch_max_mem`.
|
4
2
|
|
5
3
|
Run with ``python -m torch_max_mem.version``
|
6
4
|
"""
|
7
5
|
|
8
6
|
import os
|
9
|
-
from subprocess import CalledProcessError, check_output
|
7
|
+
from subprocess import CalledProcessError, check_output
|
10
8
|
|
11
9
|
__all__ = [
|
12
10
|
"VERSION",
|
13
|
-
"get_version",
|
14
11
|
"get_git_hash",
|
12
|
+
"get_version",
|
15
13
|
]
|
16
14
|
|
17
|
-
VERSION = "0.1.
|
15
|
+
VERSION = "0.1.4"
|
18
16
|
|
19
17
|
|
20
18
|
def get_git_hash() -> str:
|
21
19
|
"""Get the :mod:`torch_max_mem` git hash."""
|
22
20
|
with open(os.devnull, "w") as devnull:
|
23
21
|
try:
|
24
|
-
ret = check_output(
|
22
|
+
ret = check_output(
|
25
23
|
["git", "rev-parse", "HEAD"],
|
26
24
|
cwd=os.path.dirname(__file__),
|
27
25
|
stderr=devnull,
|
@@ -0,0 +1,345 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: torch-max-mem
|
3
|
+
Version: 0.1.4
|
4
|
+
Summary: Maximize memory utilization with PyTorch.
|
5
|
+
Keywords: snekpack,cookiecutter,torch
|
6
|
+
Author: Max Berrendorf
|
7
|
+
Author-email: Max Berrendorf <max.berrendorf@gmail.com>
|
8
|
+
License-File: LICENSE
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
10
|
+
Classifier: Environment :: Console
|
11
|
+
Classifier: Intended Audience :: Developers
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
13
|
+
Classifier: Operating System :: OS Independent
|
14
|
+
Classifier: Framework :: Pytest
|
15
|
+
Classifier: Framework :: tox
|
16
|
+
Classifier: Framework :: Sphinx
|
17
|
+
Classifier: Programming Language :: Python
|
18
|
+
Classifier: Programming Language :: Python :: 3.9
|
19
|
+
Classifier: Programming Language :: Python :: 3.10
|
20
|
+
Classifier: Programming Language :: Python :: 3.11
|
21
|
+
Classifier: Programming Language :: Python :: 3.12
|
22
|
+
Classifier: Programming Language :: Python :: 3.13
|
23
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
24
|
+
Requires-Dist: torch>=2.0
|
25
|
+
Requires-Dist: typing-extensions
|
26
|
+
Maintainer: Max Berrendorf
|
27
|
+
Maintainer-email: Max Berrendorf <max.berrendorf@gmail.com>
|
28
|
+
Requires-Python: >=3.9
|
29
|
+
Project-URL: Bug Tracker, https://github.com/mberr/torch-max-mem/issues
|
30
|
+
Project-URL: Download, https://github.com/mberr/torch-max-mem/releases
|
31
|
+
Project-URL: Homepage, https://github.com/mberr/torch-max-mem
|
32
|
+
Project-URL: Source Code, https://github.com/mberr/torch-max-mem
|
33
|
+
Description-Content-Type: text/markdown
|
34
|
+
|
35
|
+
<!--
|
36
|
+
<p align="center">
|
37
|
+
<img src="https://github.com/mberr/torch-max-mem/raw/main/docs/source/logo.png" height="150">
|
38
|
+
</p>
|
39
|
+
-->
|
40
|
+
|
41
|
+
<h1 align="center">
|
42
|
+
torch-max-mem
|
43
|
+
</h1>
|
44
|
+
|
45
|
+
<p align="center">
|
46
|
+
<a href="https://github.com/mberr/torch-max-mem/actions/workflows/tests.yml">
|
47
|
+
<img alt="Tests" src="https://github.com/mberr/torch-max-mem/actions/workflows/tests.yml/badge.svg" /></a>
|
48
|
+
<a href="https://pypi.org/project/torch_max_mem">
|
49
|
+
<img alt="PyPI" src="https://img.shields.io/pypi/v/torch_max_mem" /></a>
|
50
|
+
<a href="https://pypi.org/project/torch_max_mem">
|
51
|
+
<img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/torch_max_mem" /></a>
|
52
|
+
<a href="https://github.com/mberr/torch-max-mem/blob/main/LICENSE">
|
53
|
+
<img alt="PyPI - License" src="https://img.shields.io/pypi/l/torch_max_mem" /></a>
|
54
|
+
<a href='https://torch_max_mem.readthedocs.io/en/latest/?badge=latest'>
|
55
|
+
<img src='https://readthedocs.org/projects/torch_max_mem/badge/?version=latest' alt='Documentation Status' /></a>
|
56
|
+
<a href="https://codecov.io/gh/mberr/torch-max-mem/branch/main">
|
57
|
+
<img src="https://codecov.io/gh/mberr/torch-max-mem/branch/main/graph/badge.svg" alt="Codecov status" /></a>
|
58
|
+
<a href="https://github.com/cthoyt/cookiecutter-python-package">
|
59
|
+
<img alt="Cookiecutter template from @cthoyt" src="https://img.shields.io/badge/Cookiecutter-snekpack-blue" /></a>
|
60
|
+
<a href="https://github.com/astral-sh/ruff">
|
61
|
+
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff" style="max-width:100%;"></a>
|
62
|
+
<a href="https://github.com/mberr/torch-max-mem/blob/main/.github/CODE_OF_CONDUCT.md">
|
63
|
+
<img src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg" alt="Contributor Covenant"/></a>
|
64
|
+
<!-- uncomment if you archive on zenodo
|
65
|
+
<a href="https://zenodo.org/badge/latestdoi/XXXXXX">
|
66
|
+
<img src="https://zenodo.org/badge/XXXXXX.svg" alt="DOI"></a>
|
67
|
+
-->
|
68
|
+
</p>
|
69
|
+
|
70
|
+
This package provides decorators for memory utilization maximization with
|
71
|
+
PyTorch and CUDA by starting with a maximum parameter size and applying
|
72
|
+
successive halving until no more out-of-memory exception occurs.
|
73
|
+
|
74
|
+
## 💪 Getting Started
|
75
|
+
|
76
|
+
Assume you have a function for batched computation of nearest neighbors using
|
77
|
+
brute-force distance calculation.
|
78
|
+
|
79
|
+
```python
|
80
|
+
import torch
|
81
|
+
|
82
|
+
def knn(x, y, batch_size, k: int = 3):
|
83
|
+
return torch.cat(
|
84
|
+
[
|
85
|
+
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
|
86
|
+
for start in range(0, x.shape[0], batch_size)
|
87
|
+
],
|
88
|
+
dim=0,
|
89
|
+
)
|
90
|
+
```
|
91
|
+
|
92
|
+
With `torch_max_mem` you can decorate this function to reduce the batch size
|
93
|
+
until no more out-of-memory error occurs.
|
94
|
+
|
95
|
+
```python
|
96
|
+
import torch
|
97
|
+
from torch_max_mem import maximize_memory_utilization
|
98
|
+
|
99
|
+
|
100
|
+
@maximize_memory_utilization()
|
101
|
+
def knn(x, y, batch_size, k: int = 3):
|
102
|
+
return torch.cat(
|
103
|
+
[
|
104
|
+
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
|
105
|
+
for start in range(0, x.shape[0], batch_size)
|
106
|
+
],
|
107
|
+
dim=0,
|
108
|
+
)
|
109
|
+
```
|
110
|
+
|
111
|
+
In the code, you can now always pass the largest sensible batch size, e.g.,
|
112
|
+
|
113
|
+
```python
|
114
|
+
x = torch.rand(100, 100, device="cuda")
|
115
|
+
y = torch.rand(200, 100, device="cuda")
|
116
|
+
knn(x, y, batch_size=x.shape[0])
|
117
|
+
```
|
118
|
+
|
119
|
+
## 🚀 Installation
|
120
|
+
|
121
|
+
The most recent release can be installed from
|
122
|
+
[PyPI](https://pypi.org/project/torch_max_mem/) with uv:
|
123
|
+
|
124
|
+
```console
|
125
|
+
uv pip install torch_max_mem
|
126
|
+
```
|
127
|
+
|
128
|
+
or with pip:
|
129
|
+
|
130
|
+
```console
|
131
|
+
python3 -m pip install torch_max_mem
|
132
|
+
```
|
133
|
+
|
134
|
+
The most recent code and data can be installed directly from GitHub with uv:
|
135
|
+
|
136
|
+
```console
|
137
|
+
uv pip install git+https://github.com/mberr/torch-max-mem.git
|
138
|
+
```
|
139
|
+
|
140
|
+
or with pip:
|
141
|
+
|
142
|
+
```console
|
143
|
+
python3 -m pip install git+https://github.com/mberr/torch-max-mem.git
|
144
|
+
```
|
145
|
+
|
146
|
+
## 👐 Contributing
|
147
|
+
|
148
|
+
Contributions, whether filing an issue, making a pull request, or forking, are
|
149
|
+
appreciated. See
|
150
|
+
[CONTRIBUTING.md](https://github.com/mberr/torch-max-mem/blob/master/.github/CONTRIBUTING.md)
|
151
|
+
for more information on getting involved.
|
152
|
+
|
153
|
+
## 👋 Attribution
|
154
|
+
|
155
|
+
Parts of the logic have been developed with
|
156
|
+
[Laurent Vermue](https://github.com/lvermue) for
|
157
|
+
[PyKEEN](https://github.com/pykeen/pykeen).
|
158
|
+
|
159
|
+
### ⚖️ License
|
160
|
+
|
161
|
+
The code in this package is licensed under the MIT License.
|
162
|
+
|
163
|
+
### 🍪 Cookiecutter
|
164
|
+
|
165
|
+
This package was created with
|
166
|
+
[@audreyfeldroy](https://github.com/audreyfeldroy)'s
|
167
|
+
[cookiecutter](https://github.com/cookiecutter/cookiecutter) package using
|
168
|
+
[@cthoyt](https://github.com/cthoyt)'s
|
169
|
+
[cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack)
|
170
|
+
template.
|
171
|
+
|
172
|
+
## 🛠️ For Developers
|
173
|
+
|
174
|
+
<details>
|
175
|
+
<summary>See developer instructions</summary>
|
176
|
+
|
177
|
+
The final section of the README is for if you want to get involved by making a
|
178
|
+
code contribution.
|
179
|
+
|
180
|
+
### Development Installation
|
181
|
+
|
182
|
+
To install in development mode, use the following:
|
183
|
+
|
184
|
+
```console
|
185
|
+
git clone git+https://github.com/mberr/torch-max-mem.git
|
186
|
+
cd snekpack-demo
|
187
|
+
uv pip install -e .
|
188
|
+
```
|
189
|
+
|
190
|
+
Alternatively, install using pip:
|
191
|
+
|
192
|
+
```console
|
193
|
+
python3 -m pip install -e .
|
194
|
+
```
|
195
|
+
|
196
|
+
### Updating Package Boilerplate
|
197
|
+
|
198
|
+
This project uses `cruft` to keep boilerplate (i.e., configuration, contribution
|
199
|
+
guidelines, documentation configuration) up-to-date with the upstream
|
200
|
+
cookiecutter package. Install cruft with either `uv tool install cruft` or
|
201
|
+
`python3 -m pip install cruft` then run:
|
202
|
+
|
203
|
+
```console
|
204
|
+
cruft update
|
205
|
+
```
|
206
|
+
|
207
|
+
More info on Cruft's update command is available
|
208
|
+
[here](https://github.com/cruft/cruft?tab=readme-ov-file#updating-a-project).
|
209
|
+
|
210
|
+
### 🥼 Testing
|
211
|
+
|
212
|
+
After cloning the repository and installing `tox` with
|
213
|
+
`uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, the
|
214
|
+
unit tests in the `tests/` folder can be run reproducibly with:
|
215
|
+
|
216
|
+
```console
|
217
|
+
tox -e py
|
218
|
+
```
|
219
|
+
|
220
|
+
Additionally, these tests are automatically re-run with each commit in a
|
221
|
+
[GitHub Action](https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests).
|
222
|
+
|
223
|
+
### 📖 Building the Documentation
|
224
|
+
|
225
|
+
The documentation can be built locally using the following:
|
226
|
+
|
227
|
+
```console
|
228
|
+
git clone git+https://github.com/mberr/torch-max-mem.git
|
229
|
+
cd snekpack-demo
|
230
|
+
tox -e docs
|
231
|
+
open docs/build/html/index.html
|
232
|
+
```
|
233
|
+
|
234
|
+
The documentation automatically installs the package as well as the `docs` extra
|
235
|
+
specified in the [`pyproject.toml`](pyproject.toml). `sphinx` plugins like
|
236
|
+
`texext` can be added there. Additionally, they need to be added to the
|
237
|
+
`extensions` list in [`docs/source/conf.py`](docs/source/conf.py).
|
238
|
+
|
239
|
+
The documentation can be deployed to [ReadTheDocs](https://readthedocs.io) using
|
240
|
+
[this guide](https://docs.readthedocs.io/en/stable/intro/import-guide.html). The
|
241
|
+
[`.readthedocs.yml`](.readthedocs.yml) YAML file contains all the configuration
|
242
|
+
you'll need. You can also set up continuous integration on GitHub to check not
|
243
|
+
only that Sphinx can build the documentation in an isolated environment (i.e.,
|
244
|
+
with `tox -e docs-test`) but also that
|
245
|
+
[ReadTheDocs can build it too](https://docs.readthedocs.io/en/stable/pull-requests.html).
|
246
|
+
|
247
|
+
#### Configuring ReadTheDocs
|
248
|
+
|
249
|
+
1. Log in to ReadTheDocs with your GitHub account to install the integration at
|
250
|
+
https://readthedocs.org/accounts/login/?next=/dashboard/
|
251
|
+
2. Import your project by navigating to https://readthedocs.org/dashboard/import
|
252
|
+
then clicking the plus icon next to your repository
|
253
|
+
3. You can rename the repository on the next screen using a more stylized name
|
254
|
+
(i.e., with spaces and capital letters)
|
255
|
+
4. Click next, and you're good to go!
|
256
|
+
|
257
|
+
### 📦 Making a Release
|
258
|
+
|
259
|
+
#### Configuring Zenodo
|
260
|
+
|
261
|
+
[Zenodo](https://zenodo.org) is a long-term archival system that assigns a DOI
|
262
|
+
to each release of your package.
|
263
|
+
|
264
|
+
1. Log in to Zenodo via GitHub with this link:
|
265
|
+
https://zenodo.org/oauth/login/github/?next=%2F. This brings you to a page
|
266
|
+
that lists all of your organizations and asks you to approve installing the
|
267
|
+
Zenodo app on GitHub. Click "grant" next to any organizations you want to
|
268
|
+
enable the integration for, then click the big green "approve" button. This
|
269
|
+
step only needs to be done once.
|
270
|
+
2. Navigate to https://zenodo.org/account/settings/github/, which lists all of
|
271
|
+
your GitHub repositories (both in your username and any organizations you
|
272
|
+
enabled). Click the on/off toggle for any relevant repositories. When you
|
273
|
+
make a new repository, you'll have to come back to this
|
274
|
+
|
275
|
+
After these steps, you're ready to go! After you make "release" on GitHub (steps
|
276
|
+
for this are below), you can navigate to
|
277
|
+
https://zenodo.org/account/settings/github/repository/mberr/torch-max-mem to see
|
278
|
+
the DOI for the release and link to the Zenodo record for it.
|
279
|
+
|
280
|
+
#### Registering with the Python Package Index (PyPI)
|
281
|
+
|
282
|
+
You only have to do the following steps once.
|
283
|
+
|
284
|
+
1. Register for an account on the
|
285
|
+
[Python Package Index (PyPI)](https://pypi.org/account/register)
|
286
|
+
2. Navigate to https://pypi.org/manage/account and make sure you have verified
|
287
|
+
your email address. A verification email might not have been sent by default,
|
288
|
+
so you might have to click the "options" dropdown next to your address to get
|
289
|
+
to the "re-send verification email" button
|
290
|
+
3. 2-Factor authentication is required for PyPI since the end of 2023 (see this
|
291
|
+
[blog post from PyPI](https://blog.pypi.org/posts/2023-05-25-securing-pypi-with-2fa/)).
|
292
|
+
This means you have to first issue account recovery codes, then set up
|
293
|
+
2-factor authentication
|
294
|
+
4. Issue an API token from https://pypi.org/manage/account/token
|
295
|
+
|
296
|
+
#### Configuring your machine's connection to PyPI
|
297
|
+
|
298
|
+
You have to do the following steps once per machine.
|
299
|
+
|
300
|
+
```console
|
301
|
+
uv tool install keyring
|
302
|
+
keyring set https://upload.pypi.org/legacy/ __token__
|
303
|
+
keyring set https://test.pypi.org/legacy/ __token__
|
304
|
+
```
|
305
|
+
|
306
|
+
Note that this deprecates previous workflows using `.pypirc`.
|
307
|
+
|
308
|
+
#### Uploading to PyPI
|
309
|
+
|
310
|
+
After installing the package in development mode and installing `tox` with
|
311
|
+
`uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, run
|
312
|
+
the following from the console:
|
313
|
+
|
314
|
+
```console
|
315
|
+
tox -e finish
|
316
|
+
```
|
317
|
+
|
318
|
+
This script does the following:
|
319
|
+
|
320
|
+
1. Uses [bump-my-version](https://github.com/callowayproject/bump-my-version) to
|
321
|
+
switch the version number in the `pyproject.toml`, `CITATION.cff`,
|
322
|
+
`src/torch_max_mem/version.py`, and
|
323
|
+
[`docs/source/conf.py`](docs/source/conf.py) to not have the `-dev` suffix
|
324
|
+
2. Packages the code in both a tar archive and a wheel using
|
325
|
+
[`uv build`](https://docs.astral.sh/uv/guides/publish/#building-your-package)
|
326
|
+
3. Uploads to PyPI using
|
327
|
+
[`uv publish`](https://docs.astral.sh/uv/guides/publish/#publishing-your-package).
|
328
|
+
4. Push to GitHub. You'll need to make a release going with the commit where the
|
329
|
+
version was bumped.
|
330
|
+
5. Bump the version to the next patch. If you made big changes and want to bump
|
331
|
+
the version by minor, you can use `tox -e bumpversion -- minor` after.
|
332
|
+
|
333
|
+
#### Releasing on GitHub
|
334
|
+
|
335
|
+
1. Navigate to https://github.com/mberr/torch-max-mem/releases/new to draft a
|
336
|
+
new release
|
337
|
+
2. Click the "Choose a Tag" dropdown and select the tag corresponding to the
|
338
|
+
release you just made
|
339
|
+
3. Click the "Generate Release Notes" button to get a quick outline of recent
|
340
|
+
changes. Modify the title and description as you see fit
|
341
|
+
4. Click the big green "Publish Release" button
|
342
|
+
|
343
|
+
This will trigger Zenodo to assign a DOI to your release as well.
|
344
|
+
|
345
|
+
</details>
|
@@ -0,0 +1,8 @@
|
|
1
|
+
torch_max_mem/api.py,sha256=0576dc52db63f99c6bc928a4fb1bca2234391bb7d0276d8cbee8efa9c20e0c67,17786
|
2
|
+
torch_max_mem/py.typed,sha256=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855,0
|
3
|
+
torch_max_mem/version.py,sha256=ee7ce9aee46d693e77190563fe86fd5a5544fda7bb7a61d95904d56b7c1c56b3,977
|
4
|
+
torch_max_mem/__init__.py,sha256=1249af2acdc8d0ec62b94f722864cf85872502e9ccbc1cb97ad41b9025232e08,206
|
5
|
+
torch_max_mem-0.1.4.dist-info/licenses/LICENSE,sha256=5163c26353a6aae35675c7d09095eab6c29149653ae83f942c510eab10e14c82,1071
|
6
|
+
torch_max_mem-0.1.4.dist-info/WHEEL,sha256=a77f4251c0f6db962579c616a6b1d7cf825367be8db8f5d245046ce8ff723666,79
|
7
|
+
torch_max_mem-0.1.4.dist-info/METADATA,sha256=29c64143887a54b549d60c8d1520ebbc674d411c1f4f9f123e0ee5338816a729,12714
|
8
|
+
torch_max_mem-0.1.4.dist-info/RECORD,,
|
File without changes
|
@@ -1,215 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: torch-max-mem
|
3
|
-
Version: 0.1.2
|
4
|
-
Summary: Maximize memory utilization with PyTorch.
|
5
|
-
Home-page: https://github.com/mberr/torch-max-mem
|
6
|
-
Download-URL: https://github.com/mberr/torch-max-mem/releases
|
7
|
-
Author: Max Berrendorf
|
8
|
-
Author-email: max.berrendorf@gmail.com
|
9
|
-
Maintainer: Max Berrendorf
|
10
|
-
Maintainer-email: max.berrendorf@gmail.com
|
11
|
-
License: MIT
|
12
|
-
Project-URL: Bug Tracker, https://github.com/mberr/torch-max-mem/issues
|
13
|
-
Project-URL: Source Code, https://github.com/mberr/torch-max-mem
|
14
|
-
Keywords: snekpack,cookiecutter,torch
|
15
|
-
Classifier: Development Status :: 4 - Beta
|
16
|
-
Classifier: Environment :: Console
|
17
|
-
Classifier: Intended Audience :: Developers
|
18
|
-
Classifier: License :: OSI Approved :: MIT License
|
19
|
-
Classifier: Operating System :: OS Independent
|
20
|
-
Classifier: Framework :: Pytest
|
21
|
-
Classifier: Framework :: tox
|
22
|
-
Classifier: Framework :: Sphinx
|
23
|
-
Classifier: Programming Language :: Python
|
24
|
-
Classifier: Programming Language :: Python :: 3.8
|
25
|
-
Classifier: Programming Language :: Python :: 3.9
|
26
|
-
Classifier: Programming Language :: Python :: 3.10
|
27
|
-
Classifier: Programming Language :: Python :: 3.11
|
28
|
-
Classifier: Programming Language :: Python :: 3 :: Only
|
29
|
-
Requires-Python: >=3.8
|
30
|
-
Description-Content-Type: text/markdown
|
31
|
-
License-File: LICENSE
|
32
|
-
Requires-Dist: torch >=2.0
|
33
|
-
Requires-Dist: typing-extensions
|
34
|
-
Provides-Extra: docs
|
35
|
-
Requires-Dist: sphinx <7 ; extra == 'docs'
|
36
|
-
Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
|
37
|
-
Requires-Dist: sphinx-click ; extra == 'docs'
|
38
|
-
Requires-Dist: sphinx-autodoc-typehints ; extra == 'docs'
|
39
|
-
Requires-Dist: sphinx-automodapi ; extra == 'docs'
|
40
|
-
Provides-Extra: formatting
|
41
|
-
Requires-Dist: black ; extra == 'formatting'
|
42
|
-
Requires-Dist: isort ; extra == 'formatting'
|
43
|
-
Provides-Extra: tests
|
44
|
-
Requires-Dist: numpy ; extra == 'tests'
|
45
|
-
Requires-Dist: pytest ; extra == 'tests'
|
46
|
-
Requires-Dist: coverage ; extra == 'tests'
|
47
|
-
|
48
|
-
<!--
|
49
|
-
<p align="center">
|
50
|
-
<img src="https://github.com/mberr/torch-max-mem/raw/main/docs/source/logo.png" height="150">
|
51
|
-
</p>
|
52
|
-
-->
|
53
|
-
|
54
|
-
<h1 align="center">
|
55
|
-
torch-max-mem
|
56
|
-
</h1>
|
57
|
-
|
58
|
-
<p align="center">
|
59
|
-
<a href="https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests">
|
60
|
-
<img alt="Tests" src="https://github.com/mberr/torch-max-mem/workflows/Tests/badge.svg" />
|
61
|
-
</a>
|
62
|
-
<a href="https://github.com/cthoyt/cookiecutter-python-package">
|
63
|
-
<img alt="Cookiecutter template from @cthoyt" src="https://img.shields.io/badge/Cookiecutter-snekpack-blue" />
|
64
|
-
</a>
|
65
|
-
<a href="https://pypi.org/project/torch_max_mem">
|
66
|
-
<img alt="PyPI" src="https://img.shields.io/pypi/v/torch_max_mem" />
|
67
|
-
</a>
|
68
|
-
<a href="https://pypi.org/project/torch_max_mem">
|
69
|
-
<img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/torch_max_mem" />
|
70
|
-
</a>
|
71
|
-
<a href="https://github.com/mberr/torch-max-mem/blob/main/LICENSE">
|
72
|
-
<img alt="PyPI - License" src="https://img.shields.io/pypi/l/torch_max_mem" />
|
73
|
-
</a>
|
74
|
-
<a href='https://torch_max_mem.readthedocs.io/en/latest/?badge=latest'>
|
75
|
-
<img src='https://readthedocs.org/projects/torch_max_mem/badge/?version=latest' alt='Documentation Status' />
|
76
|
-
</a>
|
77
|
-
<a href='https://github.com/psf/black'>
|
78
|
-
<img src='https://img.shields.io/badge/code%20style-black-000000.svg' alt='Code style: black' />
|
79
|
-
</a>
|
80
|
-
</p>
|
81
|
-
|
82
|
-
This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and applying successive halving until no more out-of-memory exception occurs.
|
83
|
-
|
84
|
-
## 💪 Getting Started
|
85
|
-
|
86
|
-
Assume you have a function for batched computation of nearest neighbors using brute-force distance calculation.
|
87
|
-
|
88
|
-
```python
|
89
|
-
import torch
|
90
|
-
|
91
|
-
def knn(x, y, batch_size, k: int = 3):
|
92
|
-
return torch.cat(
|
93
|
-
[
|
94
|
-
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
|
95
|
-
for start in range(0, x.shape[0], batch_size)
|
96
|
-
],
|
97
|
-
dim=0,
|
98
|
-
)
|
99
|
-
```
|
100
|
-
|
101
|
-
With `torch_max_mem` you can decorate this function to reduce the batch size until no more out-of-memory error occurs.
|
102
|
-
|
103
|
-
```python
|
104
|
-
import torch
|
105
|
-
from torch_max_mem import maximize_memory_utilization
|
106
|
-
|
107
|
-
|
108
|
-
@maximize_memory_utilization()
|
109
|
-
def knn(x, y, batch_size, k: int = 3):
|
110
|
-
return torch.cat(
|
111
|
-
[
|
112
|
-
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
|
113
|
-
for start in range(0, x.shape[0], batch_size)
|
114
|
-
],
|
115
|
-
dim=0,
|
116
|
-
)
|
117
|
-
```
|
118
|
-
|
119
|
-
In the code, you can now always pass the largest sensible batch size, e.g.,
|
120
|
-
|
121
|
-
```python
|
122
|
-
x = torch.rand(100, 100, device="cuda")
|
123
|
-
y = torch.rand(200, 100, device="cuda")
|
124
|
-
knn(x, y, batch_size=x.shape[0])
|
125
|
-
```
|
126
|
-
|
127
|
-
## 🚀 Installation
|
128
|
-
|
129
|
-
The most recent release can be installed from
|
130
|
-
[PyPI](https://pypi.org/project/torch_max_mem/) with:
|
131
|
-
|
132
|
-
```bash
|
133
|
-
$ pip install torch_max_mem
|
134
|
-
```
|
135
|
-
|
136
|
-
The most recent code and data can be installed directly from GitHub with:
|
137
|
-
|
138
|
-
```bash
|
139
|
-
$ pip install git+https://github.com/mberr/torch-max-mem.git
|
140
|
-
```
|
141
|
-
|
142
|
-
To install in development mode, use the following:
|
143
|
-
|
144
|
-
```bash
|
145
|
-
$ git clone git+https://github.com/mberr/torch-max-mem.git
|
146
|
-
$ cd torch-max-mem
|
147
|
-
$ pip install -e .
|
148
|
-
```
|
149
|
-
|
150
|
-
## 👐 Contributing
|
151
|
-
|
152
|
-
Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See
|
153
|
-
[CONTRIBUTING.md](https://github.com/mberr/torch-max-mem/blob/master/CONTRIBUTING.md) for more information on getting involved.
|
154
|
-
|
155
|
-
## 👋 Attribution
|
156
|
-
|
157
|
-
Parts of the logic have been developed with [Laurent Vermue](https://github.com/lvermue) for [PyKEEN](https://github.com/pykeen/pykeen).
|
158
|
-
|
159
|
-
|
160
|
-
### ⚖️ License
|
161
|
-
|
162
|
-
The code in this package is licensed under the MIT License.
|
163
|
-
|
164
|
-
### 🍪 Cookiecutter
|
165
|
-
|
166
|
-
This package was created with [@audreyfeldroy](https://github.com/audreyfeldroy)'s
|
167
|
-
[cookiecutter](https://github.com/cookiecutter/cookiecutter) package using [@cthoyt](https://github.com/cthoyt)'s
|
168
|
-
[cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) template.
|
169
|
-
|
170
|
-
## 🛠️ For Developers
|
171
|
-
|
172
|
-
<details>
|
173
|
-
<summary>See developer instrutions</summary>
|
174
|
-
|
175
|
-
|
176
|
-
The final section of the README is for if you want to get involved by making a code contribution.
|
177
|
-
|
178
|
-
### 🥼 Testing
|
179
|
-
|
180
|
-
After cloning the repository and installing `tox` with `pip install tox`, the unit tests in the `tests/` folder can be
|
181
|
-
run reproducibly with:
|
182
|
-
|
183
|
-
```shell
|
184
|
-
$ tox
|
185
|
-
```
|
186
|
-
|
187
|
-
Additionally, these tests are automatically re-run with each commit in a [GitHub Action](https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests).
|
188
|
-
|
189
|
-
### 📖 Building the Documentation
|
190
|
-
|
191
|
-
```shell
|
192
|
-
$ tox -e docs
|
193
|
-
```
|
194
|
-
|
195
|
-
### 📦 Making a Release
|
196
|
-
|
197
|
-
After installing the package in development mode and installing
|
198
|
-
`tox` with `pip install tox`, the commands for making a new release are contained within the `finish` environment
|
199
|
-
in `tox.ini`. Run the following from the shell:
|
200
|
-
|
201
|
-
```shell
|
202
|
-
$ tox -e finish
|
203
|
-
```
|
204
|
-
|
205
|
-
This script does the following:
|
206
|
-
|
207
|
-
1. Uses [Bump2Version](https://github.com/c4urself/bump2version) to switch the version number in the `setup.cfg` and
|
208
|
-
`src/torch_max_mem/version.py` to not have the `-dev` suffix
|
209
|
-
2. Packages the code in both a tar archive and a wheel
|
210
|
-
3. Uploads to PyPI using `twine`. Be sure to have a `.pypirc` file configured to avoid the need for manual input at this
|
211
|
-
step
|
212
|
-
4. Push to GitHub. You'll need to make a release going with the commit where the version was bumped.
|
213
|
-
5. Bump the version to the next patch. If you made big changes and want to bump the version by minor, you can
|
214
|
-
use `tox -e bumpversion minor` after.
|
215
|
-
</details>
|
@@ -1,8 +0,0 @@
|
|
1
|
-
torch_max_mem/__init__.py,sha256=7XoGfOMupwSZ3HWFXhwQB7ysIhPZeERbVRfzCDxaADw,230
|
2
|
-
torch_max_mem/api.py,sha256=bTaDhXQIO8bv7xzL9TU-0IqqYeCRk66C5Ms70ZgcnFI,18228
|
3
|
-
torch_max_mem/version.py,sha256=Bbj1Rx8IDqBJkeF-OjEagjvqA8K8k9XaQPSf5XQmVYU,1035
|
4
|
-
torch_max_mem-0.1.2.dist-info/LICENSE,sha256=UWPCY1OmquNWdcfQkJXqtsKRSWU66D-ULFEOqxDhTII,1071
|
5
|
-
torch_max_mem-0.1.2.dist-info/METADATA,sha256=ENMihj_gxuWWm7FbY5x_nc5NGlOpBOGlQQVqv3Y-Ew4,7401
|
6
|
-
torch_max_mem-0.1.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
7
|
-
torch_max_mem-0.1.2.dist-info/top_level.txt,sha256=ztqqyZB7neLi8zWiWdaaKytNuRkZ7SzO7YDUQ4ZJR3U,14
|
8
|
-
torch_max_mem-0.1.2.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
torch_max_mem
|