torch-max-mem 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torch_max_mem/__init__.py CHANGED
@@ -1,9 +1,8 @@
1
- # -*- coding: utf-8 -*-
2
-
3
1
  """Maximize memory utilization with PyTorch."""
2
+
4
3
  from .api import MemoryUtilizationMaximizer, maximize_memory_utilization
5
4
 
6
5
  __all__ = [
7
- "maximize_memory_utilization",
8
6
  "MemoryUtilizationMaximizer",
7
+ "maximize_memory_utilization",
9
8
  ]
torch_max_mem/api.py CHANGED
@@ -1,5 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
-
3
1
  """
4
2
  This module contains the public API.
5
3
 
@@ -9,6 +7,7 @@ Assume you have a function for batched computation of nearest neighbors using br
9
7
 
10
8
  import torch
11
9
 
10
+
12
11
  def knn(x, y, batch_size, k: int = 3):
13
12
  return torch.cat(
14
13
  [
@@ -26,6 +25,7 @@ out-of-memory error occurs.
26
25
  import torch
27
26
  from torch_max_mem import maximize_memory_utilization
28
27
 
28
+
29
29
  @maximize_memory_utilization()
30
30
  def knn(x, y, batch_size, k: int = 3):
31
31
  return torch.cat(
@@ -45,6 +45,7 @@ In the code, you can now always pass the largest sensible batch size, e.g.,
45
45
  y = torch.rand(200, 100, device="cuda")
46
46
  knn(x, y, batch_size=x.shape[0])
47
47
  """
48
+
48
49
  # cf. https://gist.github.com/mberr/c37a8068b38cabc98228db2cbe358043
49
50
  from __future__ import annotations
50
51
 
@@ -52,16 +53,10 @@ import functools
52
53
  import inspect
53
54
  import itertools
54
55
  import logging
56
+ from collections.abc import Collection, Iterable, Mapping, MutableMapping, Sequence
55
57
  from typing import (
56
58
  Any,
57
59
  Callable,
58
- Collection,
59
- Iterable,
60
- Mapping,
61
- MutableMapping,
62
- Optional,
63
- Sequence,
64
- Tuple,
65
60
  TypeVar,
66
61
  )
67
62
 
