torch-memory-saver 0.0.6__cp39-abi3-manylinux2014_x86_64.whl → 0.0.9rc1__cp39-abi3-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,115 +1,5 @@
1
- import ctypes
2
- import logging
3
- import os
4
- from contextlib import contextmanager
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Optional
1
+ from .entrypoint import TorchMemorySaver
2
+ from .hooks.mode_preload import configure_subprocess
8
3
 
9
- import torch
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class TorchMemorySaver:
15
- def __init__(self):
16
- self._mem_pool = None
17
- self._id = _global_info.next_id()
18
- assert self._id == 1, 'Only support one single instance yet (multi-instance will be implemented later)'
19
-
20
- @contextmanager
21
- def region(self):
22
- if _global_info.binary_info.enabled:
23
- self._ensure_mem_pool()
24
- with torch.cuda.use_mem_pool(self._mem_pool):
25
- _global_info.binary_info.cdll.tms_region_enter()
26
- try:
27
- yield
28
- finally:
29
- _global_info.binary_info.cdll.tms_region_leave()
30
- else:
31
- yield
32
-
33
- def pause(self):
34
- if _global_info.binary_info.enabled:
35
- _global_info.binary_info.cdll.tms_pause()
36
-
37
- def resume(self):
38
- if _global_info.binary_info.enabled:
39
- _global_info.binary_info.cdll.tms_resume()
40
-
41
- @property
42
- def enabled(self):
43
- return _global_info.binary_info.enabled
44
-
45
- def _ensure_mem_pool(self):
46
- if self._mem_pool is None:
47
- self._mem_pool = torch.cuda.MemPool()
48
-
49
-
50
- @dataclass
51
- class _BinaryInfo:
52
- cdll: Optional[ctypes.CDLL]
53
-
54
- @property
55
- def enabled(self):
56
- return self.cdll is not None
57
-
58
- @staticmethod
59
- def compute():
60
- env_ld_preload = os.environ.get('LD_PRELOAD', '')
61
- if 'torch_memory_saver' in env_ld_preload:
62
- return _BinaryInfo(cdll=ctypes.CDLL(env_ld_preload))
63
- else:
64
- logger.warning(
65
- f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD="{env_ld_preload}" (process_id={os.getpid()})')
66
- return _BinaryInfo(cdll=None)
67
-
68
-
69
- class _GlobalInfo:
70
- def __init__(self):
71
- self._binary_info: Optional[_BinaryInfo] = None
72
- self._last_id = 0
73
-
74
- @property
75
- def binary_info(self):
76
- if self._binary_info is None:
77
- self._binary_info = _BinaryInfo.compute()
78
- return self._binary_info
79
-
80
- def next_id(self):
81
- self._last_id += 1
82
- return self._last_id
83
-
84
-
85
- _global_info = _GlobalInfo()
86
-
87
-
88
- def get_binary_path():
89
- dir_package = Path(__file__).parent
90
- candidates = [
91
- p
92
- for d in [dir_package, dir_package.parent]
93
- for p in d.glob('torch_memory_saver_cpp.*.so')
94
- ]
95
- assert len(candidates) == 1, f'{candidates=}'
96
- return candidates[0]
97
-
98
-
99
- @contextmanager
100
- def configure_subprocess():
101
- with change_env('LD_PRELOAD', str(get_binary_path())):
102
- yield
103
-
104
-
105
- @contextmanager
106
- def change_env(key: str, value: str):
107
- old_value = os.environ.get(key, '')
108
- os.environ[key] = value
109
- logger.debug(f'change_env set key={key} value={value}')
110
- try:
111
- yield
112
- finally:
113
- assert os.environ[key] == value
114
- os.environ[key] = old_value
115
- logger.debug(f'change_env restore key={key} value={old_value}')
4
+ # Global singleton
5
+ torch_memory_saver = TorchMemorySaver()
@@ -0,0 +1,31 @@
1
+ import ctypes
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class BinaryWrapper:
8
+ def __init__(self, path_binary: str):
9
+ try:
10
+ self.cdll = ctypes.CDLL(path_binary)
11
+ except OSError as e:
12
+ logger.error(f"Failed to load CDLL from {path_binary}: {e}")
13
+ raise
14
+
15
+ _setup_function_signatures(self.cdll)
16
+
17
+ def set_config(self, *, tag: str, interesting_region: bool, enable_cpu_backup: bool):
18
+ self.cdll.tms_set_current_tag(tag.encode("utf-8"))
19
+ self.cdll.tms_set_interesting_region(interesting_region)
20
+ self.cdll.tms_set_enable_cpu_backup(enable_cpu_backup)
21
+
22
+
23
+ def _setup_function_signatures(cdll):
24
+ """Define function signatures for the C library"""
25
+ cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
26
+ cdll.tms_set_interesting_region.argtypes = [ctypes.c_bool]
27
+ cdll.tms_get_interesting_region.restype = ctypes.c_bool
28
+ cdll.tms_set_enable_cpu_backup.argtypes = [ctypes.c_bool]
29
+ cdll.tms_get_enable_cpu_backup.restype = ctypes.c_bool
30
+ cdll.tms_pause.argtypes = [ctypes.c_char_p]
31
+ cdll.tms_resume.argtypes = [ctypes.c_char_p]
@@ -0,0 +1,142 @@
1
+ import ctypes
2
+ import logging
3
+ import os
4
+ from contextlib import contextmanager
5
+ from typing import Optional
6
+ import torch
7
+
8
+ from .binary_wrapper import BinaryWrapper
9
+ from .hooks.base import HookUtilBase, HookMode
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ _TAG_DEFAULT = "default"
14
+
15
+
16
+ class TorchMemorySaver:
17
+ def __init__(self):
18
+ self._impl_ctor_kwargs = {}
19
+ self._impl: Optional[_TorchMemorySaverImpl] = None
20
+
21
+ @contextmanager
22
+ def region(self, tag: str = _TAG_DEFAULT, enable_cpu_backup: bool = False):
23
+ """Context manager for memory saving with optional tag"""
24
+ self._ensure_initialized()
25
+ with self._impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
26
+ yield
27
+
28
+ @contextmanager
29
+ def cuda_graph(
30
+ self,
31
+ cuda_graph, pool=None, stream=None, capture_error_mode='global',
32
+ tag: str = _TAG_DEFAULT, enable_cpu_backup: bool = False,
33
+ ):
34
+ """Similar to `torch.cuda.graph`, but ensures memory in it to be pauseable."""
35
+ self._ensure_initialized()
36
+ with self._impl.cuda_graph(
37
+ cuda_graph=cuda_graph,
38
+ pool=pool, stream=stream, capture_error_mode=capture_error_mode,
39
+ tag=tag, enable_cpu_backup=enable_cpu_backup,
40
+ ):
41
+ yield
42
+
43
+ @contextmanager
44
+ def disable(self):
45
+ self._ensure_initialized()
46
+ with self._impl.disable():
47
+ yield
48
+
49
+ def pause(self, tag: Optional[str] = None):
50
+ """Pause memory for specific tag or all memory if tag is None"""
51
+ self._ensure_initialized()
52
+ self._impl.pause(tag=tag)
53
+
54
+ def resume(self, tag: Optional[str] = None):
55
+ """Resume memory for specific tag or all memory if tag is None"""
56
+ self._ensure_initialized()
57
+ self._impl.resume(tag=tag)
58
+
59
+ # for compatibility
60
+ @property
61
+ def enabled(self):
62
+ return True
63
+
64
+ @property
65
+ def hook_mode(self):
66
+ raise AttributeError
67
+
68
+ @hook_mode.setter
69
+ def hook_mode(self, hook_mode: HookMode):
70
+ assert self._impl_ctor_kwargs is not None, "Cannot configure after initialization"
71
+ self._impl_ctor_kwargs["hook_mode"] = hook_mode
72
+
73
+ def _ensure_initialized(self):
74
+ if self._impl is not None:
75
+ return
76
+ self._impl = _TorchMemorySaverImpl(**self._impl_ctor_kwargs)
77
+ del self._impl_ctor_kwargs
78
+
79
+
80
+ class _TorchMemorySaverImpl:
81
+ def __init__(self, hook_mode: HookMode = "preload"):
82
+ self._hook_mode = hook_mode
83
+ self._hook_util = HookUtilBase.create(hook_mode=hook_mode)
84
+ self._binary_wrapper = BinaryWrapper(path_binary=self._hook_util.get_path_binary())
85
+ self._primary_mem_pool = torch.cuda.MemPool(allocator=self._hook_util.get_allocator())
86
+ _sanity_checks()
87
+
88
+ @contextmanager
89
+ def region(self, tag: str, enable_cpu_backup: bool):
90
+ with torch.cuda.use_mem_pool(self._primary_mem_pool):
91
+ with self._with_region_config(tag=tag, enable_cpu_backup=enable_cpu_backup):
92
+ yield
93
+
94
+ @contextmanager
95
+ def cuda_graph(self, cuda_graph, pool, stream, capture_error_mode, tag: str, enable_cpu_backup: bool):
96
+ assert self._hook_mode == "preload", "Only hook_mode=preload supports pauseable CUDA Graph currently"
97
+ with torch.cuda.graph(cuda_graph, pool=pool, stream=stream, capture_error_mode=capture_error_mode):
98
+ with self._with_region_config(tag=tag, enable_cpu_backup=enable_cpu_backup):
99
+ yield
100
+
101
+ @contextmanager
102
+ def _with_region_config(self, tag: str, enable_cpu_backup: bool):
103
+ assert not self._binary_wrapper.cdll.tms_get_interesting_region()
104
+ original_enable_cpu_backup = self._binary_wrapper.cdll.tms_get_enable_cpu_backup()
105
+
106
+ self._binary_wrapper.set_config(tag=tag, interesting_region=True, enable_cpu_backup=enable_cpu_backup)
107
+ try:
108
+ yield
109
+ finally:
110
+ assert self._binary_wrapper.cdll.tms_get_interesting_region()
111
+ self._binary_wrapper.set_config(tag=_TAG_DEFAULT, interesting_region=False, enable_cpu_backup=original_enable_cpu_backup)
112
+
113
+ @contextmanager
114
+ def disable(self, dispose_mem_pool_after_use: bool = True):
115
+ assert dispose_mem_pool_after_use, "Only dispose_mem_pool_after_use=true is supported now"
116
+ assert self._binary_wrapper.cdll.tms_get_interesting_region(), "disable() should be called only when tms is active"
117
+
118
+ self._binary_wrapper.cdll.tms_set_interesting_region(False)
119
+ try:
120
+ # We can either reuse the pool or delete it immediately, and we implement the latter currently since Slime uses it.
121
+ # About why we need a pool: https://github.com/fzyzcjy/torch_memory_saver/pull/20#issuecomment-3047099047
122
+ pool = torch.cuda.MemPool()
123
+ with torch.cuda.use_mem_pool(pool):
124
+ yield
125
+ del pool
126
+ finally:
127
+ self._binary_wrapper.cdll.tms_set_interesting_region(True)
128
+
129
+ def pause(self, tag: Optional[str]):
130
+ tag_bytes = tag.encode("utf-8") if tag else None
131
+ self._binary_wrapper.cdll.tms_pause(tag_bytes)
132
+
133
+ def resume(self, tag: Optional[str]):
134
+ tag_bytes = tag.encode("utf-8") if tag else None
135
+ self._binary_wrapper.cdll.tms_resume(tag_bytes)
136
+
137
+
138
+ def _sanity_checks():
139
+ if "expandable_segments:True" in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""):
140
+ raise RuntimeError(
141
+ "TorchMemorySaver is disabled for the current process because expandable_segments is not supported yet."
142
+ )
File without changes
@@ -0,0 +1,21 @@
1
+ from abc import ABC
2
+ from typing import Literal
3
+
4
+ HookMode = Literal["preload", "torch"]
5
+
6
+
7
+ class HookUtilBase(ABC):
8
+ @staticmethod
9
+ def create(hook_mode: HookMode) -> "HookUtilBase":
10
+ from torch_memory_saver.hooks.mode_preload import HookUtilModePreload
11
+ from torch_memory_saver.hooks.mode_torch import HookUtilModeTorch
12
+ return {
13
+ "preload": HookUtilModePreload,
14
+ "torch": HookUtilModeTorch,
15
+ }[hook_mode]()
16
+
17
+ def get_path_binary(self):
18
+ raise NotImplementedError
19
+
20
+ def get_allocator(self):
21
+ return None
@@ -0,0 +1,26 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from torch_memory_saver.hooks.base import HookUtilBase
5
+ from torch_memory_saver.utils import get_binary_path_from_package, change_env
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HookUtilModePreload(HookUtilBase):
11
+ def get_path_binary(self):
12
+ env_ld_preload = os.environ.get("LD_PRELOAD", "")
13
+ assert "torch_memory_saver" in env_ld_preload, (
14
+ f"TorchMemorySaver observes invalid LD_PRELOAD. "
15
+ f"You can use configure_subprocess() utility, "
16
+ f"or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. "
17
+ f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
18
+ )
19
+ return env_ld_preload
20
+
21
+
22
+ @contextmanager
23
+ def configure_subprocess():
24
+ """Configure environment variables for subprocesses. Only needed for hook_mode=preload."""
25
+ with change_env("LD_PRELOAD", str(get_binary_path_from_package("torch_memory_saver_hook_mode_preload"))):
26
+ yield
@@ -0,0 +1,19 @@
1
+ import logging
2
+
3
+ from torch_memory_saver.hooks.base import HookUtilBase
4
+ from torch_memory_saver.utils import get_binary_path_from_package
5
+ from torch.cuda.memory import CUDAPluggableAllocator
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HookUtilModeTorch(HookUtilBase):
11
+ def __init__(self):
12
+ self.allocator = CUDAPluggableAllocator(self.get_path_binary(), "tms_torch_malloc", "tms_torch_free")
13
+ logger.debug(f"HookUtilModeTorch {self.allocator=} {self.get_path_binary()=}")
14
+
15
+ def get_path_binary(self):
16
+ return str(get_binary_path_from_package("torch_memory_saver_hook_mode_torch"))
17
+
18
+ def get_allocator(self):
19
+ return self.allocator.allocator()
@@ -0,0 +1,10 @@
1
+ """Not to be used by end users, but only for tests of the package itself."""
2
+
3
+ import torch
4
+
5
+
6
+ def get_and_print_gpu_memory(message, gpu_id=0):
7
+ """Print GPU memory usage with optional message"""
8
+ mem = torch.cuda.device_memory_used(gpu_id)
9
+ print(f"GPU {gpu_id} memory: {mem / 1024 ** 3:.2f} GB ({message})")
10
+ return mem
@@ -0,0 +1,27 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def get_binary_path_from_package(stem: str):
10
+ dir_package = Path(__file__).parent
11
+ candidates = [p for d in [dir_package, dir_package.parent] for p in d.glob(f"{stem}.*.so")]
12
+ assert len(candidates) == 1, f"Expected exactly one torch_memory_saver_cpp library, found: {candidates}"
13
+ return candidates[0]
14
+
15
+
16
+ # private utils, not to be used by end users
17
+ @contextmanager
18
+ def change_env(key: str, value: str):
19
+ old_value = os.environ.get(key, "")
20
+ os.environ[key] = value
21
+ logger.debug(f"change_env set key={key} value={value}")
22
+ try:
23
+ yield
24
+ finally:
25
+ assert os.environ[key] == value
26
+ os.environ[key] = old_value
27
+ logger.debug(f"change_env restore key={key} value={old_value}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torch_memory_saver
3
- Version: 0.0.6
3
+ Version: 0.0.9rc1
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
 
@@ -0,0 +1,16 @@
1
+ torch_memory_saver_hook_mode_preload.abi3.so,sha256=dQ6PvuqKNXTsbAPSWxP86YkQL82Vo-ngJ9WqB4Zbiss,777208
2
+ torch_memory_saver_hook_mode_torch.abi3.so,sha256=-NE71rUpuYC3Sh8EqaMWbdMbm9C4h6MxeC30gNoqGys,781032
3
+ torch_memory_saver/__init__.py,sha256=9iU_QlTe6OxMR5_OtSRUmvr6ltzk149GjojYvG74sag,154
4
+ torch_memory_saver/binary_wrapper.py,sha256=MeQlPHIuFycamcWp3kOXjVZMiEK8HONuSx4l92J4k_Q,1133
5
+ torch_memory_saver/entrypoint.py,sha256=aFkgqnWRI8vF8EeAL4FvIY33dNVtIbUMk1eM3_xH-fs,5538
6
+ torch_memory_saver/testing_utils.py,sha256=vd9jhMgBLbeEy3vdvbuCjjtO-lRSX-RVB_Dg-wSHVQM,332
7
+ torch_memory_saver/utils.py,sha256=LhtiocZTpMyDEjSexXaGglQtOJeJB7AaH5s43PZX5yo,856
8
+ torch_memory_saver/hooks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ torch_memory_saver/hooks/base.py,sha256=f8Rv_XxNupU80dKWUgE-Ea5pu1qaoXcnzltZrDy90hY,579
10
+ torch_memory_saver/hooks/mode_preload.py,sha256=ELaVloI7T-rjssxn6lujaknNfujxFBxA2oc0SOsiUfk,1041
11
+ torch_memory_saver/hooks/mode_torch.py,sha256=yxGyA8AYrKX7hr3Bawr_MH2AMwSgXxLN76GZaAbQLGU,681
12
+ torch_memory_saver-0.0.9rc1.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
13
+ torch_memory_saver-0.0.9rc1.dist-info/METADATA,sha256=PAMArqZ3_juC25gvir7W_--VffaTK9tO1-422cWOfnA,111
14
+ torch_memory_saver-0.0.9rc1.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
15
+ torch_memory_saver-0.0.9rc1.dist-info/top_level.txt,sha256=Fdob5gbD3sjPAe3kNfDokaN1sL43cMvwKRLKuR8oitw,91
16
+ torch_memory_saver-0.0.9rc1.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ torch_memory_saver
2
+ torch_memory_saver_hook_mode_preload
3
+ torch_memory_saver_hook_mode_torch
@@ -1,7 +0,0 @@
1
- torch_memory_saver_cpp.abi3.so,sha256=OCweTnvdmyg5zhUIMJjfH9NW0lYjtcNwqzS9-89cCvQ,315896
2
- torch_memory_saver/__init__.py,sha256=B3AXwxxJeUbNFKdrfaGzXvl3vTcgPOf2UjaFVtGCZ68,3072
3
- torch_memory_saver-0.0.6.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
4
- torch_memory_saver-0.0.6.dist-info/METADATA,sha256=P21LYFkCJHFwaMAkxZBoiQRkhvIQmOBAKbHHoQdQiEI,108
5
- torch_memory_saver-0.0.6.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
6
- torch_memory_saver-0.0.6.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
7
- torch_memory_saver-0.0.6.dist-info/RECORD,,
@@ -1,2 +0,0 @@
1
- torch_memory_saver
2
- torch_memory_saver_cpp
Binary file