torch-memory-saver 0.0.6__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/PKG-INFO +1 -1
- torch_memory_saver-0.0.8/README.md +74 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/csrc/torch_memory_saver.cpp +47 -16
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/setup.py +11 -7
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver/__init__.py +37 -17
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/PKG-INFO +1 -1
- torch_memory_saver-0.0.6/README.md +0 -29
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/LICENSE +0 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/setup.cfg +0 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/SOURCES.txt +0 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/dependency_links.txt +0 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/top_level.txt +0 -0
@@ -0,0 +1,74 @@
|
|
1
|
+
# Torch Memory Saver
|
2
|
+
|
3
|
+
A PyTorch library that allows tensor memory to be temporarily released and resumed later.
|
4
|
+
|
5
|
+
During the pause:
|
6
|
+
- Physical memory is released
|
7
|
+
- Virtual address is preserved
|
8
|
+
|
9
|
+
When resume:
|
10
|
+
- Virtual address is restored to the original one
|
11
|
+
|
12
|
+
Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
|
13
|
+
|
14
|
+
## Examples
|
15
|
+
|
16
|
+
### Basic Example
|
17
|
+
|
18
|
+
```python
|
19
|
+
import torch_memory_saver
|
20
|
+
|
21
|
+
memory_saver = torch_memory_saver.memory_saver
|
22
|
+
|
23
|
+
# 1. For tensors that wants to be paused, create them within `region`
|
24
|
+
with memory_saver.region():
|
25
|
+
pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
26
|
+
|
27
|
+
# 2. After `pause`, CUDA memory is released for those tensors.
|
28
|
+
# For example, check `nvidia-smi`'s memory usage to verify.
|
29
|
+
memory_saver.pause()
|
30
|
+
|
31
|
+
# 3. After `resume`, CUDA memory is re-occupied for those tensors.
|
32
|
+
memory_saver.resume()
|
33
|
+
```
|
34
|
+
|
35
|
+
### Multiple Tags Example
|
36
|
+
|
37
|
+
Please refer to https://github.com/sgl-project/sglang/issues/7009 for details.
|
38
|
+
|
39
|
+
```python
|
40
|
+
from torch_memory_saver import torch_memory_saver
|
41
|
+
|
42
|
+
# 1. Create tensors with different tags
|
43
|
+
with torch_memory_saver.region(tag="type1"):
|
44
|
+
tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
45
|
+
|
46
|
+
with torch_memory_saver.region(tag="type2"):
|
47
|
+
tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
48
|
+
|
49
|
+
# 2. Pause and resume with different tags selectively
|
50
|
+
torch_memory_saver.pause("type1")
|
51
|
+
torch_memory_saver.pause("type2")
|
52
|
+
|
53
|
+
|
54
|
+
torch_memory_saver.resume("type2")
|
55
|
+
torch_memory_saver.resume("type1")
|
56
|
+
|
57
|
+
torch_memory_saver.pause("type1")
|
58
|
+
torch_memory_saver.resume("type1")
|
59
|
+
```
|
60
|
+
|
61
|
+
## Development
|
62
|
+
|
63
|
+
```bash
|
64
|
+
pip install -e .
|
65
|
+
```
|
66
|
+
|
67
|
+
A `torch_memory_saver_cpp.abi3.so` will be built under `{your_workspace}/torch_memory_saver/` folder.
|
68
|
+
|
69
|
+
You can use this command for local testing:
|
70
|
+
```bash
|
71
|
+
LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/simple.py
|
72
|
+
|
73
|
+
LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/rl_with_cuda_graph.py
|
74
|
+
```
|
@@ -6,6 +6,7 @@
|
|
6
6
|
#include <dlfcn.h>
|
7
7
|
#include <unordered_map>
|
8
8
|
#include <mutex>
|
9
|
+
#include <string>
|
9
10
|
|
10
11
|
// #define TMS_DEBUG_LOG
|
11
12
|
|
@@ -118,32 +119,32 @@ struct _AllocationMetadata {
|
|
118
119
|
size_t size;
|
119
120
|
CUdevice device;
|
120
121
|
CUmemGenericAllocationHandle allocHandle;
|
122
|
+
std::string tag;
|
121
123
|
};
|
122
124
|
|
123
125
|
class TorchMemorySaver {
|
124
126
|
public:
|
125
127
|
TorchMemorySaver() {}
|
126
128
|
|
127
|
-
cudaError_t malloc(void **ptr, size_t size) {
|
129
|
+
cudaError_t malloc(void **ptr, size_t size, const std::string& tag) {
|
128
130
|
CUdevice device;
|
129
131
|
CURESULT_CHECK(cuCtxGetDevice(&device));
|
130
132
|
|
131
133
|
CUmemGenericAllocationHandle allocHandle;
|
132
134
|
CUDAUtils::cu_mem_create(&allocHandle, size, device);
|
133
|
-
|
134
135
|
CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr *) ptr, size, 0, 0, 0));
|
135
136
|
CURESULT_CHECK(cuMemMap((CUdeviceptr) * ptr, size, 0, allocHandle, 0));
|
136
137
|
CUDAUtils::cu_mem_set_access(*ptr, size, device);
|
137
138
|
|
138
139
|
{
|
139
|
-
const std::lock_guard
|
140
|
-
allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle});
|
140
|
+
const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
|
141
|
+
allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle, tag});
|
141
142
|
}
|
142
143
|
|
143
144
|
#ifdef TMS_DEBUG_LOG
|
144
145
|
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
|
145
146
|
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
|
146
|
-
<< " allocHandle=" << allocHandle
|
147
|
+
<< " allocHandle=" << allocHandle << " tag=" << tag
|
147
148
|
<< std::endl;
|
148
149
|
#endif
|
149
150
|
|
@@ -166,39 +167,47 @@ public:
|
|
166
167
|
#ifdef TMS_DEBUG_LOG
|
167
168
|
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
|
168
169
|
<< " ptr=" << ptr << " metadata.size=" << metadata.size
|
169
|
-
<< " metadata.allocHandle=" << metadata.allocHandle
|
170
|
+
<< " metadata.allocHandle=" << metadata.allocHandle << " tag=" << metadata.tag
|
170
171
|
<< std::endl;
|
171
172
|
#endif
|
172
173
|
|
173
174
|
return cudaSuccess;
|
174
175
|
}
|
175
176
|
|
176
|
-
void pause() {
|
177
|
+
void pause(const std::string& tag) {
|
177
178
|
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
|
178
179
|
|
179
180
|
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
180
181
|
void *ptr = it->first;
|
181
182
|
_AllocationMetadata metadata = it->second;
|
182
183
|
|
184
|
+
if (!tag.empty() && metadata.tag != tag) {
|
185
|
+
continue;
|
186
|
+
}
|
187
|
+
|
183
188
|
CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
|
184
189
|
CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
|
185
190
|
|
186
191
|
#ifdef TMS_DEBUG_LOG
|
187
192
|
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
|
188
193
|
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " metadata.allocHandle="
|
189
|
-
<< metadata.allocHandle
|
194
|
+
<< metadata.allocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
|
190
195
|
<< std::endl;
|
191
196
|
#endif
|
192
197
|
}
|
193
198
|
}
|
194
199
|
|
195
|
-
void resume() {
|
200
|
+
void resume(const std::string& tag) {
|
196
201
|
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
|
197
202
|
|
198
203
|
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
199
204
|
void *ptr = it->first;
|
200
205
|
_AllocationMetadata &metadata = it->second;
|
201
206
|
|
207
|
+
if (!tag.empty() && metadata.tag != tag) {
|
208
|
+
continue;
|
209
|
+
}
|
210
|
+
|
202
211
|
CUmemGenericAllocationHandle newAllocHandle;
|
203
212
|
CUDAUtils::cu_mem_create(&newAllocHandle, metadata.size, metadata.device);
|
204
213
|
|
@@ -210,7 +219,7 @@ public:
|
|
210
219
|
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
|
211
220
|
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " (old)metadata.allocHandle="
|
212
221
|
<< metadata.allocHandle
|
213
|
-
<< " (new)newAllocHandle=" << newAllocHandle
|
222
|
+
<< " (new)newAllocHandle=" << newAllocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
|
214
223
|
<< std::endl;
|
215
224
|
#endif
|
216
225
|
|
@@ -223,14 +232,18 @@ public:
|
|
223
232
|
return instance;
|
224
233
|
}
|
225
234
|
|
235
|
+
|
226
236
|
private:
|
227
|
-
// Similar to torch's CUDACachingAllocator and CUDAPluggableAllocator
|
228
237
|
std::mutex allocator_metadata_mutex_;
|
229
238
|
std::unordered_map<void *, _AllocationMetadata> allocation_metadata_;
|
230
239
|
};
|
231
240
|
|
241
|
+
|
242
|
+
// ----------------------------------------------- region manager --------------------------------------------------
|
243
|
+
|
232
244
|
namespace RegionManager {
|
233
245
|
static thread_local bool is_interesting_region_ = false;
|
246
|
+
static thread_local std::string current_tag_ = "default";
|
234
247
|
|
235
248
|
void enter() {
|
236
249
|
#ifdef TMS_DEBUG_LOG
|
@@ -249,13 +262,21 @@ namespace RegionManager {
|
|
249
262
|
bool is_interesting_region() {
|
250
263
|
return is_interesting_region_;
|
251
264
|
}
|
265
|
+
|
266
|
+
void set_current_tag(const std::string& tag) {
|
267
|
+
current_tag_ = tag;
|
268
|
+
}
|
269
|
+
|
270
|
+
const std::string& get_current_tag() {
|
271
|
+
return current_tag_;
|
272
|
+
}
|
252
273
|
}
|
253
274
|
|
254
275
|
// ------------------------------------------------- entrypoints ------------------------------------------------
|
255
276
|
|
256
277
|
cudaError_t cudaMalloc(void **ptr, size_t size) {
|
257
278
|
if (RegionManager::is_interesting_region()) {
|
258
|
-
return TorchMemorySaver::instance().malloc(ptr, size);
|
279
|
+
return TorchMemorySaver::instance().malloc(ptr, size, RegionManager::get_current_tag());
|
259
280
|
} else {
|
260
281
|
return APIForwarder::call_real_cuda_malloc(ptr, size);
|
261
282
|
}
|
@@ -278,11 +299,21 @@ void tms_region_leave() {
|
|
278
299
|
RegionManager::leave();
|
279
300
|
}
|
280
301
|
|
281
|
-
void
|
282
|
-
|
302
|
+
void tms_set_current_tag(const char* tag) {
|
303
|
+
if (tag == nullptr) {
|
304
|
+
std::cerr << "[torch_memory_saver.cpp] FATAL: NULL tag passed to tms_set_current_tag" << std::endl;
|
305
|
+
exit(1);
|
306
|
+
}
|
307
|
+
RegionManager::set_current_tag(std::string(tag));
|
283
308
|
}
|
284
309
|
|
285
|
-
void
|
286
|
-
|
310
|
+
void tms_pause(const char* tag) {
|
311
|
+
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
|
312
|
+
TorchMemorySaver::instance().pause(tag_str);
|
287
313
|
}
|
314
|
+
|
315
|
+
void tms_resume(const char* tag) {
|
316
|
+
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
|
317
|
+
TorchMemorySaver::instance().resume(tag_str);
|
288
318
|
}
|
319
|
+
}
|
@@ -1,8 +1,9 @@
|
|
1
|
+
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import shutil
|
4
5
|
from pathlib import Path
|
5
|
-
|
6
|
+
import platform
|
6
7
|
import setuptools
|
7
8
|
from setuptools import setup
|
8
9
|
|
@@ -24,25 +25,28 @@ def _find_cuda_home():
|
|
24
25
|
cuda_home = '/usr/local/cuda'
|
25
26
|
return cuda_home
|
26
27
|
|
27
|
-
|
28
28
|
cuda_home = Path(_find_cuda_home())
|
29
|
+
|
29
30
|
include_dirs = [
|
30
|
-
str(cuda_home
|
31
|
+
str((cuda_home / 'include').resolve()),
|
31
32
|
]
|
33
|
+
|
32
34
|
library_dirs = [
|
33
|
-
str(cuda_home
|
34
|
-
str(cuda_home
|
35
|
+
str((cuda_home / 'lib64').resolve()),
|
36
|
+
str((cuda_home / 'lib64/stubs').resolve()),
|
35
37
|
]
|
36
38
|
|
37
39
|
setup(
|
38
40
|
name='torch_memory_saver',
|
39
|
-
version='0.0.
|
41
|
+
version='0.0.8',
|
40
42
|
ext_modules=[setuptools.Extension(
|
41
43
|
'torch_memory_saver_cpp',
|
42
44
|
['csrc/torch_memory_saver.cpp'],
|
43
45
|
include_dirs=include_dirs,
|
44
46
|
library_dirs=library_dirs,
|
45
|
-
libraries=['cuda']
|
47
|
+
libraries=['cuda'],
|
48
|
+
define_macros=[('Py_LIMITED_API', '0x03090000')],
|
49
|
+
py_limited_api=True,
|
46
50
|
)],
|
47
51
|
python_requires=">=3.9",
|
48
52
|
packages=['torch_memory_saver'],
|
@@ -14,29 +14,34 @@ logger = logging.getLogger(__name__)
|
|
14
14
|
class TorchMemorySaver:
|
15
15
|
def __init__(self):
|
16
16
|
self._mem_pool = None
|
17
|
-
self._id = _global_info.next_id()
|
18
|
-
assert self._id == 1, 'Only support one single instance yet (multi-instance will be implemented later)'
|
19
17
|
|
20
18
|
@contextmanager
|
21
|
-
def region(self):
|
19
|
+
def region(self, tag: str = "default"):
|
20
|
+
"""Context manager for memory saving with optional tag"""
|
22
21
|
if _global_info.binary_info.enabled:
|
23
22
|
self._ensure_mem_pool()
|
24
23
|
with torch.cuda.use_mem_pool(self._mem_pool):
|
24
|
+
_global_info.binary_info.cdll.tms_set_current_tag(tag.encode('utf-8'))
|
25
25
|
_global_info.binary_info.cdll.tms_region_enter()
|
26
26
|
try:
|
27
27
|
yield
|
28
28
|
finally:
|
29
|
+
_global_info.binary_info.cdll.tms_set_current_tag(b"default")
|
29
30
|
_global_info.binary_info.cdll.tms_region_leave()
|
30
31
|
else:
|
31
32
|
yield
|
32
33
|
|
33
|
-
def pause(self):
|
34
|
+
def pause(self, tag: Optional[str] = None):
|
35
|
+
"""Pause memory for specific tag or all memory if tag is None"""
|
34
36
|
if _global_info.binary_info.enabled:
|
35
|
-
|
37
|
+
tag_bytes = tag.encode('utf-8') if tag else None
|
38
|
+
_global_info.binary_info.cdll.tms_pause(tag_bytes)
|
36
39
|
|
37
|
-
def resume(self):
|
40
|
+
def resume(self, tag: Optional[str] = None):
|
41
|
+
"""Resume memory for specific tag or all memory if tag is None"""
|
38
42
|
if _global_info.binary_info.enabled:
|
39
|
-
|
43
|
+
tag_bytes = tag.encode('utf-8') if tag else None
|
44
|
+
_global_info.binary_info.cdll.tms_resume(tag_bytes)
|
40
45
|
|
41
46
|
@property
|
42
47
|
def enabled(self):
|
@@ -46,7 +51,6 @@ class TorchMemorySaver:
|
|
46
51
|
if self._mem_pool is None:
|
47
52
|
self._mem_pool = torch.cuda.MemPool()
|
48
53
|
|
49
|
-
|
50
54
|
@dataclass
|
51
55
|
class _BinaryInfo:
|
52
56
|
cdll: Optional[ctypes.CDLL]
|
@@ -55,21 +59,39 @@ class _BinaryInfo:
|
|
55
59
|
def enabled(self):
|
56
60
|
return self.cdll is not None
|
57
61
|
|
62
|
+
@staticmethod
|
63
|
+
def _setup_function_signatures(cdll):
|
64
|
+
"""Define function signatures for the C library"""
|
65
|
+
cdll.tms_region_enter.argtypes = []
|
66
|
+
cdll.tms_region_leave.argtypes = []
|
67
|
+
cdll.tms_set_current_tag.argtypes = [ctypes.c_char_p]
|
68
|
+
cdll.tms_pause.argtypes = [ctypes.c_char_p]
|
69
|
+
cdll.tms_resume.argtypes = [ctypes.c_char_p]
|
70
|
+
|
58
71
|
@staticmethod
|
59
72
|
def compute():
|
60
73
|
env_ld_preload = os.environ.get('LD_PRELOAD', '')
|
61
74
|
if 'torch_memory_saver' in env_ld_preload:
|
62
|
-
|
75
|
+
try:
|
76
|
+
cdll = ctypes.CDLL(env_ld_preload)
|
77
|
+
_BinaryInfo._setup_function_signatures(cdll)
|
78
|
+
return _BinaryInfo(cdll=cdll)
|
79
|
+
except OSError as e:
|
80
|
+
logger.error(f'Failed to load CDLL from {env_ld_preload}: {e}')
|
81
|
+
return _BinaryInfo(cdll=None)
|
63
82
|
else:
|
64
|
-
|
65
|
-
f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD
|
83
|
+
print(
|
84
|
+
f'TorchMemorySaver is disabled for the current process because invalid LD_PRELOAD. '
|
85
|
+
f'You can use configure_subprocess() utility, '
|
86
|
+
f'or directly specify `LD_PRELOAD=/path/to/torch_memory_saver_cpp.some-postfix.so python your_script.py. '
|
87
|
+
f'(LD_PRELOAD="{env_ld_preload}" process_id={os.getpid()})'
|
88
|
+
)
|
66
89
|
return _BinaryInfo(cdll=None)
|
67
90
|
|
68
91
|
|
69
92
|
class _GlobalInfo:
|
70
93
|
def __init__(self):
|
71
94
|
self._binary_info: Optional[_BinaryInfo] = None
|
72
|
-
self._last_id = 0
|
73
95
|
|
74
96
|
@property
|
75
97
|
def binary_info(self):
|
@@ -77,13 +99,11 @@ class _GlobalInfo:
|
|
77
99
|
self._binary_info = _BinaryInfo.compute()
|
78
100
|
return self._binary_info
|
79
101
|
|
80
|
-
def next_id(self):
|
81
|
-
self._last_id += 1
|
82
|
-
return self._last_id
|
83
|
-
|
84
102
|
|
85
103
|
_global_info = _GlobalInfo()
|
86
104
|
|
105
|
+
# Global singleton instance
|
106
|
+
torch_memory_saver = TorchMemorySaver()
|
87
107
|
|
88
108
|
def get_binary_path():
|
89
109
|
dir_package = Path(__file__).parent
|
@@ -92,7 +112,7 @@ def get_binary_path():
|
|
92
112
|
for d in [dir_package, dir_package.parent]
|
93
113
|
for p in d.glob('torch_memory_saver_cpp.*.so')
|
94
114
|
]
|
95
|
-
assert len(candidates) == 1, f'{candidates
|
115
|
+
assert len(candidates) == 1, f'Expected exactly one torch_memory_saver_cpp library, found: {candidates}'
|
96
116
|
return candidates[0]
|
97
117
|
|
98
118
|
|
@@ -1,29 +0,0 @@
|
|
1
|
-
# torch_memory_saver
|
2
|
-
|
3
|
-
Allow torch tensor memory to be released and resumed later.
|
4
|
-
|
5
|
-
API:
|
6
|
-
|
7
|
-
```python
|
8
|
-
memory_saver = TorchMemorySaver()
|
9
|
-
|
10
|
-
# 1. For tensors that wants to be paused, create them within `region`
|
11
|
-
with memory_saver.region():
|
12
|
-
x = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
13
|
-
|
14
|
-
# 2. After `pause`, CUDA memory is released for those tensors.
|
15
|
-
# For example, check `nvidia-smi`'s memory usage to verify.
|
16
|
-
memory_saver.pause()
|
17
|
-
|
18
|
-
# 3. After `resume`, CUDA memory is re-occupied for those tensors.
|
19
|
-
memory_saver.resume()
|
20
|
-
```
|
21
|
-
|
22
|
-
Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
|
23
|
-
|
24
|
-
TODO:
|
25
|
-
|
26
|
-
- [x] Implementation
|
27
|
-
- [x] Publish to pypi
|
28
|
-
- [ ] More tests and infra
|
29
|
-
- [ ] Documentation
|
File without changes
|
File without changes
|
{torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{torch_memory_saver-0.0.6 → torch_memory_saver-0.0.8}/torch_memory_saver.egg-info/top_level.txt
RENAMED
File without changes
|