torch-memory-saver 0.0.6__cp39-abi3-manylinux2014_x86_64.whl → 0.0.8__cp39-abi3-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,29 +14,34 @@ logger = logging.getLogger(__name__)
14
14
  class TorchMemorySaver:
15
15
  def __init__(self):
16
16
  self._mem_pool = None
17
- self._id = _global_info.next_id()
18
- assert self._id == 1, 'Only support one single instance yet (multi-instance will be implemented later)'
19
17
 
20
18
  @contextmanager
21
- def region(self):
19
+ def region(self, tag: str = "default"):
20
+ """Context manager for memory saving with optional tag"""
22
21
  if _global_info.binary_info.enabled:
23
22
  self._ensure_mem_pool()
24
23
  with torch.cuda.use_mem_pool(self._mem_pool):
24
+ _global_info.binary_info.cdll.tms_set_current_tag(tag.encode('utf-8'))
25
25
  _global_info.binary_info.cdll.tms_region_enter()
26
26
  try:
27
27
  yield
28
28
  finally:
29
+ _global_info.binary_info.cdll.tms_set_current_tag(b"default")
29
30
  _global_info.binary_info.cdll.tms_region_leave()
30
31
  else:
31
32
  yield
32
33
 
33
- def pause(self):
34
+ def pause(self, tag: Optional[str] = None):
35
+ """Pause memory for specific tag or all memory if tag is None"""
34
36
  if _global_info.binary_info.enabled:
35
- _global_info.binary_info.cdll.tms_pause()
37
+ tag_bytes = tag.encode('utf-8') if tag else None
38
+ _global_info.binary_info.cdll.tms_pause(tag_bytes)
36
39
 
37
- def resume(self):
40
+ def resume(self, tag: Optional[str] = None):
41
+ """Resume memory for specific tag or all memory if tag is None"""
38
42
  if _global_info.binary_info.enabled:
39
- _global_info.binary_info.cdll.tms_resume()
43
+ tag_bytes = tag.encode('utf-8') if tag else None
44
+ _global_info.binary_info.cdll.tms_resume(tag_bytes)
40
45
 
41
46
  @property
42
47
  def enabled(self):
@@ -46,7 +51,6 @@ class TorchMemorySaver:
46
51
  if self._mem_pool is None:
47
52
  self._mem_pool = torch.cuda.MemPool()
48
53
 
49
-
50
54
  @dataclass
51
55
  class _BinaryInfo:
52
56
  cdll: Optional[ctypes.CDLL]
@@ -55,21 +59,39 @@ class _BinaryInfo:
55
59
  def enabled(self):
56
60
  return self.cdll is not None
57
61
 
62
+ @staticmethod
63
+ def _setup_function_signatures(cdll):
64
+ """Define function signatures for the C library"""
65
+ cdll.tms_region_enter.argtypes = []
66
+ cdll.tms_region_leave.argtypes = []
67
+ cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
68
+ cdll.tms_pause.argtypes = [ctypes.c_char_p]
69
+ cdll.tms_resume.argtypes = [ctypes.c_char_p]
70
+
58
71
  @staticmethod
59
72
  def compute():
60
73
  env_ld_preload = os.environ.get('LD_PRELOAD', '')
61
74
  if 'torch_memory_saver' in env_ld_preload:
62
- return _BinaryInfo(cdll=ctypes.CDLL(env_ld_preload))
75
+ try:
76
+ cdll = ctypes.CDLL(env_ld_preload)
77
+ _BinaryInfo._setup_function_signatures(cdll)
78
+ return _BinaryInfo(cdll=cdll)
79
+ except OSError as e:
80
+ logger.error(f'Failed to load CDLL from {env_ld_preload}: {e}')
81
+ return _BinaryInfo(cdll=None)
63
82
  else:
64
- logger.warning(
65
- f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD="{env_ld_preload}" (process_id={os.getpid()})')
83
+ print(
84
+ f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD. '
85
+ f'You can use configure_subprocess() utility, '
86
+ f'or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. '
87
+ f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
88
+ )
66
89
  return _BinaryInfo(cdll=None)
67
90
 
68
91
 
69
92
  class _GlobalInfo:
70
93
  def __init__(self):
71
94
  self._binary_info: Optional[_BinaryInfo] = None
72
- self._last_id = 0
73
95
 
74
96
  @property
75
97
  def binary_info(self):
@@ -77,13 +99,11 @@ class _GlobalInfo:
77
99
  self._binary_info = _BinaryInfo.compute()
78
100
  return self._binary_info
79
101
 
80
- def next_id(self):
81
- self._last_id += 1
82
- return self._last_id
83
-
84
102
 
85
103
  _global_info = _GlobalInfo()
86
104
 
105
+ # Global singleton instance
106
+ torch_memory_saver = TorchMemorySaver()
87
107
 
88
108
  def get_binary_path():
89
109
  dir_package = Path(__file__).parent
@@ -92,7 +112,7 @@ def get_binary_path():
92
112
  for d in [dir_package, dir_package.parent]
93
113
  for p in d.glob('torch_memory_saver_cpp.*.so')
94
114
  ]
95
- assert len(candidates) == 1, f'{candidates=}'
115
+ assert len(candidates) == 1, f'Expected exactly one torch_memory_saver_cpp library, found: {candidates}'
96
116
  return candidates[0]
97
117
 
98
118
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torch_memory_saver
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
 
@@ -0,0 +1,7 @@
1
+ torch_memory_saver_cpp.abi3.so,sha256=kI54XLX1E_R-PgmfwrZoM2-emKOWQXRoabNEMZS0dCE,391672
2
+ torch_memory_saver/__init__.py,sha256=MfMSSGSNhP7xBo4jq6RPug6svnMVPtmua8-l0yWlkTg,4403
3
+ torch_memory_saver-0.0.8.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
4
+ torch_memory_saver-0.0.8.dist-info/METADATA,sha256=tUjtugtoDTFWeIct5kTvWRbGUUTlJlG85h_zSClzUdA,108
5
+ torch_memory_saver-0.0.8.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
6
+ torch_memory_saver-0.0.8.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
7
+ torch_memory_saver-0.0.8.dist-info/RECORD,,
Binary file
@@ -1,7 +0,0 @@
1
- torch_memory_saver_cpp.abi3.so,sha256=OCweTnvdmyg5zhUIMJjfH9NW0lYjtcNwqzS9-89cCvQ,315896
2
- torch_memory_saver/__init__.py,sha256=B3AXwxxJeUbNFKdrfaGzXvl3vTcgPOf2UjaFVtGCZ68,3072
3
- torch_memory_saver-0.0.6.dist-info/LICENSE,sha256=i806R5xShJFB4k9yNQJ2GYCcSBlu1frTx2vH_nWdWE8,1064
4
- torch_memory_saver-0.0.6.dist-info/METADATA,sha256=P21LYFkCJHFwaMAkxZBoiQRkhvIQmOBAKbHHoQdQiEI,108
5
- torch_memory_saver-0.0.6.dist-info/WHEEL,sha256=HUPiMa7ZA9BvJ9gdJRYwZIjK2rWbCcrqYvJ4Onw0owE,102
6
- torch_memory_saver-0.0.6.dist-info/top_level.txt,sha256=uJ27-bVSKHxdcfHRcakvEr_KQxnUlMia6v19fHbfHxA,42
7
- torch_memory_saver-0.0.6.dist-info/RECORD,,