torch-ddp 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/torch_ddp/cuda.cpp +68 -0
- data/ext/torch_ddp/ext.cpp +2 -0
- data/ext/torch_ddp/extconf.rb +1 -0
- data/lib/torch/ddp/monkey_patch.rb +21 -36
- data/lib/torch/ddp/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: c04c1f358a671d251826b7bf9db798bd5d9f11e279b639cff382f9f8f07d4b5f
|
|
4
|
+
data.tar.gz: 14c9db6913aaf75f98242f06808db8a3b696f37e6d779e092c8ed59cd398a644
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 0c62affe04041abca2dc56d6e82cc511d0338c3a4460c076b2fde13f219b889903b376521c92a38a6df0f0447bbd058081075ec1fcd6d3fcfb914063a0338691
|
|
7
|
+
data.tar.gz: 11634e1c29274033be29c37f51eae9bcc62bab902afedb469bdfde39e74f9d9230a93295a1c0b5fb2ee8e303af0df20e659335ee5e21c3c9e6b89608c9600dde
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
#include <torch/torch.h>
|
|
2
|
+
|
|
3
|
+
#include <rice/rice.hpp>
|
|
4
|
+
|
|
5
|
+
#if defined(WITH_CUDA)
|
|
6
|
+
#include <cuda_runtime_api.h>
|
|
7
|
+
#include <c10/cuda/CUDACachingAllocator.h>
|
|
8
|
+
#endif
|
|
9
|
+
|
|
10
|
+
namespace {
|
|
11
|
+
|
|
12
|
+
void register_cuda_helpers(Rice::Module& m) {
|
|
13
|
+
auto rb_mDDP = Rice::define_module_under(m, "DDP");
|
|
14
|
+
|
|
15
|
+
rb_mDDP.define_singleton_function(
|
|
16
|
+
"_cuda_set_device",
|
|
17
|
+
[](int device_id) {
|
|
18
|
+
#if defined(WITH_CUDA)
|
|
19
|
+
int count = 0;
|
|
20
|
+
auto status = cudaGetDeviceCount(&count);
|
|
21
|
+
if (status != cudaSuccess) {
|
|
22
|
+
rb_raise(
|
|
23
|
+
rb_eRuntimeError,
|
|
24
|
+
"cudaGetDeviceCount failed with code %d",
|
|
25
|
+
static_cast<int>(status));
|
|
26
|
+
}
|
|
27
|
+
if (device_id < 0 || device_id >= count) {
|
|
28
|
+
rb_raise(
|
|
29
|
+
rb_eArgError,
|
|
30
|
+
"Invalid device_id %d for CUDA (available devices: %d)",
|
|
31
|
+
device_id,
|
|
32
|
+
count);
|
|
33
|
+
}
|
|
34
|
+
status = cudaSetDevice(device_id);
|
|
35
|
+
if (status != cudaSuccess) {
|
|
36
|
+
rb_raise(
|
|
37
|
+
rb_eRuntimeError,
|
|
38
|
+
"cudaSetDevice(%d) failed with code %d",
|
|
39
|
+
device_id,
|
|
40
|
+
static_cast<int>(status));
|
|
41
|
+
}
|
|
42
|
+
#else
|
|
43
|
+
rb_raise(
|
|
44
|
+
rb_eRuntimeError,
|
|
45
|
+
"Torch::DDP._cuda_set_device requires CUDA support");
|
|
46
|
+
#endif
|
|
47
|
+
return Rice::Nil;
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
rb_mDDP.define_singleton_function(
|
|
51
|
+
"_cuda_empty_cache",
|
|
52
|
+
[]() {
|
|
53
|
+
#if defined(WITH_CUDA)
|
|
54
|
+
c10::cuda::CUDACachingAllocator::emptyCache();
|
|
55
|
+
#else
|
|
56
|
+
rb_raise(
|
|
57
|
+
rb_eRuntimeError,
|
|
58
|
+
"Torch::DDP._cuda_empty_cache requires CUDA support");
|
|
59
|
+
#endif
|
|
60
|
+
return Rice::Nil;
|
|
61
|
+
});
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
} // namespace
|
|
65
|
+
|
|
66
|
+
void init_cuda_helpers(Rice::Module& m) {
|
|
67
|
+
register_cuda_helpers(m);
|
|
68
|
+
}
|
data/ext/torch_ddp/ext.cpp
CHANGED
data/ext/torch_ddp/extconf.rb
CHANGED
|
@@ -56,6 +56,7 @@ if Dir["#{lib}/*torch_cuda*"].any?
|
|
|
56
56
|
$LDFLAGS += " -L#{cudnn_lib}" if Dir.exist?(cudnn_lib) && cudnn_lib != cuda_lib
|
|
57
57
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
|
58
58
|
end
|
|
59
|
+
$defs << "-DWITH_CUDA" if with_cuda
|
|
59
60
|
|
|
60
61
|
$INCFLAGS += " -I#{inc}"
|
|
61
62
|
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
require "fiddle"
|
|
2
|
-
|
|
3
1
|
module Torch
|
|
4
2
|
module DDP
|
|
5
3
|
module MonkeyPatch
|
|
@@ -14,6 +12,7 @@ module Torch
|
|
|
14
12
|
|
|
15
13
|
warn("#{WARNING_PREFIX} Applying torch compatibility patch for: #{missing.join(', ')}. Please upgrade the torch gem for native support.")
|
|
16
14
|
patch_cuda_set_device if missing.include?(:cuda_set_device)
|
|
15
|
+
patch_cuda_empty_cache if missing.include?(:cuda_empty_cache)
|
|
17
16
|
patch_device_helpers
|
|
18
17
|
patch_load if missing.include?(:load_keywords)
|
|
19
18
|
patch_tensor_item if missing.include?(:tensor_item_scalar)
|
|
@@ -25,6 +24,7 @@ module Torch
|
|
|
25
24
|
def missing_features
|
|
26
25
|
missing = []
|
|
27
26
|
missing << :cuda_set_device unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device)
|
|
27
|
+
missing << :cuda_empty_cache unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:empty_cache)
|
|
28
28
|
missing << :load_keywords unless load_supports_map_location_and_weights_only?
|
|
29
29
|
missing << :tensor_item_scalar unless tensor_item_returns_scalar?
|
|
30
30
|
missing
|
|
@@ -56,48 +56,33 @@ module Torch
|
|
|
56
56
|
end
|
|
57
57
|
|
|
58
58
|
def cuda_set_device!(device_id)
|
|
59
|
-
|
|
59
|
+
unless Torch.const_defined?(:DDP) && Torch::DDP.respond_to?(:_cuda_set_device)
|
|
60
|
+
raise Torch::Error, "Torch::CUDA.set_device is unavailable; ensure torch is built with CUDA or upgrade torch."
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
Torch::DDP._cuda_set_device(Integer(device_id))
|
|
60
64
|
end
|
|
61
65
|
public :cuda_set_device!
|
|
62
66
|
|
|
63
|
-
def
|
|
64
|
-
|
|
65
|
-
candidates = [
|
|
66
|
-
ENV["LIBCUDART_PATH"],
|
|
67
|
-
"/usr/local/cuda/lib64/libcudart.so",
|
|
68
|
-
"/usr/local/cuda/lib/libcudart.so",
|
|
69
|
-
"/usr/local/cuda/lib/libcudart.dylib",
|
|
70
|
-
"libcudart.so.12",
|
|
71
|
-
"libcudart.so.11",
|
|
72
|
-
"libcudart.so",
|
|
73
|
-
"libcudart.dylib"
|
|
74
|
-
].compact
|
|
75
|
-
|
|
76
|
-
function = nil
|
|
77
|
-
candidates.each do |path|
|
|
78
|
-
begin
|
|
79
|
-
handle = Fiddle.dlopen(path)
|
|
80
|
-
function = Fiddle::Function.new(handle["cudaSetDevice"], [Fiddle::TYPE_INT], Fiddle::TYPE_INT)
|
|
81
|
-
break
|
|
82
|
-
rescue Fiddle::DLError
|
|
83
|
-
next
|
|
84
|
-
end
|
|
85
|
-
end
|
|
67
|
+
def patch_cuda_empty_cache
|
|
68
|
+
return unless Torch.const_defined?(:CUDA)
|
|
86
69
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
raise Torch::Error, "cudaSetDevice(#{device_id}) failed with code #{result}" unless result.zero?
|
|
91
|
-
nil
|
|
92
|
-
end
|
|
93
|
-
else
|
|
94
|
-
->(device_id) do
|
|
95
|
-
raise Torch::Error, "Torch::CUDA.set_device is unavailable; ensure torch is built with CUDA or upgrade torch."
|
|
96
|
-
end
|
|
70
|
+
Torch::CUDA.singleton_class.class_eval do
|
|
71
|
+
define_method(:empty_cache) do
|
|
72
|
+
Torch::DDP::MonkeyPatch.cuda_empty_cache!
|
|
97
73
|
end
|
|
98
74
|
end
|
|
99
75
|
end
|
|
100
76
|
|
|
77
|
+
def cuda_empty_cache!
|
|
78
|
+
unless Torch.const_defined?(:DDP) && Torch::DDP.respond_to?(:_cuda_empty_cache)
|
|
79
|
+
raise Torch::Error, "Torch::CUDA.empty_cache is unavailable; ensure torch is built with CUDA or upgrade torch."
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
Torch::DDP._cuda_empty_cache
|
|
83
|
+
end
|
|
84
|
+
public :cuda_empty_cache!
|
|
85
|
+
|
|
101
86
|
def patch_device_helpers
|
|
102
87
|
Torch::Device.class_eval do
|
|
103
88
|
alias_method :_torch_ddp_original_to_s, :to_s unless method_defined?(:_torch_ddp_original_to_s)
|
data/lib/torch/ddp/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: torch-ddp
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.2.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Ivan Razuvaev
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2025-12-
|
|
11
|
+
date: 2025-12-19 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: torch-rb
|
|
@@ -51,6 +51,7 @@ files:
|
|
|
51
51
|
- bin/torchrun
|
|
52
52
|
- examples/benchmark/training.rb
|
|
53
53
|
- examples/mnist/distributed.rb
|
|
54
|
+
- ext/torch_ddp/cuda.cpp
|
|
54
55
|
- ext/torch_ddp/distributed.cpp
|
|
55
56
|
- ext/torch_ddp/ext.cpp
|
|
56
57
|
- ext/torch_ddp/extconf.rb
|