torch-memory-saver 0.0.8__cp39-abi3-manylinux2014_x86_64.whl → 0.0.9rc2__cp39-abi3-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,135 +1,5 @@
1
- import ctypes
2
- import logging
3
- import os
4
- from contextlib import contextmanager
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Optional
1
+ from .entrypoint import TorchMemorySaver
2
+ from .hooks.mode_preload import configure_subprocess
8
3
 
9
- import torch
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class TorchMemorySaver:
15
- def __init__(self):
16
- self._mem_pool = None
17
-
18
- @contextmanager
19
- def region(self, tag: str = "default"):
20
- """Context manager for memory saving with optional tag"""
21
- if _global_info.binary_info.enabled:
22
- self._ensure_mem_pool()
23
- with torch.cuda.use_mem_pool(self._mem_pool):
24
- _global_info.binary_info.cdll.tms_set_current_tag(tag.encode('utf-8'))
25
- _global_info.binary_info.cdll.tms_region_enter()
26
- try:
27
- yield
28
- finally:
29
- _global_info.binary_info.cdll.tms_set_current_tag(b"default")
30
- _global_info.binary_info.cdll.tms_region_leave()
31
- else:
32
- yield
33
-
34
- def pause(self, tag: Optional[str] = None):
35
- """Pause memory for specific tag or all memory if tag is None"""
36
- if _global_info.binary_info.enabled:
37
- tag_bytes = tag.encode('utf-8') if tag else None
38
- _global_info.binary_info.cdll.tms_pause(tag_bytes)
39
-
40
- def resume(self, tag: Optional[str] = None):
41
- """Resume memory for specific tag or all memory if tag is None"""
42
- if _global_info.binary_info.enabled:
43
- tag_bytes = tag.encode('utf-8') if tag else None
44
- _global_info.binary_info.cdll.tms_resume(tag_bytes)
45
-
46
- @property
47
- def enabled(self):
48
- return _global_info.binary_info.enabled
49
-
50
- def _ensure_mem_pool(self):
51
- if self._mem_pool is None:
52
- self._mem_pool = torch.cuda.MemPool()
53
-
54
- @dataclass
55
- class _BinaryInfo:
56
- cdll: Optional[ctypes.CDLL]
57
-
58
- @property
59
- def enabled(self):
60
- return self.cdll is not None
61
-
62
- @staticmethod
63
- def _setup_function_signatures(cdll):
64
- """Define function signatures for the C library"""
65
- cdll.tms_region_enter.argtypes = []
66
- cdll.tms_region_leave.argtypes = []
67
- cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
68
- cdll.tms_pause.argtypes = [ctypes.c_char_p]
69
- cdll.tms_resume.argtypes = [ctypes.c_char_p]
70
-
71
- @staticmethod
72
- def compute():
73
- env_ld_preload = os.environ.get('LD_PRELOAD', '')
74
- if 'torch_memory_saver' in env_ld_preload:
75
- try:
76
- cdll = ctypes.CDLL(env_ld_preload)
77
- _BinaryInfo._setup_function_signatures(cdll)
78
- return _BinaryInfo(cdll=cdll)
79
- except OSError as e:
80
- logger.error(f'Failed to load CDLL from {env_ld_preload}: {e}')
81
- return _BinaryInfo(cdll=None)
82
- else:
83
- print(
84
- f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD. '
85
- f'You can use configure_subprocess() utility, '
86
- f'or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. '
87
- f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
88
- )
89
- return _BinaryInfo(cdll=None)
90
-
91
-
92
- class _GlobalInfo:
93
- def __init__(self):
94
- self._binary_info: Optional[_BinaryInfo] = None
95
-
96
- @property
97
- def binary_info(self):
98
- if self._binary_info is None:
99
- self._binary_info = _BinaryInfo.compute()
100
- return self._binary_info
101
-
102
-
103
- _global_info = _GlobalInfo()
104
-
105
- # Global singleton instance
4
+ # Global singleton
106
5
  torch_memory_saver = TorchMemorySaver()
