torch-ddp 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +46 -0
- data/README.md +114 -0
- data/bin/torchrun +6 -0
- data/examples/benchmark/training.rb +374 -0
- data/examples/mnist/distributed.rb +240 -0
- data/ext/torch_ddp/distributed.cpp +348 -0
- data/ext/torch_ddp/ext.cpp +11 -0
- data/ext/torch_ddp/extconf.rb +155 -0
- data/lib/torch/ddp/monkey_patch.rb +325 -0
- data/lib/torch/ddp/version.rb +5 -0
- data/lib/torch/distributed.rb +466 -0
- data/lib/torch/nn/parallel/distributed_data_parallel.rb +115 -0
- data/lib/torch/torchrun.rb +531 -0
- data/lib/torch-ddp.rb +8 -0
- data/test/distributed_test.rb +243 -0
- data/test/support/net.rb +42 -0
- data/test/support/scripts/show_ranks.rb +7 -0
- data/test/support/tensor.pth +0 -0
- data/test/test_helper.rb +71 -0
- data/test/torchrun_test.rb +33 -0
- metadata +92 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
require "mkmf-rice"
|
|
2
|
+
|
|
3
|
+
$CXXFLAGS += " -std=c++17 $(optflags)"
|
|
4
|
+
|
|
5
|
+
# change to 0 for Linux pre-cxx11 ABI version
|
|
6
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
7
|
+
|
|
8
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
|
9
|
+
|
|
10
|
+
if apple_clang
|
|
11
|
+
# silence torch warnings
|
|
12
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
|
13
|
+
else
|
|
14
|
+
# silence rice warnings
|
|
15
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
|
16
|
+
|
|
17
|
+
# silence torch warnings
|
|
18
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
paths = [
|
|
22
|
+
"/usr/local",
|
|
23
|
+
"/opt/homebrew",
|
|
24
|
+
"/home/linuxbrew/.linuxbrew"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
inc, lib = dir_config("torch")
|
|
28
|
+
inc ||= paths.map { |v| "#{v}/include" }.find { |v| Dir.exist?("#{v}/torch") }
|
|
29
|
+
lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*torch_cpu*"].any? }
|
|
30
|
+
|
|
31
|
+
unless inc && lib
|
|
32
|
+
abort "LibTorch not found"
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
cuda_inc, cuda_lib = dir_config("cuda")
|
|
36
|
+
cuda_lib ||= "/usr/local/cuda/lib64"
|
|
37
|
+
|
|
38
|
+
cudnn_inc, cudnn_lib = dir_config("cudnn")
|
|
39
|
+
cudnn_lib ||= "/usr/local/cuda/lib"
|
|
40
|
+
|
|
41
|
+
gloo_inc, _ = dir_config("gloo")
|
|
42
|
+
gloo_inc ||= "./vendor/gloo"
|
|
43
|
+
|
|
44
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
|
45
|
+
abort "LibTorch not found" unless have_library("torch")
|
|
46
|
+
|
|
47
|
+
have_library("mkldnn")
|
|
48
|
+
have_library("nnpack")
|
|
49
|
+
|
|
50
|
+
with_cuda = false
|
|
51
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
|
52
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
|
53
|
+
$INCFLAGS += " -I#{cuda_inc}" if cuda_inc && Dir.exist?(cuda_inc)
|
|
54
|
+
$LDFLAGS += " -L#{cudnn_lib}" if Dir.exist?(cudnn_lib) && cudnn_lib != cuda_lib
|
|
55
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
$INCFLAGS += " -I#{inc}"
|
|
59
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
|
60
|
+
|
|
61
|
+
CONFIG["CC"] = CONFIG["CXX"]
|
|
62
|
+
$CFLAGS = $CXXFLAGS
|
|
63
|
+
|
|
64
|
+
supports_c10_cuda = with_cuda && try_compile(<<~CPP)
|
|
65
|
+
#include <torch/torch.h>
|
|
66
|
+
#include <c10/cuda/CUDAFunctions.h>
|
|
67
|
+
|
|
68
|
+
int main() {
|
|
69
|
+
c10::cuda::set_device(0);
|
|
70
|
+
return 0;
|
|
71
|
+
}
|
|
72
|
+
CPP
|
|
73
|
+
|
|
74
|
+
if supports_c10_cuda
|
|
75
|
+
$defs << " -DHAVE_C10_CUDA"
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
|
79
|
+
if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
|
|
80
|
+
$LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
|
|
81
|
+
end
|
|
82
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
|
83
|
+
|
|
84
|
+
# https://github.com/pytorch/pytorch/blob/v2.9.0/torch/utils/cpp_extension.py#L1351-L1364
|
|
85
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
|
86
|
+
if with_cuda
|
|
87
|
+
$LDFLAGS += " -lcuda -lnvrtc"
|
|
88
|
+
$LDFLAGS += " -lnvToolsExt" if File.exist?("#{cuda_lib}/libnvToolsExt.so")
|
|
89
|
+
$LDFLAGS += " -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
|
90
|
+
# TODO figure out why this is needed
|
|
91
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
supports_c10d = try_link(<<~CPP, "-DUSE_C10D")
|
|
95
|
+
#include <torch/torch.h>
|
|
96
|
+
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
|
97
|
+
|
|
98
|
+
int main() {
|
|
99
|
+
::c10d::FileStore store("unused", 1);
|
|
100
|
+
return 0;
|
|
101
|
+
}
|
|
102
|
+
CPP
|
|
103
|
+
|
|
104
|
+
if supports_c10d
|
|
105
|
+
$defs << " -DUSE_C10D"
|
|
106
|
+
puts "Building with distributed support"
|
|
107
|
+
|
|
108
|
+
if find_header("gloo/algorithm.h", gloo_inc)
|
|
109
|
+
$INCFLAGS += " -I#{gloo_inc}"
|
|
110
|
+
else
|
|
111
|
+
puts "GLOO headers not found. Consider setting --with-gloo-include param"
|
|
112
|
+
end
|
|
113
|
+
else
|
|
114
|
+
puts "Building without distributed support"
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
supports_c10d_gloo = supports_c10d && try_link(<<~CPP, "-DUSE_C10D -DUSE_C10D_GLOO")
|
|
118
|
+
#include <torch/torch.h>
|
|
119
|
+
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
|
120
|
+
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
|
121
|
+
|
|
122
|
+
int main() {
|
|
123
|
+
auto store = c10::make_intrusive<::c10d::FileStore>("unused", 1);
|
|
124
|
+
auto opts = ::c10d::ProcessGroupGloo::Options::create();
|
|
125
|
+
opts->devices.push_back(::c10d::ProcessGroupGloo::createDefaultDevice());
|
|
126
|
+
::c10d::ProcessGroupGloo pg(store, 0, 1, opts);
|
|
127
|
+
return static_cast<int>(pg.getRank());
|
|
128
|
+
}
|
|
129
|
+
CPP
|
|
130
|
+
|
|
131
|
+
supports_c10d_nccl = with_cuda && supports_c10_cuda && try_link(<<~CPP, "-DUSE_C10D -DUSE_C10D_NCCL")
|
|
132
|
+
#include <torch/torch.h>
|
|
133
|
+
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
134
|
+
|
|
135
|
+
int main() {
|
|
136
|
+
auto opts = c10::make_intrusive<::c10d::ProcessGroupNCCL::Options>();
|
|
137
|
+
opts->is_high_priority_stream = false;
|
|
138
|
+
return 0;
|
|
139
|
+
}
|
|
140
|
+
CPP
|
|
141
|
+
|
|
142
|
+
if supports_c10d_gloo
|
|
143
|
+
$defs << "-DUSE_C10D_GLOO"
|
|
144
|
+
puts "GLOO support detected"
|
|
145
|
+
end
|
|
146
|
+
unless supports_c10_cuda
|
|
147
|
+
puts "No c10 CUDA headers found. NCCL is unavailable"
|
|
148
|
+
end
|
|
149
|
+
if supports_c10d_nccl
|
|
150
|
+
$defs << "-DUSE_C10D_NCCL"
|
|
151
|
+
puts "NCCL support detected"
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
# create makefile
|
|
155
|
+
create_makefile("torch/ddp_ext")
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
require "fiddle"
|
|
2
|
+
|
|
3
|
+
module Torch
|
|
4
|
+
module DDP
|
|
5
|
+
module MonkeyPatch
|
|
6
|
+
WARNING_PREFIX = "[torch-ddp]".freeze
|
|
7
|
+
|
|
8
|
+
class << self
|
|
9
|
+
def apply_if_needed
|
|
10
|
+
return if defined?(@applied) && @applied
|
|
11
|
+
|
|
12
|
+
missing = missing_features
|
|
13
|
+
return if missing.empty?
|
|
14
|
+
|
|
15
|
+
warn("#{WARNING_PREFIX} Applying torch compatibility patch for: #{missing.join(', ')}. Please upgrade the torch gem for native support.")
|
|
16
|
+
patch_cuda_set_device if missing.include?(:cuda_set_device)
|
|
17
|
+
patch_device_helpers
|
|
18
|
+
patch_load if missing.include?(:load_keywords)
|
|
19
|
+
patch_tensor_item if missing.include?(:tensor_item_scalar)
|
|
20
|
+
@applied = true
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
private
|
|
24
|
+
|
|
25
|
+
def missing_features
|
|
26
|
+
missing = []
|
|
27
|
+
missing << :cuda_set_device unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device)
|
|
28
|
+
missing << :load_keywords unless load_supports_map_location_and_weights_only?
|
|
29
|
+
missing << :tensor_item_scalar unless tensor_item_returns_scalar?
|
|
30
|
+
missing
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def load_supports_map_location_and_weights_only?
|
|
34
|
+
params = Torch.method(:load).parameters
|
|
35
|
+
keyword_names = params.select { |kind, _| [:key, :keyreq].include?(kind) }.map(&:last)
|
|
36
|
+
keyword_names.include?(:map_location) && keyword_names.include?(:weights_only)
|
|
37
|
+
rescue NameError
|
|
38
|
+
false
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def tensor_item_returns_scalar?
|
|
42
|
+
value = Torch.tensor([[1]]).item
|
|
43
|
+
value.is_a?(Numeric) || value == true || value == false
|
|
44
|
+
rescue StandardError
|
|
45
|
+
true
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def patch_cuda_set_device
|
|
49
|
+
return unless Torch.const_defined?(:CUDA)
|
|
50
|
+
|
|
51
|
+
Torch::CUDA.singleton_class.class_eval do
|
|
52
|
+
define_method(:set_device) do |device_id|
|
|
53
|
+
Torch::DDP::MonkeyPatch.cuda_set_device!(device_id)
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def cuda_set_device!(device_id)
|
|
59
|
+
cuda_set_device_proc.call(Integer(device_id))
|
|
60
|
+
end
|
|
61
|
+
public :cuda_set_device!
|
|
62
|
+
|
|
63
|
+
def cuda_set_device_proc
|
|
64
|
+
@cuda_set_device_proc ||= begin
|
|
65
|
+
candidates = [
|
|
66
|
+
ENV["LIBCUDART_PATH"],
|
|
67
|
+
"/usr/local/cuda/lib64/libcudart.so",
|
|
68
|
+
"/usr/local/cuda/lib/libcudart.so",
|
|
69
|
+
"/usr/local/cuda/lib/libcudart.dylib",
|
|
70
|
+
"libcudart.so.12",
|
|
71
|
+
"libcudart.so.11",
|
|
72
|
+
"libcudart.so",
|
|
73
|
+
"libcudart.dylib"
|
|
74
|
+
].compact
|
|
75
|
+
|
|
76
|
+
function = nil
|
|
77
|
+
candidates.each do |path|
|
|
78
|
+
begin
|
|
79
|
+
handle = Fiddle.dlopen(path)
|
|
80
|
+
function = Fiddle::Function.new(handle["cudaSetDevice"], [Fiddle::TYPE_INT], Fiddle::TYPE_INT)
|
|
81
|
+
break
|
|
82
|
+
rescue Fiddle::DLError
|
|
83
|
+
next
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
if function
|
|
88
|
+
->(device_id) do
|
|
89
|
+
result = function.call(device_id)
|
|
90
|
+
raise Torch::Error, "cudaSetDevice(#{device_id}) failed with code #{result}" unless result.zero?
|
|
91
|
+
nil
|
|
92
|
+
end
|
|
93
|
+
else
|
|
94
|
+
->(device_id) do
|
|
95
|
+
raise Torch::Error, "Torch::CUDA.set_device is unavailable; ensure torch is built with CUDA or upgrade torch."
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def patch_device_helpers
|
|
102
|
+
Torch::Device.class_eval do
|
|
103
|
+
define_method(:to_s) { _str }
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
unless Torch.const_defined?(:DeviceString)
|
|
107
|
+
Torch.const_set(
|
|
108
|
+
:DeviceString,
|
|
109
|
+
Class.new(String) do
|
|
110
|
+
def initialize(device)
|
|
111
|
+
@device = device
|
|
112
|
+
super(device._str)
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def type
|
|
116
|
+
@device.type
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def index
|
|
120
|
+
@device.index
|
|
121
|
+
end
|
|
122
|
+
end
|
|
123
|
+
)
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
Torch::Tensor.class_eval do
|
|
127
|
+
define_method(:device) { Torch::DeviceString.new(_device) }
|
|
128
|
+
end
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def patch_tensor_item
|
|
132
|
+
Torch::Tensor.class_eval do
|
|
133
|
+
alias_method :_torch_ddp_original_item, :item unless method_defined?(:_torch_ddp_original_item)
|
|
134
|
+
|
|
135
|
+
def item
|
|
136
|
+
value = _torch_ddp_original_item
|
|
137
|
+
value.is_a?(Array) ? value.flatten.first : value
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
def patch_load
|
|
143
|
+
patch_load_helpers
|
|
144
|
+
|
|
145
|
+
Torch.singleton_class.class_eval do
|
|
146
|
+
alias_method :_torch_ddp_original_load, :load unless method_defined?(:_torch_ddp_original_load)
|
|
147
|
+
|
|
148
|
+
def load(filename, map_location: nil, weights_only: false)
|
|
149
|
+
load_device = map_location_device(map_location) if map_location
|
|
150
|
+
result =
|
|
151
|
+
if load_device && respond_to?(:_load_with_device)
|
|
152
|
+
Torch::DDP::MonkeyPatch.load_with_device(filename, load_device)
|
|
153
|
+
else
|
|
154
|
+
_torch_ddp_original_load(filename)
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
ensure_weights_only_contents!(result) if weights_only
|
|
158
|
+
result = apply_map_location(result, map_location) if map_location
|
|
159
|
+
result
|
|
160
|
+
end
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def patch_load_helpers
|
|
165
|
+
Torch.singleton_class.class_eval do
|
|
166
|
+
const_set(
|
|
167
|
+
:WEIGHTS_ONLY_PRIMITIVE_CLASSES,
|
|
168
|
+
[NilClass, TrueClass, FalseClass, Integer, Float, String].freeze
|
|
169
|
+
) unless const_defined?(:WEIGHTS_ONLY_PRIMITIVE_CLASSES)
|
|
170
|
+
|
|
171
|
+
unless method_defined?(:ensure_weights_only_contents!)
|
|
172
|
+
def ensure_weights_only_contents!(obj)
|
|
173
|
+
case obj
|
|
174
|
+
when *WEIGHTS_ONLY_PRIMITIVE_CLASSES, Tensor
|
|
175
|
+
obj
|
|
176
|
+
when Array
|
|
177
|
+
obj.each { |value| ensure_weights_only_contents!(value) }
|
|
178
|
+
when Hash
|
|
179
|
+
obj.each do |key, value|
|
|
180
|
+
ensure_weights_only_contents!(key)
|
|
181
|
+
ensure_weights_only_contents!(value)
|
|
182
|
+
end
|
|
183
|
+
else
|
|
184
|
+
raise Error, "weights_only load supports tensors, primitive Ruby types, arrays, and hashes (found #{obj.class.name})"
|
|
185
|
+
end
|
|
186
|
+
end
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
unless method_defined?(:map_location_device)
|
|
190
|
+
def map_location_device(map_location)
|
|
191
|
+
case map_location
|
|
192
|
+
when Device, String, Symbol
|
|
193
|
+
normalize_map_location_device(map_location)
|
|
194
|
+
when Hash
|
|
195
|
+
devices = map_location.values.filter_map do |value|
|
|
196
|
+
begin
|
|
197
|
+
normalize_map_location_device(value)
|
|
198
|
+
rescue StandardError
|
|
199
|
+
nil
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
return nil if devices.empty?
|
|
203
|
+
devices.uniq!
|
|
204
|
+
devices.one? ? devices.first : nil
|
|
205
|
+
else
|
|
206
|
+
nil
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
unless method_defined?(:apply_map_location)
|
|
212
|
+
def apply_map_location(obj, map_location)
|
|
213
|
+
case obj
|
|
214
|
+
when Tensor
|
|
215
|
+
map_tensor_location(obj, map_location)
|
|
216
|
+
when Array
|
|
217
|
+
obj.map { |value| apply_map_location(value, map_location) }
|
|
218
|
+
when Hash
|
|
219
|
+
obj.each_with_object({}) do |(key, value), memo|
|
|
220
|
+
memo[apply_map_location(key, map_location)] = apply_map_location(value, map_location)
|
|
221
|
+
end
|
|
222
|
+
else
|
|
223
|
+
obj
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
unless method_defined?(:map_tensor_location)
|
|
229
|
+
def map_tensor_location(tensor, map_location)
|
|
230
|
+
case map_location
|
|
231
|
+
when nil
|
|
232
|
+
tensor
|
|
233
|
+
when Hash
|
|
234
|
+
target = lookup_map_location_target(map_location, tensor.device)
|
|
235
|
+
return tensor if target.nil?
|
|
236
|
+
map_tensor_location(tensor, target)
|
|
237
|
+
else
|
|
238
|
+
return map_tensor_location_callable(tensor, map_location) if map_location.respond_to?(:call)
|
|
239
|
+
device = normalize_map_location_device(map_location)
|
|
240
|
+
tensor.to(device)
|
|
241
|
+
end
|
|
242
|
+
end
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
unless method_defined?(:map_tensor_location_callable)
|
|
246
|
+
def map_tensor_location_callable(tensor, callable)
|
|
247
|
+
mapped = callable.call(tensor, map_location_device_tag(tensor.device))
|
|
248
|
+
return tensor if mapped.nil?
|
|
249
|
+
unless mapped.is_a?(Tensor)
|
|
250
|
+
raise Error, "map_location callable must return a Tensor or nil (got #{mapped.class.name})"
|
|
251
|
+
end
|
|
252
|
+
mapped
|
|
253
|
+
end
|
|
254
|
+
end
|
|
255
|
+
|
|
256
|
+
unless method_defined?(:lookup_map_location_target)
|
|
257
|
+
def lookup_map_location_target(mapping, device)
|
|
258
|
+
key = map_location_device_tag(device)
|
|
259
|
+
mapping.each do |candidate, value|
|
|
260
|
+
candidate_key =
|
|
261
|
+
case candidate
|
|
262
|
+
when Device
|
|
263
|
+
map_location_device_tag(candidate)
|
|
264
|
+
when String, Symbol
|
|
265
|
+
candidate.to_s
|
|
266
|
+
else
|
|
267
|
+
candidate
|
|
268
|
+
end
|
|
269
|
+
return value if candidate_key == key
|
|
270
|
+
end
|
|
271
|
+
nil
|
|
272
|
+
end
|
|
273
|
+
end
|
|
274
|
+
|
|
275
|
+
unless method_defined?(:map_location_device_tag)
|
|
276
|
+
def map_location_device_tag(device)
|
|
277
|
+
case device
|
|
278
|
+
when Device
|
|
279
|
+
tag = device.type
|
|
280
|
+
tag += ":#{device.index}" unless device.index.nil?
|
|
281
|
+
tag
|
|
282
|
+
when String, Symbol
|
|
283
|
+
device.to_s
|
|
284
|
+
else
|
|
285
|
+
raise Error, "Unknown device reference: #{device.inspect}"
|
|
286
|
+
end
|
|
287
|
+
end
|
|
288
|
+
end
|
|
289
|
+
|
|
290
|
+
unless method_defined?(:normalize_map_location_device)
|
|
291
|
+
def normalize_map_location_device(location)
|
|
292
|
+
case location
|
|
293
|
+
when Device
|
|
294
|
+
location
|
|
295
|
+
when String, Symbol
|
|
296
|
+
device(location.to_s)
|
|
297
|
+
else
|
|
298
|
+
raise Error, "Unsupported map_location: #{location.inspect}"
|
|
299
|
+
end
|
|
300
|
+
end
|
|
301
|
+
end
|
|
302
|
+
end
|
|
303
|
+
end
|
|
304
|
+
end
|
|
305
|
+
|
|
306
|
+
module_function
|
|
307
|
+
|
|
308
|
+
def load_with_device(filename, device)
|
|
309
|
+
fallback_load =
|
|
310
|
+
if Torch.respond_to?(:_torch_ddp_original_load)
|
|
311
|
+
Torch.method(:_torch_ddp_original_load)
|
|
312
|
+
else
|
|
313
|
+
Torch.method(:load)
|
|
314
|
+
end
|
|
315
|
+
|
|
316
|
+
return fallback_load.call(filename) unless Torch.respond_to?(:_load_with_device)
|
|
317
|
+
|
|
318
|
+
device_str = device.respond_to?(:_str) ? device._str : device.to_s
|
|
319
|
+
Torch.send(:to_ruby, Torch._load_with_device(filename, device_str))
|
|
320
|
+
rescue StandardError
|
|
321
|
+
fallback_load.call(filename)
|
|
322
|
+
end
|
|
323
|
+
end
|
|
324
|
+
end
|
|
325
|
+
end
|