@@ -98,9 +93,7 @@ def upgrade_to_sequence(
98
93
  when the (inferred) length of q and parameter_name do not match
99
94
  """
100
95
  # normalize parameter name
101
- parameter_names = (
102
- (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name)
103
- )
96
+ parameter_names = (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name)
104
97
  q = (q,) if isinstance(q, int) else tuple(q)
105
98
  q = q * len(parameter_names) if len(q) == 1 else q
106
99
  if len(q) != len(parameter_names):
@@ -127,7 +120,7 @@ def determine_default_max_value(
127
120
  :raises ValueError:
128
121
  when the function does not have a parameter of the given name
129
122
  """
130
- if parameter_name not in signature.parameters.keys():
123
+ if parameter_name not in signature.parameters:
131
124
  raise ValueError(f"{func} does not have a parameter {parameter_name}.")
132
125
  _parameter = signature.parameters[parameter_name]
133
126
  if _parameter.annotation != inspect.Parameter.empty and _parameter.annotation not in (
@@ -145,10 +138,10 @@ def determine_default_max_value(
145
138
 
146
139
  def determine_max_value(
147
140
  bound_arguments: inspect.BoundArguments,
148
- args: P.args,
149
- kwargs: P.kwargs,
150
141
  parameter_name: str,
151
142
  default_max_value: int | Callable[P, int] | None,
143
+ *args: P.args,
144
+ **kwargs: P.kwargs,
152
145
  ) -> int:
153
146
  """
154
147
  Either use the provided value, or the default maximum value.
@@ -203,12 +196,13 @@ ADDITIONAL_OOM_ERROR_INFIXES = {
203
196
  def iter_tensor_devices(*args: Any, **kwargs: Any) -> Iterable[torch.device]:
204
197
  """Iterate over tensors' devices (may contain duplicates)."""
205
198
  for obj in itertools.chain(args, kwargs.values()):
206
- if torch.is_tensor(obj):
207
- assert isinstance(obj, torch.Tensor)
199
+ if isinstance(obj, torch.Tensor):
208
200
  yield obj.device
209
201
 
210
202
 
211
- def create_tensor_checker(safe_devices: Collection[str] | None = None) -> Callable[P, None]:
203
+ def create_tensor_checker(
204
+ safe_devices: Collection[str] | None = None,
205
+ ) -> Callable[P, None]:
212
206
  """
213
207
  Create a function that warns when tensors are on any device that is not considered safe.
214
208
 
@@ -273,16 +267,15 @@ def is_oom_error(error: BaseException) -> bool:
273
267
  return True
274
268
  if not isinstance(error, RuntimeError):
275
269
  return False
276
- if not error.args:
277
- return False
278
- return any(infix in error.args[0] for infix in ADDITIONAL_OOM_ERROR_INFIXES)
270
+ message = str(error)
271
+ return any(infix in message for infix in ADDITIONAL_OOM_ERROR_INFIXES)
279
272
 
280
273
 
281
274
  def maximize_memory_utilization_decorator(
282
275
  parameter_name: str | Sequence[str] = "batch_size",
283
276
  q: int | Sequence[int] = 32,
284
277
  safe_devices: Collection[str] | None = None,
285
- ) -> Callable[[Callable[P, R]], Callable[P, Tuple[R, tuple[int, ...]]]]:
278
+ ) -> Callable[[Callable[P, R]], Callable[P, tuple[R, tuple[int, ...]]]]:
286
279
  """
287
280
  Create decorators to create methods for memory utilization maximization.
288
281
 
@@ -300,8 +293,8 @@ def maximize_memory_utilization_decorator(
300
293
  parameter_names, qs = upgrade_to_sequence(parameter_name, q)
301
294
 
302
295
  def decorator_maximize_memory_utilization(
303
- func: Callable[P, R]
304
- ) -> Callable[P, Tuple[R, tuple[int, ...]]]:
296
+ func: Callable[P, R],
297
+ ) -> Callable[P, tuple[R, tuple[int, ...]]]:
305
298
  """
306
299
  Decorate a function to maximize memory utilization.
307
300
 
@@ -319,9 +312,7 @@ def maximize_memory_utilization_decorator(
319
312
  }
320
313
 
321
314
  @functools.wraps(func)
322
- def wrapper_maximize_memory_utilization(
323
- *args: P.args, **kwargs: P.kwargs
324
- ) -> Tuple[R, tuple[int, ...]]:
315
+ def wrapper_maximize_memory_utilization(*args: P.args, **kwargs: P.kwargs) -> tuple[R, tuple[int, ...]]:
325
316
  """
326
317
  Wrap a function to maximize memory utilization by successive halving.
327
318
 
@@ -344,11 +335,11 @@ def maximize_memory_utilization_decorator(
344
335
  # determine actual max values
345
336
  max_values = [
346
337
  determine_max_value(
347
- bound_arguments=bound_arguments,
348
- args=args,
349
- kwargs=kwargs,
350
- parameter_name=name,
351
- default_max_value=default_max_value,
338
+ bound_arguments,
339
+ name,
340
+ default_max_value,
341
+ *args,
342
+ **kwargs,
352
343
  )
353
344
  for name, default_max_value in default_max_values.items()
354
345
  ]
@@ -359,15 +350,11 @@ def maximize_memory_utilization_decorator(
359
350
 
360
351
  while i < len(max_values):
361
352
  while max_values[i] > 0:
362
- p_kwargs = {
363
- name: max_value for name, max_value in zip(parameter_names, max_values)
364
- }
353
+ p_kwargs = dict(zip(parameter_names, max_values))
365
354
  # note: changes to arguments apply to both, .args and .kwargs
366
355
  bound_arguments.arguments.update(p_kwargs)
367
356
  try:
368
- return func(*bound_arguments.args, **bound_arguments.kwargs), tuple(
369
- max_values
370
- )
357
+ return func(*bound_arguments.args, **bound_arguments.kwargs), tuple(max_values)
371
358
  except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
372
359
  # raise errors unrelated to out-of-memory
373
360
  if not is_oom_error(error):
@@ -392,12 +379,8 @@ def maximize_memory_utilization_decorator(
392
379
  i += 1
393
380
  # log memory summary for each CUDA device before raising memory error
394
381
  for device in {d for d in iter_tensor_devices(*args, **kwargs) if d.type == "cuda"}:
395
- logger.debug(
396
- f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}"
397
- )
398
- raise MemoryError(
399
- f"Execution did not even succeed with {parameter_names} all equal to 1."
400
- ) from last_error
382
+ logger.debug(f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}")
383
+ raise MemoryError(f"Execution did not even succeed with {parameter_names} all equal to 1.") from last_error
401
384
 
402
385
  return wrapper_maximize_memory_utilization
403
386
 
@@ -456,7 +439,7 @@ class MemoryUtilizationMaximizer:
456
439
  parameter_name: str | Sequence[str] = "batch_size",
457
440
  q: int | Sequence[int] = 32,
458
441
  safe_devices: Collection[str] | None = None,
459
- hasher: Optional[Callable[[Mapping[str, Any]], int]] = None,
442
+ hasher: Callable[[Mapping[str, Any]], int] | None = None,
460
443
  keys: Collection[str] | str | None = None,
461
444
  ) -> None:
462
445
  """
@@ -475,7 +458,7 @@ class MemoryUtilizationMaximizer:
475
458
  """
476
459
  self.parameter_names, self.qs = upgrade_to_sequence(parameter_name=parameter_name, q=q)
477
460
  self.safe_devices = safe_devices
478
- self.parameter_value: MutableMapping[int, tuple[int, ...]] = dict()
461
+ self.parameter_value: MutableMapping[int, tuple[int, ...]] = {}
479
462
  if hasher is None:
480
463
  keys = KeyHasher.normalize_keys(keys)
481
464
  intersection = set(keys).intersection(self.parameter_names)
torch_max_mem/py.typed ADDED
File without changes
torch_max_mem/version.py CHANGED
@@ -1,27 +1,25 @@
1
- # -*- coding: utf-8 -*-
2
-
3
1
  """Version information for :mod:`torch_max_mem`.
4
2
 
5
3
  Run with ``python -m torch_max_mem.version``
6
4
  """
7
5
 
8
6
  import os
9
- from subprocess import CalledProcessError, check_output # noqa: S404
7
+ from subprocess import CalledProcessError, check_output
10
8
 
11
9
  __all__ = [
12
10
  "VERSION",
13
- "get_version",
14
11
  "get_git_hash",
12
+ "get_version",
15
13
  ]
16
14
 
17
- VERSION = "0.1.3"
15
+ VERSION = "0.1.4"
18
16
 
19
17
 
20
18
  def get_git_hash() -> str:
21
19
  """Get the :mod:`torch_max_mem` git hash."""
22
20
  with open(os.devnull, "w") as devnull:
23
21
  try:
24
- ret = check_output( # noqa: S603,S607
22
+ ret = check_output(
25
23
  ["git", "rev-parse", "HEAD"],
26
24
  cwd=os.path.dirname(__file__),
27
25
  stderr=devnull,
@@ -0,0 +1,345 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-max-mem
3
+ Version: 0.1.4
4
+ Summary: Maximize memory utilization with PyTorch.
5
+ Keywords: snekpack,cookiecutter,torch
6
+ Author: Max Berrendorf
7
+ Author-email: Max Berrendorf <max.berrendorf@gmail.com>
8
+ License-File: LICENSE
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Environment :: Console
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Framework :: Pytest
15
+ Classifier: Framework :: tox
16
+ Classifier: Framework :: Sphinx
17
+ Classifier: Programming Language :: Python
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Programming Language :: Python :: 3.13
23
+ Classifier: Programming Language :: Python :: 3 :: Only
24
+ Requires-Dist: torch>=2.0
25
+ Requires-Dist: typing-extensions
26
+ Maintainer: Max Berrendorf
27
+ Maintainer-email: Max Berrendorf <max.berrendorf@gmail.com>
28
+ Requires-Python: >=3.9
29
+ Project-URL: Bug Tracker, https://github.com/mberr/torch-max-mem/issues
30
+ Project-URL: Download, https://github.com/mberr/torch-max-mem/releases
31
+ Project-URL: Homepage, https://github.com/mberr/torch-max-mem
32
+ Project-URL: Source Code, https://github.com/mberr/torch-max-mem
33
+ Description-Content-Type: text/markdown
34
+
35
+ <!--
36
+ <p align="center">
37
+ <img src="https://github.com/mberr/torch-max-mem/raw/main/docs/source/logo.png" height="150">
38
+ </p>
39
+ -->
40
+
41
+ <h1 align="center">
42
+ torch-max-mem
43
+ </h1>
44
+
45
+ <p align="center">
46
+ <a href="https://github.com/mberr/torch-max-mem/actions/workflows/tests.yml">
47
+ <img alt="Tests" src="https://github.com/mberr/torch-max-mem/actions/workflows/tests.yml/badge.svg" /></a>
48
+ <a href="https://pypi.org/project/torch_max_mem">
49
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/torch_max_mem" /></a>
50
+ <a href="https://pypi.org/project/torch_max_mem">
51
+ <img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/torch_max_mem" /></a>
52
+ <a href="https://github.com/mberr/torch-max-mem/blob/main/LICENSE">
53
+ <img alt="PyPI - License" src="https://img.shields.io/pypi/l/torch_max_mem" /></a>
54
+ <a href='https://torch_max_mem.readthedocs.io/en/latest/?badge=latest'>
55
+ <img src='https://readthedocs.org/projects/torch_max_mem/badge/?version=latest' alt='Documentation Status' /></a>
56
+ <a href="https://codecov.io/gh/mberr/torch-max-mem/branch/main">
57
+ <img src="https://codecov.io/gh/mberr/torch-max-mem/branch/main/graph/badge.svg" alt="Codecov status" /></a>
58
+ <a href="https://github.com/cthoyt/cookiecutter-python-package">
59
+ <img alt="Cookiecutter template from @cthoyt" src="https://img.shields.io/badge/Cookiecutter-snekpack-blue" /></a>
60
+ <a href="https://github.com/astral-sh/ruff">
61
+ <img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff" style="max-width:100%;"></a>
62
+ <a href="https://github.com/mberr/torch-max-mem/blob/main/.github/CODE_OF_CONDUCT.md">
63
+ <img src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg" alt="Contributor Covenant"/></a>
64
+ <!-- uncomment if you archive on zenodo
65
+ <a href="https://zenodo.org/badge/latestdoi/XXXXXX">
66
+ <img src="https://zenodo.org/badge/XXXXXX.svg" alt="DOI"></a>
67
+ -->
68
+ </p>
69
+
70
+ This package provides decorators for memory utilization maximization with
71
+ PyTorch and CUDA by starting with a maximum parameter size and applying
72
+ successive halving until no more out-of-memory exception occurs.
73
+
74
+ ## 💪 Getting Started
75
+
76
+ Assume you have a function for batched computation of nearest neighbors using
77
+ brute-force distance calculation.
78
+
79
+ ```python
80
+ import torch
81
+
82
+ def knn(x, y, batch_size, k: int = 3):
83
+ return torch.cat(
84
+ [
85
+ torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
86
+ for start in range(0, x.shape[0], batch_size)
87
+ ],
88
+ dim=0,
89
+ )
90
+ ```
91
+
92
+ With `torch_max_mem` you can decorate this function to reduce the batch size
93
+ until no more out-of-memory error occurs.
94
+
95
+ ```python
96
+ import torch
97
+ from torch_max_mem import maximize_memory_utilization
98
+
99
+
100
+ @maximize_memory_utilization()
101
+ def knn(x, y, batch_size, k: int = 3):
102
+ return torch.cat(
103
+ [
104
+ torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
105
+ for start in range(0, x.shape[0], batch_size)
106
+ ],
107
+ dim=0,
108
+ )
109
+ ```
110
+
111
+ In the code, you can now always pass the largest sensible batch size, e.g.,
112
+
113
+ ```python
114
+ x = torch.rand(100, 100, device="cuda")
115
+ y = torch.rand(200, 100, device="cuda")
116
+ knn(x, y, batch_size=x.shape[0])
117
+ ```
118
+
119
+ ## 🚀 Installation
120
+
121
+ The most recent release can be installed from
122
+ [PyPI](https://pypi.org/project/torch_max_mem/) with uv:
123
+
124
+ ```console
125
+ uv pip install torch_max_mem
126
+ ```
127
+
128
+ or with pip:
129
+
130
+ ```console
131
+ python3 -m pip install torch_max_mem
132
+ ```
133
+
134
+ The most recent code and data can be installed directly from GitHub with uv:
135
+
136
+ ```console
137
+ uv pip install git+https://github.com/mberr/torch-max-mem.git
138
+ ```
139
+
140
+ or with pip:
141
+
142
+ ```console
143
+ python3 -m pip install git+https://github.com/mberr/torch-max-mem.git
144
+ ```
145
+
146
+ ## 👐 Contributing
147
+
148
+ Contributions, whether filing an issue, making a pull request, or forking, are
149
+ appreciated. See
150
+ [CONTRIBUTING.md](https://github.com/mberr/torch-max-mem/blob/master/.github/CONTRIBUTING.md)
151
+ for more information on getting involved.
152
+
153
+ ## 👋 Attribution
154
+
155
+ Parts of the logic have been developed with
156
+ [Laurent Vermue](https://github.com/lvermue) for
157
+ [PyKEEN](https://github.com/pykeen/pykeen).
158
+
159
+ ### ⚖️ License
160
+
161
+ The code in this package is licensed under the MIT License.
162
+
163
+ ### 🍪 Cookiecutter
164
+
165
+ This package was created with
166
+ [@audreyfeldroy](https://github.com/audreyfeldroy)'s
167
+ [cookiecutter](https://github.com/cookiecutter/cookiecutter) package using
168
+ [@cthoyt](https://github.com/cthoyt)'s
169
+ [cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack)
170
+ template.
171
+
172
+ ## 🛠️ For Developers
173
+
174
+ <details>
175
+ <summary>See developer instructions</summary>
176
+
177
+ The final section of the README is for if you want to get involved by making a
178
+ code contribution.
179
+
180
+ ### Development Installation
181
+
182
+ To install in development mode, use the following:
183
+
184
+ ```console
185
+ git clone git+https://github.com/mberr/torch-max-mem.git
186
+ cd snekpack-demo
187
+ uv pip install -e .
188
+ ```
189
+
190
+ Alternatively, install using pip:
191
+
192
+ ```console
193
+ python3 -m pip install -e .
194
+ ```
195
+
196
+ ### Updating Package Boilerplate
197
+
198
+ This project uses `cruft` to keep boilerplate (i.e., configuration, contribution
199
+ guidelines, documentation configuration) up-to-date with the upstream
200
+ cookiecutter package. Install cruft with either `uv tool install cruft` or
201
+ `python3 -m pip install cruft` then run:
202
+
203
+ ```console
204
+ cruft update
205
+ ```
206
+
207
+ More info on Cruft's update command is available
208
+ [here](https://github.com/cruft/cruft?tab=readme-ov-file#updating-a-project).
209
+
210
+ ### 🥼 Testing
211
+
212
+ After cloning the repository and installing `tox` with
213
+ `uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, the
214
+ unit tests in the `tests/` folder can be run reproducibly with:
215
+
216
+ ```console
217
+ tox -e py
218
+ ```
219
+
220
+ Additionally, these tests are automatically re-run with each commit in a
221
+ [GitHub Action](https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests).
222
+
223
+ ### 📖 Building the Documentation
224
+
225
+ The documentation can be built locally using the following:
226
+
227
+ ```console
228
+ git clone git+https://github.com/mberr/torch-max-mem.git
229
+ cd snekpack-demo
230
+ tox -e docs
231
+ open docs/build/html/index.html
232
+ ```
233
+
234
+ The documentation automatically installs the package as well as the `docs` extra
235
+ specified in the [`pyproject.toml`](pyproject.toml). `sphinx` plugins like
236
+ `texext` can be added there. Additionally, they need to be added to the
237
+ `extensions` list in [`docs/source/conf.py`](docs/source/conf.py).
238
+
239
+ The documentation can be deployed to [ReadTheDocs](https://readthedocs.io) using
240
+ [this guide](https://docs.readthedocs.io/en/stable/intro/import-guide.html). The
241
+ [`.readthedocs.yml`](.readthedocs.yml) YAML file contains all the configuration
242
+ you'll need. You can also set up continuous integration on GitHub to check not
243
+ only that Sphinx can build the documentation in an isolated environment (i.e.,
244
+ with `tox -e docs-test`) but also that
245
+ [ReadTheDocs can build it too](https://docs.readthedocs.io/en/stable/pull-requests.html).
246
+
247
+ #### Configuring ReadTheDocs
248
+
249
+ 1. Log in to ReadTheDocs with your GitHub account to install the integration at
250
+ https://readthedocs.org/accounts/login/?next=/dashboard/
251
+ 2. Import your project by navigating to https://readthedocs.org/dashboard/import
252
+ then clicking the plus icon next to your repository
253
+ 3. You can rename the repository on the next screen using a more stylized name
254
+ (i.e., with spaces and capital letters)
255
+ 4. Click next, and you're good to go!
256
+
257
+ ### 📦 Making a Release
258
+
259
+ #### Configuring Zenodo
260
+
261
+ [Zenodo](https://zenodo.org) is a long-term archival system that assigns a DOI
262
+ to each release of your package.
263
+
264
+ 1. Log in to Zenodo via GitHub with this link:
265
+ https://zenodo.org/oauth/login/github/?next=%2F. This brings you to a page
266
+ that lists all of your organizations and asks you to approve installing the
267
+ Zenodo app on GitHub. Click "grant" next to any organizations you want to
268
+ enable the integration for, then click the big green "approve" button. This
269
+ step only needs to be done once.
270
+ 2. Navigate to https://zenodo.org/account/settings/github/, which lists all of
271
+ your GitHub repositories (both in your username and any organizations you
272
+ enabled). Click the on/off toggle for any relevant repositories. When you
273
+ make a new repository, you'll have to come back to this
274
+
275
+ After these steps, you're ready to go! After you make "release" on GitHub (steps
276
+ for this are below), you can navigate to
277
+ https://zenodo.org/account/settings/github/repository/mberr/torch-max-mem to see
278
+ the DOI for the release and link to the Zenodo record for it.
279
+
280
+ #### Registering with the Python Package Index (PyPI)
281
+
282
+ You only have to do the following steps once.
283
+
284
+ 1. Register for an account on the
285
+ [Python Package Index (PyPI)](https://pypi.org/account/register)
286
+ 2. Navigate to https://pypi.org/manage/account and make sure you have verified
287
+ your email address. A verification email might not have been sent by default,
288
+ so you might have to click the "options" dropdown next to your address to get
289
+ to the "re-send verification email" button
290
+ 3. 2-Factor authentication is required for PyPI since the end of 2023 (see this
291
+ [blog post from PyPI](https://blog.pypi.org/posts/2023-05-25-securing-pypi-with-2fa/)).
292
+ This means you have to first issue account recovery codes, then set up
293
+ 2-factor authentication
294
+ 4. Issue an API token from https://pypi.org/manage/account/token
295
+
296
+ #### Configuring your machine's connection to PyPI
297
+
298
+ You have to do the following steps once per machine.
299
+
300
+ ```console
301
+ uv tool install keyring
302
+ keyring set https://upload.pypi.org/legacy/ __token__
303
+ keyring set https://test.pypi.org/legacy/ __token__
304
+ ```
305
+
306
+ Note that this deprecates previous workflows using `.pypirc`.
307
+
308
+ #### Uploading to PyPI
309
+
310
+ After installing the package in development mode and installing `tox` with
311
+ `uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, run
312
+ the following from the console:
313
+
314
+ ```console
315
+ tox -e finish
316
+ ```
317
+
318
+ This script does the following:
319
+
320
+ 1. Uses [bump-my-version](https://github.com/callowayproject/bump-my-version) to
321
+ switch the version number in the `pyproject.toml`, `CITATION.cff`,
322
+ `src/torch_max_mem/version.py`, and
323
+ [`docs/source/conf.py`](docs/source/conf.py) to not have the `-dev` suffix
324
+ 2. Packages the code in both a tar archive and a wheel using
325
+ [`uv build`](https://docs.astral.sh/uv/guides/publish/#building-your-package)
326
+ 3. Uploads to PyPI using
327
+ [`uv publish`](https://docs.astral.sh/uv/guides/publish/#publishing-your-package).
328
+ 4. Push to GitHub. You'll need to make a release going with the commit where the
329
+ version was bumped.
330
+ 5. Bump the version to the next patch. If you made big changes and want to bump
331
+ the version by minor, you can use `tox -e bumpversion -- minor` after.
332
+
333
+ #### Releasing on GitHub
334
+
335
+ 1. Navigate to https://github.com/mberr/torch-max-mem/releases/new to draft a
336
+ new release
337
+ 2. Click the "Choose a Tag" dropdown and select the tag corresponding to the
338
+ release you just made
339
+ 3. Click the "Generate Release Notes" button to get a quick outline of recent
340
+ changes. Modify the title and description as you see fit
341
+ 4. Click the big green "Publish Release" button
342
+
343
+ This will trigger Zenodo to assign a DOI to your release as well.
344
+
345
+ </details>
@@ -0,0 +1,8 @@
1
+ torch_max_mem/api.py,sha256=0576dc52db63f99c6bc928a4fb1bca2234391bb7d0276d8cbee8efa9c20e0c67,17786
2
+ torch_max_mem/py.typed,sha256=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855,0
3
+ torch_max_mem/version.py,sha256=ee7ce9aee46d693e77190563fe86fd5a5544fda7bb7a61d95904d56b7c1c56b3,977
4
+ torch_max_mem/__init__.py,sha256=1249af2acdc8d0ec62b94f722864cf85872502e9ccbc1cb97ad41b9025232e08,206
5
+ torch_max_mem-0.1.4.dist-info/licenses/LICENSE,sha256=5163c26353a6aae35675c7d09095eab6c29149653ae83f942c510eab10e14c82,1071
6
+ torch_max_mem-0.1.4.dist-info/WHEEL,sha256=a77f4251c0f6db962579c616a6b1d7cf825367be8db8f5d245046ce8ff723666,79
7
+ torch_max_mem-0.1.4.dist-info/METADATA,sha256=29c64143887a54b549d60c8d1520ebbc674d411c1f4f9f123e0ee5338816a729,12714
8
+ torch_max_mem-0.1.4.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.6.14
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -1,215 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: torch-max-mem
3
- Version: 0.1.3
4
- Summary: Maximize memory utilization with PyTorch.
5
- Home-page: https://github.com/mberr/torch-max-mem
6
- Download-URL: https://github.com/mberr/torch-max-mem/releases
7
- Author: Max Berrendorf
8
- Author-email: max.berrendorf@gmail.com
9
- Maintainer: Max Berrendorf
10
- Maintainer-email: max.berrendorf@gmail.com
11
- License: MIT
12
- Project-URL: Bug Tracker, https://github.com/mberr/torch-max-mem/issues
13
- Project-URL: Source Code, https://github.com/mberr/torch-max-mem
14
- Keywords: snekpack,cookiecutter,torch
15
- Classifier: Development Status :: 4 - Beta
16
- Classifier: Environment :: Console
17
- Classifier: Intended Audience :: Developers
18
- Classifier: License :: OSI Approved :: MIT License
19
- Classifier: Operating System :: OS Independent
20
- Classifier: Framework :: Pytest
21
- Classifier: Framework :: tox
22
- Classifier: Framework :: Sphinx
23
- Classifier: Programming Language :: Python
24
- Classifier: Programming Language :: Python :: 3.8
25
- Classifier: Programming Language :: Python :: 3.9
26
- Classifier: Programming Language :: Python :: 3.10
27
- Classifier: Programming Language :: Python :: 3.11
28
- Classifier: Programming Language :: Python :: 3 :: Only
29
- Requires-Python: >=3.8
30
- Description-Content-Type: text/markdown
31
- License-File: LICENSE
32
- Requires-Dist: torch >=2.0
33
- Requires-Dist: typing-extensions
34
- Provides-Extra: docs
35
- Requires-Dist: sphinx <7 ; extra == 'docs'
36
- Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
37
- Requires-Dist: sphinx-click ; extra == 'docs'
38
- Requires-Dist: sphinx-autodoc-typehints ; extra == 'docs'
39
- Requires-Dist: sphinx-automodapi ; extra == 'docs'
40
- Provides-Extra: formatting
41
- Requires-Dist: black ; extra == 'formatting'
42
- Requires-Dist: isort ; extra == 'formatting'
43
- Provides-Extra: tests
44
- Requires-Dist: numpy ; extra == 'tests'
45
- Requires-Dist: pytest ; extra == 'tests'
46
- Requires-Dist: coverage ; extra == 'tests'
47
-
48
- <!--
49
- <p align="center">
50
- <img src="https://github.com/mberr/torch-max-mem/raw/main/docs/source/logo.png" height="150">
51
- </p>
52
- -->
53
-
54
- <h1 align="center">
55
- torch-max-mem
56
- </h1>
57
-
58
- <p align="center">
59
- <a href="https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests">
60
- <img alt="Tests" src="https://github.com/mberr/torch-max-mem/workflows/Tests/badge.svg" />
61
- </a>
62
- <a href="https://github.com/cthoyt/cookiecutter-python-package">
63
- <img alt="Cookiecutter template from @cthoyt" src="https://img.shields.io/badge/Cookiecutter-snekpack-blue" />
64
- </a>
65
- <a href="https://pypi.org/project/torch_max_mem">
66
- <img alt="PyPI" src="https://img.shields.io/pypi/v/torch_max_mem" />
67
- </a>
68
- <a href="https://pypi.org/project/torch_max_mem">
69
- <img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/torch_max_mem" />
70
- </a>
71
- <a href="https://github.com/mberr/torch-max-mem/blob/main/LICENSE">
72
- <img alt="PyPI - License" src="https://img.shields.io/pypi/l/torch_max_mem" />
73
- </a>
74
- <a href='https://torch_max_mem.readthedocs.io/en/latest/?badge=latest'>
75
- <img src='https://readthedocs.org/projects/torch_max_mem/badge/?version=latest' alt='Documentation Status' />
76
- </a>
77
- <a href='https://github.com/psf/black'>
78
- <img src='https://img.shields.io/badge/code%20style-black-000000.svg' alt='Code style: black' />
79
- </a>
80
- </p>
81
-
82
- This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and applying successive halving until no more out-of-memory exception occurs.
83
-
84
- ## 💪 Getting Started
85
-
86
- Assume you have a function for batched computation of nearest neighbors using brute-force distance calculation.
87
-
88
- ```python
89
- import torch
90
-
91
- def knn(x, y, batch_size, k: int = 3):
92
- return torch.cat(
93
- [
94
- torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
95
- for start in range(0, x.shape[0], batch_size)
96
- ],
97
- dim=0,
98
- )
99
- ```
100
-
101
- With `torch_max_mem` you can decorate this function to reduce the batch size until no more out-of-memory error occurs.
102
-
103
- ```python
104
- import torch
105
- from torch_max_mem import maximize_memory_utilization
106
-
107
-
108
- @maximize_memory_utilization()
109
- def knn(x, y, batch_size, k: int = 3):
110
- return torch.cat(
111
- [
112
- torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
113
- for start in range(0, x.shape[0], batch_size)
114
- ],
115
- dim=0,
116
- )
117
- ```
118
-
119
- In the code, you can now always pass the largest sensible batch size, e.g.,
120
-
121
- ```python
122
- x = torch.rand(100, 100, device="cuda")
123
- y = torch.rand(200, 100, device="cuda")
124
- knn(x, y, batch_size=x.shape[0])
125
- ```
126
-
127
- ## 🚀 Installation
128
-
129
- The most recent release can be installed from
130
- [PyPI](https://pypi.org/project/torch_max_mem/) with:
131
-
132
- ```bash
133
- $ pip install torch_max_mem
134
- ```
135
-
136
- The most recent code and data can be installed directly from GitHub with:
137
-
138
- ```bash
139
- $ pip install git+https://github.com/mberr/torch-max-mem.git
140
- ```
141
-
142
- To install in development mode, use the following:
143
-
144
- ```bash
145
- $ git clone git+https://github.com/mberr/torch-max-mem.git
146
- $ cd torch-max-mem
147
- $ pip install -e .
148
- ```
149
-
150
- ## 👐 Contributing
151
-
152
- Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See
153
- [CONTRIBUTING.md](https://github.com/mberr/torch-max-mem/blob/master/CONTRIBUTING.md) for more information on getting involved.
154
-
155
- ## 👋 Attribution
156
-
157
- Parts of the logic have been developed with [Laurent Vermue](https://github.com/lvermue) for [PyKEEN](https://github.com/pykeen/pykeen).
158
-
159
-
160
- ### ⚖️ License
161
-
162
- The code in this package is licensed under the MIT License.
163
-
164
- ### 🍪 Cookiecutter
165
-
166
- This package was created with [@audreyfeldroy](https://github.com/audreyfeldroy)'s
167
- [cookiecutter](https://github.com/cookiecutter/cookiecutter) package using [@cthoyt](https://github.com/cthoyt)'s
168
- [cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) template.
169
-
170
- ## 🛠️ For Developers
171
-
172
- <details>
173
- <summary>See developer instrutions</summary>
174
-
175
-
176
- The final section of the README is for if you want to get involved by making a code contribution.
177
-
178
- ### 🥼 Testing
179
-
180
- After cloning the repository and installing `tox` with `pip install tox`, the unit tests in the `tests/` folder can be
181
- run reproducibly with:
182
-
183
- ```shell
184
- $ tox
185
- ```
186
-
187
- Additionally, these tests are automatically re-run with each commit in a [GitHub Action](https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests).
188
-
189
- ### 📖 Building the Documentation
190
-
191
- ```shell
192
- $ tox -e docs
193
- ```
194
-
195
- ### 📦 Making a Release
196
-
197
- After installing the package in development mode and installing
198
- `tox` with `pip install tox`, the commands for making a new release are contained within the `finish` environment
199
- in `tox.ini`. Run the following from the shell:
200
-
201
- ```shell
202
- $ tox -e finish
203
- ```
204
-
205
- This script does the following:
206
-
207
- 1. Uses [Bump2Version](https://github.com/c4urself/bump2version) to switch the version number in the `setup.cfg` and
208
- `src/torch_max_mem/version.py` to not have the `-dev` suffix
209
- 2. Packages the code in both a tar archive and a wheel
210
- 3. Uploads to PyPI using `twine`. Be sure to have a `.pypirc` file configured to avoid the need for manual input at this
211
- step
212
- 4. Push to GitHub. You'll need to make a release going with the commit where the version was bumped.
213
- 5. Bump the version to the next patch. If you made big changes and want to bump the version by minor, you can
214
- use `tox -e bumpversion minor` after.
215
- </details>
@@ -1,8 +0,0 @@
1
- torch_max_mem/__init__.py,sha256=7XoGfOMupwSZ3HWFXhwQB7ysIhPZeERbVRfzCDxaADw,230
2
- torch_max_mem/api.py,sha256=_M0bTWArDZZUwXXBZc2nWVZgM_cJqAHdvo0dA8ydFbE,18193
3
- torch_max_mem/version.py,sha256=0oTTHe1bJoC-_dZzxkennR0WP9hs7wJPbVZD8sLd8SY,1035
4
- torch_max_mem-0.1.3.dist-info/LICENSE,sha256=UWPCY1OmquNWdcfQkJXqtsKRSWU66D-ULFEOqxDhTII,1071
5
- torch_max_mem-0.1.3.dist-info/METADATA,sha256=8UXSKQ4VdA1KPTXElSKNakEnYF_31IFHCPmcYVMp9aE,7401
6
- torch_max_mem-0.1.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
7
- torch_max_mem-0.1.3.dist-info/top_level.txt,sha256=ztqqyZB7neLi8zWiWdaaKytNuRkZ7SzO7YDUQ4ZJR3U,14
8
- torch_max_mem-0.1.3.dist-info/RECORD,,
@@ -1,5 +0,0 @@
1
- Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
3
- Root-Is-Purelib: true
4
- Tag: py3-none-any
5
-
@@ -1 +0,0 @@
1
- torch_max_mem