torch-memory-saver 0.0.6__cp39-abi3-manylinux2014_x86_64.whl → 0.0.8__cp39-abi3-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_memory_saver/__init__.py +37 -17
- {torch_memory_saver-0.0.6.dist-info → torch_memory_saver-0.0.8.dist-info}/METADATA +1 -1
- torch_memory_saver-0.0.8.dist-info/RECORD +7 -0
- torch_memory_saver_cpp.abi3.so +0 -0
- torch_memory_saver-0.0.6.dist-info/RECORD +0 -7
- {torch_memory_saver-0.0.6.dist-info → torch_memory_saver-0.0.8.dist-info}/LICENSE +0 -0
- {torch_memory_saver-0.0.6.dist-info → torch_memory_saver-0.0.8.dist-info}/WHEEL +0 -0
- {torch_memory_saver-0.0.6.dist-info → torch_memory_saver-0.0.8.dist-info}/top_level.txt +0 -0
torch_memory_saver/__init__.py
CHANGED
@@ -14,29 +14,34 @@ logger = logging.getLogger(__name__)
|
|
14
14
|
class TorchMemorySaver:
|
15
15
|
def __init__(self):
|
16
16
|
self._mem_pool = None
|
17
|
-
self._id = _global_info.next_id()
|
18
|
-
assert self._id == 1, 'Only support one single instance yet (multi-instance will be implemented later)'
|
19
17
|
|
20
18
|
@contextmanager
|
21
|
-
def region(self):
|
19
|
+
def region(self, tag: str = "default"):
|
20
|
+
"""Context manager for memory saving with optional tag"""
|
22
21
|
if _global_info.binary_info.enabled:
|
23
22
|
self._ensure_mem_pool()
|
24
23
|
with torch.cuda.use_mem_pool(self._mem_pool):
|
24
|
+
_global_info.binary_info.cdll.tms_set_current_tag(tag.encode('utf-8'))
|
25
25
|
_global_info.binary_info.cdll.tms_region_enter()
|
26
26
|
try:
|
27
27
|
yield
|
28
28
|
finally:
|
29
|
+
_global_info.binary_info.cdll.tms_set_current_tag(b"default")
|
29
30
|
_global_info.binary_info.cdll.tms_region_leave()
|
30
31
|
else:
|
31
32
|
yield
|
32
33
|
|
33
|
-
def pause(self):
|
34
|
+
def pause(self, tag: Optional[str] = None):
|
35
|
+
"""Pause memory for specific tag or all memory if tag is None"""
|
34
36
|
if _global_info.binary_info.enabled:
|
35
|
-
|
37
|
+
tag_bytes = tag.encode('utf-8') if tag else None
|
38
|
+
_global_info.binary_info.cdll.tms_pause(tag_bytes)
|
36
39
|
|
37
|
-
def resume(self):
|
40
|
+
def resume(self, tag: Optional[str] = None):
|
41
|
+
"""Resume memory for specific tag or all memory if tag is None"""
|
38
42
|
if _global_info.binary_info.enabled:
|
39
|
-
|
43
|
+
tag_bytes = tag.encode('utf-8') if tag else None
|
44
|
+
_global_info.binary_info.cdll.tms_resume(tag_bytes)
|
40
45
|
|
41
46
|
@property
|
42
47
|
def enabled(self):
|
@@ -46,7 +51,6 @@ class TorchMemorySaver:
|
|
46
51
|
if self._mem_pool is None:
|
47
52
|
self._mem_pool = torch.cuda.MemPool()
|
48
53
|
|
49
|
-
|
50
54
|
@dataclass
|
51
55
|
class _BinaryInfo:
|
52
56
|
cdll: Optional[ctypes.CDLL]
|
@@ -55,21 +59,39 @@ class _BinaryInfo:
|
|
55
59
|
def enabled(self):
|
56
60
|
return self.cdll is not None
|
57
61
|
|
62
|
+
@staticmethod
|
63
|
+
def _setup_function_signatures(cdll):
|
64
|
+
"""Define function signatures for the C library"""
|
65
|
+
cdll.tms_region_enter.argtypes = []
|
66
|
+
cdll.tms_region_leave.argtypes = []
|
67
|
+
cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
|
68
|
+
cdll.tms_pause.argtypes = [ctypes.c_char_p]
|
69
|
+
cdll.tms_resume.argtypes = [ctypes.c_char_p]
|
70
|
+
|
58
71
|
@staticmethod
|
59
72
|
def compute():
|
60
73
|
env_ld_preload = os.environ.get('LD_PRELOAD', '')
|
61
74
|
if 'torch_memory_saver' in env_ld_preload:
|
62
|
-
|
75
|
+
try:
|
76
|
+
cdll = ctypes.CDLL(env_ld_preload)
|
77
|
+
_BinaryInfo._setup_function_signatures(cdll)
|
78
|
+
return _BinaryInfo(cdll=cdll)
|
79
|
+
except OSError as e:
|
80
|
+
logger.error(f'Failed to load CDLL from {env_ld_preload}: {e}')
|
81
|
+
return _BinaryInfo(cdll=None)
|
63
82
|
else:
|
64
|
-
|
65
|
-
f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD
|
83
|
+
print(
|
84
|
+
f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD. '
|
85
|
+
f'You can use configure_subprocess() utility, '
|
86
|
+
f'or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. '
|
87
|
+
f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
|
88
|
+
)
|
66
89
|
return _BinaryInfo(cdll=None)
|
67
90
|
|
68
91
|
|
69
92
|
class _GlobalInfo:
|
70
93
|
def __init__(self):
|
71
94
|
self._binary_info: Optional[_BinaryInfo] = None
|
72
|
-
self._last_id = 0
|
73
95
|
|
74
96
|
@property
|
75
97
|
def binary_info(self):
|
@@ -77,13 +99,11 @@ class _GlobalInfo:
|
|
77
99
|
self._binary_info = _BinaryInfo.compute()
|
78
100
|
return self._binary_info
|
79
101
|
|
80
|
-
def next_id(self):
|
81
|
-
self._last_id += 1
|
82
|
-
return self._last_id
|
83
|
-
|
84
102
|
|
85
103
|
_global_info = _GlobalInfo()
|
86
104
|
|
105
|
+
# Global singleton instance
|
106
|
+
torch_memory_saver = TorchMemorySaver()
|
87
107
|
|
88
108
|
def get_binary_path():
|
89
109
|
dir_package = Path(__file__).parent
|
@@ -92,7 +112,7 @@ def get_binary_path():
|
|
92
112
|
for d in [dir_package, dir_package.parent]
|
93
113
|
for p in d.glob('torch_memory_saver_cpp.*.so')
|
94
114
|
]
|
95
|
-
assert len(candidates) == 1, f'{candidates
|
115
|
+
assert len(candidates) == 1, f'Expected exactly one torch_memory_saver_cpp library, found: {candidates}'
|
96
116
|
return candidates[0]
|
97
117
|
|
98
118
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
torch_memory_saver_cpp.abi3.so,sha256=kI54XLX1E_R-PgmfwrZoM2-emKOWQXRoabNEMZS0dCE,391672
|
2
|
+
torch_memory_saver/__init__.py,sha256=MfMSSGSNhP7xBo4jq6RPug6svnMVPtmua8-l0yWlkTg,4403
|
3
|
+
torch_memory_saver-0.0.8.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
|
4
|
+
torch_memory_saver-0.0.8.dist-info/METADATA,sha256=tUjtugtoDTFWeIct5kTvWRbGUUTlJlG85h_zSClzUdA,108
|
5
|
+
torch_memory_saver-0.0.8.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
|
6
|
+
torch_memory_saver-0.0.8.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
|
7
|
+
torch_memory_saver-0.0.8.dist-info/RECORD,,
|
torch_memory_saver_cpp.abi3.so
CHANGED
Binary file
|
@@ -1,7 +0,0 @@
|
|
1
|
-
torch_memory_saver_cpp.abi3.so,sha256=OCweTnvdmyg5zhUIMJjfH9NW0lYjtcNwqzS9-89cCvQ,315896
|
2
|
-
torch_memory_saver/__init__.py,sha256=B3AXwxxJeUbNFKdrfaGzXvl3vTcgPOf2UjaFVtGCZ68,3072
|
3
|
-
torch_memory_saver-0.0.6.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
|
4
|
-
torch_memory_saver-0.0.6.dist-info/METADATA,sha256=P21LYFkCJHFwaMAkxZBoiQRkhvIQmOBAKbHHoQdQiEI,108
|
5
|
-
torch_memory_saver-0.0.6.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
|
6
|
-
torch_memory_saver-0.0.6.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
|
7
|
-
torch_memory_saver-0.0.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|