torch-memory-saver 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torch_memory_saver
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: UNKNOWN
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -0,0 +1,74 @@
1
+ # Torch Memory Saver
2
+
3
+ A PyTorch library that allows tensor memory to be temporarily released and resumed later.
4
+
5
+ During the pause:
6
+ - Physical memory is released
7
+ - Virtual address is preserved
8
+
9
+ When resume:
10
+ - Virtual address is restored to the original one
11
+
12
+ Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
13
+
14
+ ## Examples
15
+
16
+ ### Basic Example
17
+
18
+ ```python
19
+ import torch_memory_saver
20
+
21
+ memory_saver = torch_memory_saver.memory_saver
22
+
23
+ # 1. For tensors that wants to be paused, create them within `region`
24
+ with memory_saver.region():
25
+ pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
26
+
27
+ # 2. After `pause`, CUDA memory is released for those tensors.
28
+ # For example, check `nvidia-smi`'s memory usage to verify.
29
+ memory_saver.pause()
30
+
31
+ # 3. After `resume`, CUDA memory is re-occupied for those tensors.
32
+ memory_saver.resume()
33
+ ```
34
+
35
+ ### Multiple Tags Example
36
+
37
+ Please refer to https://github.com/sgl-project/sglang/issues/7009 for details.
38
+
39
+ ```python
40
+ from torch_memory_saver import torch_memory_saver
41
+
42
+ # 1. Create tensors with different tags
43
+ with torch_memory_saver.region(tag="type1"):
44
+ tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
45
+
46
+ with torch_memory_saver.region(tag="type2"):
47
+ tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
48
+
49
+ # 2. Pause and resume with different tags selectively
50
+ torch_memory_saver.pause("type1")
51
+ torch_memory_saver.pause("type2")
52
+
53
+
54
+ torch_memory_saver.resume("type2")
55
+ torch_memory_saver.resume("type1")
56
+
57
+ torch_memory_saver.pause("type1")
58
+ torch_memory_saver.resume("type1")
59
+ ```
60
+
61
+ ## Development
62
+
63
+ ```bash
64
+ pip install -e .
65
+ ```
66
+
67
+ A `torch_memory_saver_cpp.abi3.so` will be built under `{your_workspace}/torch_memory_saver/` folder.
68
+
69
+ You can use this command for local testing:
70
+ ```bash
71
+ LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/simple.py
72
+
73
+ LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/rl_with_cuda_graph.py
74
+ ```
@@ -6,6 +6,7 @@
6
6
  #include <dlfcn.h>
7
7
  #include <unordered_map>
8
8
  #include <mutex>
9
+ #include <string>
9
10
 
10
11
  // #define TMS_DEBUG_LOG
11
12
 
@@ -118,32 +119,32 @@ struct _AllocationMetadata {
118
119
  size_t size;
119
120
  CUdevice device;
120
121
  CUmemGenericAllocationHandle allocHandle;
122
+ std::string tag;
121
123
  };
122
124
 
123
125
  class TorchMemorySaver {
124
126
  public:
125
127
  TorchMemorySaver() {}
126
128
 
127
- cudaError_t malloc(void **ptr, size_t size) {
129
+ cudaError_t malloc(void **ptr, size_t size, const std::string& tag) {
128
130
  CUdevice device;
129
131
  CURESULT_CHECK(cuCtxGetDevice(&device));
130
132
 
131
133
  CUmemGenericAllocationHandle allocHandle;
132
134
  CUDAUtils::cu_mem_create(&allocHandle, size, device);
133
-
134
135
  CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr *) ptr, size, 0, 0, 0));
135
136
  CURESULT_CHECK(cuMemMap((CUdeviceptr) * ptr, size, 0, allocHandle, 0));
136
137
  CUDAUtils::cu_mem_set_access(*ptr, size, device);
137
138
 
138
139
  {
139
- const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
140
- allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle});
140
+ const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
141
+ allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle, tag});
141
142
  }
142
143
 
143
144
  #ifdef TMS_DEBUG_LOG
144
145
  std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
145
146
  << " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
146
- << " allocHandle=" << allocHandle
147
+ << " allocHandle=" << allocHandle << " tag=" << tag
147
148
  << std::endl;
148
149
  #endif
149
150
 