107
-
108
- def get_binary_path():
109
- dir_package = Path(__file__).parent
110
- candidates = [
111
- p
112
- for d in [dir_package, dir_package.parent]
113
- for p in d.glob('torch_memory_saver_cpp.*.so')
114
- ]
115
- assert len(candidates) == 1, f'Expected exactly one torch_memory_saver_cpp library, found: {candidates}'
116
- return candidates[0]
117
-
118
-
119
- @contextmanager
120
- def configure_subprocess():
121
- with change_env('LD_PRELOAD', str(get_binary_path())):
122
- yield
123
-
124
-
125
- @contextmanager
126
- def change_env(key: str, value: str):
127
- old_value = os.environ.get(key, '')
128
- os.environ[key] = value
129
- logger.debug(f'change_env set key={key} value={value}')
130
- try:
131
- yield
132
- finally:
133
- assert os.environ[key] == value
134
- os.environ[key] = old_value
135
- logger.debug(f'change_env restore key={key} value={old_value}')
@@ -0,0 +1,31 @@
1
+ import ctypes
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class BinaryWrapper:
8
+ def __init__(self, path_binary: str):
9
+ try:
10
+ self.cdll = ctypes.CDLL(path_binary)
11
+ except OSError as e:
12
+ logger.error(f"Failed to load CDLL from {path_binary}: {e}")
13
+ raise
14
+
15
+ _setup_function_signatures(self.cdll)
16
+
17
+ def set_config(self, *, tag: str, interesting_region: bool, enable_cpu_backup: bool):
18
+ self.cdll.tms_set_current_tag(tag.encode("utf-8"))
19
+ self.cdll.tms_set_interesting_region(interesting_region)
20
+ self.cdll.tms_set_enable_cpu_backup(enable_cpu_backup)
21
+
22
+
23
+ def _setup_function_signatures(cdll):
24
+ """Define function signatures for the C library"""
25
+ cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
26
+ cdll.tms_set_interesting_region.argtypes = [ctypes.c_bool]
27
+ cdll.tms_get_interesting_region.restype = ctypes.c_bool
28
+ cdll.tms_set_enable_cpu_backup.argtypes = [ctypes.c_bool]
29
+ cdll.tms_get_enable_cpu_backup.restype = ctypes.c_bool
30
+ cdll.tms_pause.argtypes = [ctypes.c_char_p]
31
+ cdll.tms_resume.argtypes = [ctypes.c_char_p]
@@ -0,0 +1,142 @@
1
+ import ctypes
2
+ import logging
3
+ import os
4
+ from contextlib import contextmanager
5
+ from typing import Optional
6
+ import torch
7
+
8
+ from .binary_wrapper import BinaryWrapper
9
+ from .hooks.base import HookUtilBase, HookMode
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ _TAG_DEFAULT = "default"
14
+
15
+
16
+ class TorchMemorySaver:
17
+ def __init__(self):
18
+ self._impl_ctor_kwargs = {}
19
+ self._impl: Optional[_TorchMemorySaverImpl] = None
20
+
21
+ @contextmanager
22
+ def region(self, tag: str = _TAG_DEFAULT, enable_cpu_backup: bool = False):
23
+ """Context manager for memory saving with optional tag"""
24
+ self._ensure_initialized()
25
+ with self._impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
26
+ yield
27
+
28
+ @contextmanager
29
+ def cuda_graph(
30
+ self,
31
+ cuda_graph, pool=None, stream=None, capture_error_mode='global',
32
+ tag: str = _TAG_DEFAULT, enable_cpu_backup: bool = False,
33
+ ):
34
+ """Similar to `torch.cuda.graph`, but ensures memory in it to be pauseable."""
35
+ self._ensure_initialized()
36
+ with self._impl.cuda_graph(
37
+ cuda_graph=cuda_graph,
38
+ pool=pool, stream=stream, capture_error_mode=capture_error_mode,
39
+ tag=tag, enable_cpu_backup=enable_cpu_backup,
40
+ ):
41
+ yield
42
+
43
+ @contextmanager
44
+ def disable(self):
45
+ self._ensure_initialized()
46
+ with self._impl.disable():
47
+ yield
48
+
49
+ def pause(self, tag: Optional[str] = None):
50
+ """Pause memory for specific tag or all memory if tag is None"""
51
+ self._ensure_initialized()
52
+ self._impl.pause(tag=tag)
53
+
54
+ def resume(self, tag: Optional[str] = None):
55
+ """Resume memory for specific tag or all memory if tag is None"""
56
+ self._ensure_initialized()
57
+ self._impl.resume(tag=tag)
58
+
59
+ # for compatibility
60
+ @property
61
+ def enabled(self):
62
+ return True
63
+
64
+ @property
65
+ def hook_mode(self):
66
+ raise AttributeError
67
+
68
+ @hook_mode.setter
69
+ def hook_mode(self, hook_mode: HookMode):
70
+ assert self._impl_ctor_kwargs is not None, "Cannot configure after initialization"
71
+ self._impl_ctor_kwargs["hook_mode"] = hook_mode
72
+
73
+ def _ensure_initialized(self):
74
+ if self._impl is not None:
75
+ return
76
+ self._impl = _TorchMemorySaverImpl(**self._impl_ctor_kwargs)
77
+ del self._impl_ctor_kwargs
78
+
79
+
80
+ class _TorchMemorySaverImpl:
81
+ def __init__(self, hook_mode: HookMode = "preload"):
82
+ self._hook_mode = hook_mode
83
+ self._hook_util = HookUtilBase.create(hook_mode=hook_mode)
84
+ self._binary_wrapper = BinaryWrapper(path_binary=self._hook_util.get_path_binary())
85
+ self._primary_mem_pool = torch.cuda.MemPool(allocator=self._hook_util.get_allocator())
86
+ _sanity_checks()
87
+
88
+ @contextmanager
89
+ def region(self, tag: str, enable_cpu_backup: bool):
90
+ with torch.cuda.use_mem_pool(self._primary_mem_pool):
91
+ with self._with_region_config(tag=tag, enable_cpu_backup=enable_cpu_backup):
92
+ yield
93
+
94
+ @contextmanager
95
+ def cuda_graph(self, cuda_graph, pool, stream, capture_error_mode, tag: str, enable_cpu_backup: bool):
96
+ assert self._hook_mode == "preload", "Only hook_mode=preload supports pauseable CUDA Graph currently"
97
+ with torch.cuda.graph(cuda_graph, pool=pool, stream=stream, capture_error_mode=capture_error_mode):
98
+ with self._with_region_config(tag=tag, enable_cpu_backup=enable_cpu_backup):
99
+ yield
100
+
101
+ @contextmanager
102
+ def _with_region_config(self, tag: str, enable_cpu_backup: bool):
103
+ assert not self._binary_wrapper.cdll.tms_get_interesting_region()
104
+ original_enable_cpu_backup = self._binary_wrapper.cdll.tms_get_enable_cpu_backup()
105
+
106
+ self._binary_wrapper.set_config(tag=tag, interesting_region=True, enable_cpu_backup=enable_cpu_backup)
107
+ try:
108
+ yield
109
+ finally:
110
+ assert self._binary_wrapper.cdll.tms_get_interesting_region()
111
+ self._binary_wrapper.set_config(tag=_TAG_DEFAULT, interesting_region=False, enable_cpu_backup=original_enable_cpu_backup)
112
+
113
+ @contextmanager
114
+ def disable(self, dispose_mem_pool_after_use: bool = True):
115
+ assert dispose_mem_pool_after_use, "Only dispose_mem_pool_after_use=true is supported now"
116
+ assert self._binary_wrapper.cdll.tms_get_interesting_region(), "disable() should be called only when tms is active"
117
+
118
+ self._binary_wrapper.cdll.tms_set_interesting_region(False)
119
+ try:
120
+ # We can either reuse the pool or delete it immediately, and we implement the latter currently since Slime uses it.
121
+ # About why we need a pool: https://github.com/fzyzcjy/torch_memory_saver/pull/20#issuecomment-3047099047
122
+ pool = torch.cuda.MemPool()
123
+ with torch.cuda.use_mem_pool(pool):
124
+ yield
125
+ del pool
126
+ finally:
127
+ self._binary_wrapper.cdll.tms_set_interesting_region(True)
128
+
129
+ def pause(self, tag: Optional[str]):
130
+ tag_bytes = tag.encode("utf-8") if tag else None
131
+ self._binary_wrapper.cdll.tms_pause(tag_bytes)
132
+
133
+ def resume(self, tag: Optional[str]):
134
+ tag_bytes = tag.encode("utf-8") if tag else None
135
+ self._binary_wrapper.cdll.tms_resume(tag_bytes)
136
+
137
+
138
+ def _sanity_checks():
139
+ if "expandable_segments:True" in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""):
140
+ raise RuntimeError(
141
+ "TorchMemorySaver is disabled for the current process because expandable_segments is not supported yet."
142
+ )
File without changes
@@ -0,0 +1,21 @@
1
+ from abc import ABC
2
+ from typing import Literal
3
+
4
+ HookMode = Literal["preload", "torch"]
5
+
6
+
7
+ class HookUtilBase(ABC):
8
+ @staticmethod
9
+ def create(hook_mode: HookMode) -> "HookUtilBase":
10
+ from torch_memory_saver.hooks.mode_preload import HookUtilModePreload
11
+ from torch_memory_saver.hooks.mode_torch import HookUtilModeTorch
12
+ return {
13
+ "preload": HookUtilModePreload,
14
+ "torch": HookUtilModeTorch,
15
+ }[hook_mode]()
16
+
17
+ def get_path_binary(self):
18
+ raise NotImplementedError
19
+
20
+ def get_allocator(self):
21
+ return None
@@ -0,0 +1,26 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from torch_memory_saver.hooks.base import HookUtilBase
5
+ from torch_memory_saver.utils import get_binary_path_from_package, change_env
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HookUtilModePreload(HookUtilBase):
11
+ def get_path_binary(self):
12
+ env_ld_preload = os.environ.get("LD_PRELOAD", "")
13
+ assert "torch_memory_saver" in env_ld_preload, (
14
+ f"TorchMemorySaver observes invalid LD_PRELOAD. "
15
+ f"You can use configure_subprocess() utility, "
16
+ f"or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. "
17
+ f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
18
+ )
19
+ return env_ld_preload
20
+
21
+
22
+ @contextmanager
23
+ def configure_subprocess():
24
+ """Configure environment variables for subprocesses. Only needed for hook_mode=preload."""
25
+ with change_env("LD_PRELOAD", str(get_binary_path_from_package("torch_memory_saver_hook_mode_preload"))):
26
+ yield
@@ -0,0 +1,19 @@
1
+ import logging
2
+
3
+ from torch_memory_saver.hooks.base import HookUtilBase
4
+ from torch_memory_saver.utils import get_binary_path_from_package
5
+ from torch.cuda.memory import CUDAPluggableAllocator
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HookUtilModeTorch(HookUtilBase):
11
+ def __init__(self):
12
+ self.allocator = CUDAPluggableAllocator(self.get_path_binary(), "tms_torch_malloc", "tms_torch_free")
13
+ logger.debug(f"HookUtilModeTorch {self.allocator=} {self.get_path_binary()=}")
14
+
15
+ def get_path_binary(self):
16
+ return str(get_binary_path_from_package("torch_memory_saver_hook_mode_torch"))
17
+
18
+ def get_allocator(self):
19
+ return self.allocator.allocator()
@@ -0,0 +1,10 @@
1
+ """Not to be used by end users, but only for tests of the package itself."""
2
+
3
+ import torch
4
+
5
+
6
+ def get_and_print_gpu_memory(message, gpu_id=0):
7
+ """Print GPU memory usage with optional message"""
8
+ mem = torch.cuda.device_memory_used(gpu_id)
9
+ print(f"GPU {gpu_id} memory: {mem / 1024 ** 3:.2f} GB ({message})")
10
+ return mem
@@ -0,0 +1,27 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def get_binary_path_from_package(stem: str):
10
+ dir_package = Path(__file__).parent
11
+ candidates = [p for d in [dir_package, dir_package.parent] for p in d.glob(f"{stem}.*.so")]
12
+ assert len(candidates) == 1, f"Expected exactly one torch_memory_saver_cpp library, found: {candidates}"
13
+ return candidates[0]
14
+
15
+
16
+ # private utils, not to be used by end users
17
+ @contextmanager
18
+ def change_env(key: str, value: str):
19
+ old_value = os.environ.get(key, "")
20
+ os.environ[key] = value
21
+ logger.debug(f"change_env set key={key} value={value}")
22
+ try:
23
+ yield
24
+ finally:
25
+ assert os.environ[key] == value
26
+ os.environ[key] = old_value
27
+ logger.debug(f"change_env restore key={key} value={old_value}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torch_memory_saver
3
- Version: 0.0.8
3
+ Version: 0.0.9rc2
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
 
@@ -0,0 +1,16 @@
1
+ torch_memory_saver_hook_mode_preload.abi3.so,sha256=dQ6PvuqKNXTsbAPSWxP86YkQL82Vo-ngJ9WqB4Zbiss,777208
2
+ torch_memory_saver_hook_mode_torch.abi3.so,sha256=-NE71rUpuYC3Sh8EqaMWbdMbm9C4h6MxeC30gNoqGys,781032
3
+ torch_memory_saver/__init__.py,sha256=9iU_QlTe6OxMR5_OtSRUmvr6ltzk149GjojYvG74sag,154
4
+ torch_memory_saver/binary_wrapper.py,sha256=MeQlPHIuFycamcWp3kOXjVZMiEK8HONuSx4l92J4k_Q,1133
5
+ torch_memory_saver/entrypoint.py,sha256=aFkgqnWRI8vF8EeAL4FvIY33dNVtIbUMk1eM3_xH-fs,5538
6
+ torch_memory_saver/testing_utils.py,sha256=vd9jhMgBLbeEy3vdvbuCjjtO-lRSX-RVB_Dg-wSHVQM,332
7
+ torch_memory_saver/utils.py,sha256=LhtiocZTpMyDEjSexXaGglQtOJeJB7AaH5s43PZX5yo,856
8
+ torch_memory_saver/hooks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ torch_memory_saver/hooks/base.py,sha256=f8Rv_XxNupU80dKWUgE-Ea5pu1qaoXcnzltZrDy90hY,579
10
+ torch_memory_saver/hooks/mode_preload.py,sha256=ELaVloI7T-rjssxn6lujaknNfujxFBxA2oc0SOsiUfk,1041
11
+ torch_memory_saver/hooks/mode_torch.py,sha256=yxGyA8AYrKX7hr3Bawr_MH2AMwSgXxLN76GZaAbQLGU,681
12
+ torch_memory_saver-0.0.9rc2.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
13
+ torch_memory_saver-0.0.9rc2.dist-info/METADATA,sha256=OFtZGuPLP8Qdyej7cEBiuWUmu-oU7qQrdZLA0wwJq3o,111
14
+ torch_memory_saver-0.0.9rc2.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
15
+ torch_memory_saver-0.0.9rc2.dist-info/top_level.txt,sha256=Fdob5gbD3sjPAe3kNfDokaN1sL43cMvwKRLKuR8oitw,91
16
+ torch_memory_saver-0.0.9rc2.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ torch_memory_saver
2
+ torch_memory_saver_hook_mode_preload
3
+ torch_memory_saver_hook_mode_torch
@@ -1,7 +0,0 @@
1
- torch_memory_saver_cpp.abi3.so,sha256=kI54XLX1E_R-PgmfwrZoM2-emKOWQXRoabNEMZS0dCE,391672
2
- torch_memory_saver/__init__.py,sha256=MfMSSGSNhP7xBo4jq6RPug6svnMVPtmua8-l0yWlkTg,4403
3
- torch_memory_saver-0.0.8.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
4
- torch_memory_saver-0.0.8.dist-info/METADATA,sha256=tUjtugtoDTFWeIct5kTvWRbGUUTlJlG85h_zSClzUdA,108
5
- torch_memory_saver-0.0.8.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
6
- torch_memory_saver-0.0.8.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
7
- torch_memory_saver-0.0.8.dist-info/RECORD,,
@@ -1,2 +0,0 @@
1
- torch_memory_saver
2
- torch_memory_saver_cpp
Binary file