@@ -166,39 +167,47 @@ public:
166
167
  #ifdef TMS_DEBUG_LOG
167
168
  std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
168
169
  << " ptr=" << ptr << " metadata.size=" << metadata.size
169
- << " metadata.allocHandle=" << metadata.allocHandle
170
+ << " metadata.allocHandle=" << metadata.allocHandle << " tag=" << metadata.tag
170
171
  << std::endl;
171
172
  #endif
172
173
 
173
174
  return cudaSuccess;
174
175
  }
175
176
 
176
- void pause() {
177
+ void pause(const std::string& tag) {
177
178
  const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
178
179
 
179
180
  for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
180
181
  void *ptr = it->first;
181
182
  _AllocationMetadata metadata = it->second;
182
183
 
184
+ if (!tag.empty() && metadata.tag != tag) {
185
+ continue;
186
+ }
187
+
183
188
  CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
184
189
  CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
185
190
 
186
191
  #ifdef TMS_DEBUG_LOG
187
192
  std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
188
193
  << " ptr=" << ptr << " metadata.size=" << metadata.size << " metadata.allocHandle="
189
- << metadata.allocHandle
194
+ << metadata.allocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
190
195
  << std::endl;
191
196
  #endif
192
197
  }
193
198
  }
194
199
 
195
- void resume() {
200
+ void resume(const std::string& tag) {
196
201
  const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
197
202
 
198
203
  for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
199
204
  void *ptr = it->first;
200
205
  _AllocationMetadata &metadata = it->second;
201
206
 
207
+ if (!tag.empty() && metadata.tag != tag) {
208
+ continue;
209
+ }
210
+
202
211
  CUmemGenericAllocationHandle newAllocHandle;
203
212
  CUDAUtils::cu_mem_create(&newAllocHandle, metadata.size, metadata.device);
204
213
 
@@ -210,7 +219,7 @@ public:
210
219
  std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
211
220
  << " ptr=" << ptr << " metadata.size=" << metadata.size << " (old)metadata.allocHandle="
212
221
  << metadata.allocHandle
213
- << " (new)newAllocHandle=" << newAllocHandle
222
+ << " (new)newAllocHandle=" << newAllocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
214
223
  << std::endl;
215
224
  #endif
216
225
 
@@ -223,14 +232,18 @@ public:
223
232
  return instance;
224
233
  }
225
234
 
235
+
226
236
  private:
227
- // Similar to torch's CUDACachingAllocator and CUDAPluggableAllocator
228
237
  std::mutex allocator_metadata_mutex_;
229
238
  std::unordered_map<void *, _AllocationMetadata> allocation_metadata_;
230
239
  };
231
240
 
241
+
242
+ // ----------------------------------------------- region manager --------------------------------------------------
243
+
232
244
  namespace RegionManager {
233
245
  static thread_local bool is_interesting_region_ = false;
246
+ static thread_local std::string current_tag_ = "default";
234
247
 
235
248
  void enter() {
236
249
  #ifdef TMS_DEBUG_LOG
@@ -249,13 +262,21 @@ namespace RegionManager {
249
262
  bool is_interesting_region() {
250
263
  return is_interesting_region_;
251
264
  }
265
+
266
+ void set_current_tag(const std::string& tag) {
267
+ current_tag_ = tag;
268
+ }
269
+
270
+ const std::string& get_current_tag() {
271
+ return current_tag_;
272
+ }
252
273
  }
253
274
 
254
275
  // ------------------------------------------------- entrypoints ------------------------------------------------
255
276
 
256
277
  cudaError_t cudaMalloc(void **ptr, size_t size) {
257
278
  if (RegionManager::is_interesting_region()) {
258
- return TorchMemorySaver::instance().malloc(ptr, size);
279
+ return TorchMemorySaver::instance().malloc(ptr, size, RegionManager::get_current_tag());
259
280
  } else {
260
281
  return APIForwarder::call_real_cuda_malloc(ptr, size);
261
282
  }
@@ -278,11 +299,21 @@ void tms_region_leave() {
278
299
  RegionManager::leave();
279
300
  }
280
301
 
281
- void tms_pause() {
282
- TorchMemorySaver::instance().pause();
302
+ void tms_set_current_tag(const char* tag) {
303
+ if (tag == nullptr) {
304
+ std::cerr << "[torch_memory_saver.cpp] FATAL: NULL tag passed to tms_set_current_tag" << std::endl;
305
+ exit(1);
306
+ }
307
+ RegionManager::set_current_tag(std::string(tag));
283
308
  }
284
309
 
285
- void tms_resume() {
286
- TorchMemorySaver::instance().resume();
310
+ void tms_pause(const char* tag) {
311
+ std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
312
+ TorchMemorySaver::instance().pause(tag_str);
287
313
  }
314
+
315
+ void tms_resume(const char* tag) {
316
+ std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
317
+ TorchMemorySaver::instance().resume(tag_str);
288
318
  }
319
+ }
@@ -1,8 +1,9 @@
1
+
1
2
  import logging
2
3
  import os
3
4
  import shutil
4
5
  from pathlib import Path
5
-
6
+ import platform
6
7
  import setuptools
7
8
  from setuptools import setup
8
9
 
@@ -24,25 +25,28 @@ def _find_cuda_home():
24
25
  cuda_home = '/usr/local/cuda'
25
26
  return cuda_home
26
27
 
27
-
28
28
  cuda_home = Path(_find_cuda_home())
29
+
29
30
  include_dirs = [
30
- str(cuda_home.resolve() / 'targets/x86_64-linux/include'),
31
+ str((cuda_home / 'include').resolve()),
31
32
  ]
33
+
32
34
  library_dirs = [
33
- str(cuda_home.resolve() / 'lib64'),
34
- str(cuda_home.resolve() / 'lib64/stubs'),
35
+ str((cuda_home / 'lib64').resolve()),
36
+ str((cuda_home / 'lib64/stubs').resolve()),
35
37
  ]
36
38
 
37
39
  setup(
38
40
  name='torch_memory_saver',
39
- version='0.0.6',
41
+ version='0.0.8',
40
42
  ext_modules=[setuptools.Extension(
41
43
  'torch_memory_saver_cpp',
42
44
  ['csrc/torch_memory_saver.cpp'],
43
45
  include_dirs=include_dirs,
44
46
  library_dirs=library_dirs,
45
- libraries=['cuda']
47
+ libraries=['cuda'],
48
+ define_macros=[('Py_LIMITED_API', '0x03090000')],
49
+ py_limited_api=True,
46
50
  )],
47
51
  python_requires=">=3.9",
48
52
  packages=['torch_memory_saver'],
@@ -14,29 +14,34 @@ logger = logging.getLogger(__name__)
14
14
  class TorchMemorySaver:
15
15
  def __init__(self):
16
16
  self._mem_pool = None
17
- self._id = _global_info.next_id()
18
- assert self._id == 1, 'Only support one single instance yet (multi-instance will be implemented later)'
19
17
 
20
18
  @contextmanager
21
- def region(self):
19
+ def region(self, tag: str = "default"):
20
+ """Context manager for memory saving with optional tag"""
22
21
  if _global_info.binary_info.enabled:
23
22
  self._ensure_mem_pool()
24
23
  with torch.cuda.use_mem_pool(self._mem_pool):
24
+ _global_info.binary_info.cdll.tms_set_current_tag(tag.encode('utf-8'))
25
25
  _global_info.binary_info.cdll.tms_region_enter()
26
26
  try:
27
27
  yield
28
28
  finally:
29
+ _global_info.binary_info.cdll.tms_set_current_tag(b"default")
29
30
  _global_info.binary_info.cdll.tms_region_leave()
30
31
  else:
31
32
  yield
32
33
 
33
- def pause(self):
34
+ def pause(self, tag: Optional[str] = None):
35
+ """Pause memory for specific tag or all memory if tag is None"""
34
36
  if _global_info.binary_info.enabled:
35
- _global_info.binary_info.cdll.tms_pause()
37
+ tag_bytes = tag.encode('utf-8') if tag else None
38
+ _global_info.binary_info.cdll.tms_pause(tag_bytes)
36
39
 
37
- def resume(self):
40
+ def resume(self, tag: Optional[str] = None):
41
+ """Resume memory for specific tag or all memory if tag is None"""
38
42
  if _global_info.binary_info.enabled:
39
- _global_info.binary_info.cdll.tms_resume()
43
+ tag_bytes = tag.encode('utf-8') if tag else None
44
+ _global_info.binary_info.cdll.tms_resume(tag_bytes)
40
45
 
41
46
  @property
42
47
  def enabled(self):
@@ -46,7 +51,6 @@ class TorchMemorySaver:
46
51
  if self._mem_pool is None:
47
52
  self._mem_pool = torch.cuda.MemPool()
48
53
 
49
-
50
54
  @dataclass
51
55
  class _BinaryInfo:
52
56
  cdll: Optional[ctypes.CDLL]
@@ -55,21 +59,39 @@ class _BinaryInfo:
55
59
  def enabled(self):
56
60
  return self.cdll is not None
57
61
 
62
+ @staticmethod
63
+ def _setup_function_signatures(cdll):
64
+ """Define function signatures for the C library"""
65
+ cdll.tms_region_enter.argtypes = []
66
+ cdll.tms_region_leave.argtypes = []
67
+ cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
68
+ cdll.tms_pause.argtypes = [ctypes.c_char_p]
69
+ cdll.tms_resume.argtypes = [ctypes.c_char_p]
70
+
58
71
  @staticmethod
59
72
  def compute():
60
73
  env_ld_preload = os.environ.get('LD_PRELOAD', '')
61
74
  if 'torch_memory_saver' in env_ld_preload:
62
- return _BinaryInfo(cdll=ctypes.CDLL(env_ld_preload))
75
+ try:
76
+ cdll = ctypes.CDLL(env_ld_preload)
77
+ _BinaryInfo._setup_function_signatures(cdll)
78
+ return _BinaryInfo(cdll=cdll)
79
+ except OSError as e:
80
+ logger.error(f'Failed to load CDLL from {env_ld_preload}: {e}')
81
+ return _BinaryInfo(cdll=None)
63
82
  else:
64
- logger.warning(
65
- f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD="{env_ld_preload}" (process_id={os.getpid()})')
83
+ print(
84
+ f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD. '
85
+ f'You can use configure_subprocess() utility, '
86
+ f'or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. '
87
+ f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
88
+ )
66
89
  return _BinaryInfo(cdll=None)
67
90
 
68
91
 
69
92
  class _GlobalInfo:
70
93
  def __init__(self):
71
94
  self._binary_info: Optional[_BinaryInfo] = None
72
- self._last_id = 0
73
95
 
74
96
  @property
75
97
  def binary_info(self):
@@ -77,13 +99,11 @@ class _GlobalInfo:
77
99
  self._binary_info = _BinaryInfo.compute()
78
100
  return self._binary_info
79
101
 
80
- def next_id(self):
81
- self._last_id += 1
82
- return self._last_id
83
-
84
102
 
85
103
  _global_info = _GlobalInfo()
86
104
 
105
+ # Global singleton instance
106
+ torch_memory_saver = TorchMemorySaver()
87
107
 
88
108
  def get_binary_path():
89
109
  dir_package = Path(__file__).parent
@@ -92,7 +112,7 @@ def get_binary_path():
92
112
  for d in [dir_package, dir_package.parent]
93
113
  for p in d.glob('torch_memory_saver_cpp.*.so')
94
114
  ]
95
- assert len(candidates) == 1, f'{candidates=}'
115
+ assert len(candidates) == 1, f'Expected exactly one torch_memory_saver_cpp library, found: {candidates}'
96
116
  return candidates[0]
97
117
 
98
118
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torch-memory-saver
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: UNKNOWN
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,29 +0,0 @@
1
- # torch_memory_saver
2
-
3
- Allow torch tensor memory to be released and resumed later.
4
-
5
- API:
6
-
7
- ```python
8
- memory_saver = TorchMemorySaver()
9
-
10
- # 1. For tensors that wants to be paused, create them within `region`
11
- with memory_saver.region():
12
- x = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
13
-
14
- # 2. After `pause`, CUDA memory is released for those tensors.
15
- # For example, check `nvidia-smi`'s memory usage to verify.
16
- memory_saver.pause()
17
-
18
- # 3. After `resume`, CUDA memory is re-occupied for those tensors.
19
- memory_saver.resume()
20
- ```
21
-
22
- Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
23
-
24
- TODO:
25
-
26
- - [x] Implementation
27
- - [x] Publish to pypi
28
- - [ ] More tests and infra
29
- - [ ] Documentation