torch-ddp 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +46 -0
- data/README.md +114 -0
- data/bin/torchrun +6 -0
- data/examples/benchmark/training.rb +374 -0
- data/examples/mnist/distributed.rb +240 -0
- data/ext/torch_ddp/distributed.cpp +348 -0
- data/ext/torch_ddp/ext.cpp +11 -0
- data/ext/torch_ddp/extconf.rb +155 -0
- data/lib/torch/ddp/monkey_patch.rb +325 -0
- data/lib/torch/ddp/version.rb +5 -0
- data/lib/torch/distributed.rb +466 -0
- data/lib/torch/nn/parallel/distributed_data_parallel.rb +115 -0
- data/lib/torch/torchrun.rb +531 -0
- data/lib/torch-ddp.rb +8 -0
- data/test/distributed_test.rb +243 -0
- data/test/support/net.rb +42 -0
- data/test/support/scripts/show_ranks.rb +7 -0
- data/test/support/tensor.pth +0 -0
- data/test/test_helper.rb +71 -0
- data/test/torchrun_test.rb +33 -0
- metadata +92 -0
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
require "torch"
|
|
2
|
+
require "torch/ddp_ext"
|
|
3
|
+
require "socket"
|
|
4
|
+
require "rbconfig"
|
|
5
|
+
|
|
6
|
+
module Torch
|
|
7
|
+
module Distributed
|
|
8
|
+
DEFAULT_DEVICE_BACKENDS = {
|
|
9
|
+
"cpu" => "gloo",
|
|
10
|
+
"cuda" => "nccl",
|
|
11
|
+
"xpu" => "xccl",
|
|
12
|
+
"mps" => "gloo"
|
|
13
|
+
}.freeze
|
|
14
|
+
|
|
15
|
+
DEFAULT_TIMEOUT = 30 * 60 unless const_defined?(:DEFAULT_TIMEOUT)
|
|
16
|
+
|
|
17
|
+
unless const_defined?(:ReduceOp)
|
|
18
|
+
module ReduceOp
|
|
19
|
+
SUM = 0
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
SPAWN_ENV_KEY = "TORCH_DISTRIBUTED_SPAWNED".freeze
|
|
24
|
+
SPAWN_RANK_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_RANK".freeze
|
|
25
|
+
SPAWN_WORLD_SIZE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_WORLD_SIZE".freeze
|
|
26
|
+
SPAWN_PORT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PORT".freeze
|
|
27
|
+
SPAWN_PIPE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PIPE".freeze
|
|
28
|
+
SPAWN_SCRIPT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_SCRIPT".freeze
|
|
29
|
+
SPAWN_TEST_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_TEST".freeze
|
|
30
|
+
SPAWN_ARGV = ARGV.dup.freeze
|
|
31
|
+
|
|
32
|
+
class << self
|
|
33
|
+
def initialized?
|
|
34
|
+
_initialized?
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def init_process_group(backend = nil, init_method: "env://", store: nil, rank: nil, world_size: nil, timeout: DEFAULT_TIMEOUT, wait_for_workers: true, device_id: nil)
|
|
38
|
+
raise Torch::Error, "torch.distributed is not available" unless available?
|
|
39
|
+
|
|
40
|
+
backend ||= default_backend_for(device_id)
|
|
41
|
+
|
|
42
|
+
if store.nil?
|
|
43
|
+
case init_method
|
|
44
|
+
when "env://"
|
|
45
|
+
rank = Integer(ENV.fetch("RANK")) if rank.nil?
|
|
46
|
+
world_size = Integer(ENV.fetch("WORLD_SIZE")) if world_size.nil?
|
|
47
|
+
master_addr = ENV.fetch("MASTER_ADDR", "127.0.0.1")
|
|
48
|
+
master_port = Integer(ENV.fetch("MASTER_PORT", "29500"))
|
|
49
|
+
raise ArgumentError, "rank is required" if rank.nil?
|
|
50
|
+
raise ArgumentError, "world_size is required" if world_size.nil?
|
|
51
|
+
is_master = rank.zero?
|
|
52
|
+
store = TCPStore.new(master_addr, master_port, world_size, is_master, wait_for_workers: wait_for_workers, timeout: timeout)
|
|
53
|
+
else
|
|
54
|
+
raise ArgumentError, "store is required when using init_method=#{init_method.inspect}"
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
raise ArgumentError, "rank is required" if rank.nil?
|
|
59
|
+
raise ArgumentError, "world_size is required" if world_size.nil?
|
|
60
|
+
|
|
61
|
+
device_id ||= default_device_id_for_backend(backend, rank, world_size)
|
|
62
|
+
|
|
63
|
+
timeout_ms = (timeout * 1000).to_i
|
|
64
|
+
bound_device_id = device_id.nil? ? -1 : Integer(device_id)
|
|
65
|
+
if backend == "nccl" && bound_device_id >= 0 && Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device)
|
|
66
|
+
device_count = Torch::CUDA.device_count if Torch::CUDA.respond_to?(:device_count)
|
|
67
|
+
# Only attempt to switch devices when the requested id exists to avoid
|
|
68
|
+
# raising on hosts with fewer GPUs than the provided local rank.
|
|
69
|
+
Torch::CUDA.set_device(bound_device_id) if device_count.nil? || bound_device_id < device_count
|
|
70
|
+
end
|
|
71
|
+
pg = _init_process_group(backend, store, rank, world_size, timeout_ms, bound_device_id)
|
|
72
|
+
warmup_process_group(pg, backend)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def destroy_process_group
|
|
76
|
+
_destroy_process_group
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
def default_process_group
|
|
80
|
+
_default_process_group
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def get_world_size(group = nil)
|
|
84
|
+
ensure_process_group!(group)
|
|
85
|
+
_get_world_size(group)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def get_rank(group = nil)
|
|
89
|
+
ensure_process_group!(group)
|
|
90
|
+
_get_rank(group)
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
def barrier(group: nil)
|
|
94
|
+
ensure_process_group!(group)
|
|
95
|
+
_barrier(group)
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def all_reduce(tensor, op: ReduceOp::SUM, group: nil)
|
|
99
|
+
ensure_process_group!(group)
|
|
100
|
+
_all_reduce(tensor, op, group)
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def broadcast(tensor, src:, group: nil)
|
|
104
|
+
ensure_process_group!(group)
|
|
105
|
+
_broadcast(tensor, src, group)
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def register_ddp_hook(tensor, process_group, world_size)
|
|
109
|
+
ensure_process_group!(process_group)
|
|
110
|
+
_register_ddp_hook(tensor, process_group, Integer(world_size))
|
|
111
|
+
rescue NoMethodError
|
|
112
|
+
# Fallback for environments built without the native helper; this may
|
|
113
|
+
# still call back into Ruby from autograd threads.
|
|
114
|
+
tensor.register_hook do |grad|
|
|
115
|
+
all_reduce(grad, group: process_group)
|
|
116
|
+
grad.div!(world_size.to_f)
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def get_default_backend_for_device(device)
|
|
121
|
+
backend = DEFAULT_DEVICE_BACKENDS[device_type_from(device)]
|
|
122
|
+
raise ArgumentError, "Default backend not registered for device: #{device.inspect}" unless backend
|
|
123
|
+
backend
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def fork_world(world_size, host: "127.0.0.1", start_method: :fork, &block)
|
|
127
|
+
raise ArgumentError, "world_size must be positive" unless world_size.to_i.positive?
|
|
128
|
+
raise ArgumentError, "block required" unless block
|
|
129
|
+
|
|
130
|
+
start_method = normalize_start_method(start_method)
|
|
131
|
+
return run_spawn_worker(&block) if start_method == :spawn && spawn_worker?
|
|
132
|
+
|
|
133
|
+
fork_spawn_world(world_size, host: host, start_method: start_method, &block)
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def fork_spawn_world(world_size, host:, start_method:, &block)
|
|
137
|
+
port = free_port(host: host)
|
|
138
|
+
readers = []
|
|
139
|
+
pids = []
|
|
140
|
+
pgid = nil
|
|
141
|
+
completed = false
|
|
142
|
+
|
|
143
|
+
begin
|
|
144
|
+
world_size.times do |rank|
|
|
145
|
+
reader, writer = IO.pipe
|
|
146
|
+
begin
|
|
147
|
+
case start_method
|
|
148
|
+
when :fork
|
|
149
|
+
pids << fork_worker(reader, writer, rank, port, world_size, &block)
|
|
150
|
+
when :spawn
|
|
151
|
+
pid, pgid = spawn_worker(reader, writer, rank, port, host: host, world_size: world_size, pgid: pgid)
|
|
152
|
+
pids << pid
|
|
153
|
+
else
|
|
154
|
+
raise ArgumentError, "Unsupported start_method: #{start_method.inspect}"
|
|
155
|
+
end
|
|
156
|
+
readers << reader
|
|
157
|
+
writer.close unless writer.closed?
|
|
158
|
+
rescue Exception
|
|
159
|
+
reader.close unless reader.closed?
|
|
160
|
+
writer.close unless writer.closed?
|
|
161
|
+
raise
|
|
162
|
+
end
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
read_failure = Object.new
|
|
166
|
+
|
|
167
|
+
outputs = readers.map do |reader|
|
|
168
|
+
begin
|
|
169
|
+
Marshal.load(reader)
|
|
170
|
+
rescue EOFError
|
|
171
|
+
read_failure
|
|
172
|
+
ensure
|
|
173
|
+
reader.close unless reader.closed?
|
|
174
|
+
end
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
statuses = pids.each_with_index.map do |pid, idx|
|
|
178
|
+
_pid, status = Process.wait2(pid)
|
|
179
|
+
[idx, pid, status]
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
statuses.each do |idx, pid, status|
|
|
183
|
+
output = outputs[idx]
|
|
184
|
+
if output.equal?(read_failure)
|
|
185
|
+
raise Torch::Error, "Child #{pid} closed pipe before sending result (status #{status.exitstatus})"
|
|
186
|
+
end
|
|
187
|
+
if !status.success? || (output.is_a?(Hash) && output[:error])
|
|
188
|
+
message = if output.is_a?(Hash) && output[:error]
|
|
189
|
+
"Child #{pid} failed: #{output[:error]}\n#{Array(output[:backtrace]).join("\n")}"
|
|
190
|
+
else
|
|
191
|
+
"Child #{pid} exited with status #{status.exitstatus}"
|
|
192
|
+
end
|
|
193
|
+
raise Torch::Error, message
|
|
194
|
+
end
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
completed = true
|
|
198
|
+
outputs
|
|
199
|
+
ensure
|
|
200
|
+
# Ensure child workers are cleaned up if an interrupt or error occurs.
|
|
201
|
+
terminate_processes(pids, pgid: pgid) unless completed
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
def free_port(host: "127.0.0.1")
|
|
206
|
+
server = TCPServer.new(host, 0)
|
|
207
|
+
port = server.addr[1]
|
|
208
|
+
server.close
|
|
209
|
+
port
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
private
|
|
213
|
+
|
|
214
|
+
def ensure_process_group!(group)
|
|
215
|
+
return if group || initialized?
|
|
216
|
+
|
|
217
|
+
raise Torch::Error, "Default process group is not initialized"
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
def default_device_id_for_backend(backend, rank, world_size)
|
|
221
|
+
return unless backend == "nccl"
|
|
222
|
+
|
|
223
|
+
default_local_rank(rank, world_size)
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
def warmup_process_group(pg, backend)
|
|
227
|
+
return pg unless backend == "nccl"
|
|
228
|
+
|
|
229
|
+
# Only warm up when a native process group was returned.
|
|
230
|
+
# Test helpers may stub out `_init_process_group` and return arbitrary
|
|
231
|
+
# Ruby objects, which cannot be passed to the C++ bindings.
|
|
232
|
+
return pg unless pg.nil? || (defined?(Torch::Distributed::ProcessGroup) && pg.is_a?(Torch::Distributed::ProcessGroup))
|
|
233
|
+
|
|
234
|
+
# Prime NCCL communicators so the first user-visible collective is fast
|
|
235
|
+
_barrier(pg)
|
|
236
|
+
pg
|
|
237
|
+
rescue
|
|
238
|
+
_destroy_process_group
|
|
239
|
+
raise
|
|
240
|
+
end
|
|
241
|
+
|
|
242
|
+
def default_local_rank(rank, world_size)
|
|
243
|
+
local_rank = env_integer("LOCAL_RANK")
|
|
244
|
+
return local_rank unless local_rank.nil?
|
|
245
|
+
|
|
246
|
+
local_world_size = env_integer("LOCAL_WORLD_SIZE") || world_size
|
|
247
|
+
return unless local_world_size && rank
|
|
248
|
+
|
|
249
|
+
rank % local_world_size if local_world_size.positive?
|
|
250
|
+
end
|
|
251
|
+
|
|
252
|
+
def env_integer(key)
|
|
253
|
+
Integer(ENV[key]) if ENV.key?(key)
|
|
254
|
+
rescue ArgumentError
|
|
255
|
+
nil
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
def default_backend_for(device_id)
|
|
259
|
+
get_default_backend_for_device(device_id)
|
|
260
|
+
end
|
|
261
|
+
|
|
262
|
+
def device_type_from(device)
|
|
263
|
+
case device
|
|
264
|
+
when Torch::Device
|
|
265
|
+
device.type
|
|
266
|
+
when NilClass
|
|
267
|
+
accelerator_type || "cpu"
|
|
268
|
+
when String
|
|
269
|
+
Torch.device(device).type
|
|
270
|
+
when Integer
|
|
271
|
+
return accelerator_type || "cpu" if device.negative?
|
|
272
|
+
if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:device_count)
|
|
273
|
+
max = Torch::CUDA.device_count
|
|
274
|
+
return accelerator_type || "cpu" if max <= 0 || device >= max
|
|
275
|
+
return Torch.device("cuda:#{device}").type
|
|
276
|
+
end
|
|
277
|
+
accelerator_type || "cpu"
|
|
278
|
+
else
|
|
279
|
+
return device.type if device.respond_to?(:type)
|
|
280
|
+
Torch.device(device).type
|
|
281
|
+
end
|
|
282
|
+
rescue => e
|
|
283
|
+
raise ArgumentError, "Invalid device #{device.inspect}: #{e.message}"
|
|
284
|
+
end
|
|
285
|
+
|
|
286
|
+
def accelerator_type
|
|
287
|
+
acc = Torch::Accelerator.current_accelerator
|
|
288
|
+
acc.type if acc && acc.respond_to?(:type)
|
|
289
|
+
rescue
|
|
290
|
+
nil
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
def normalize_start_method(start_method)
|
|
294
|
+
method = start_method&.to_sym
|
|
295
|
+
return method if [:fork, :spawn].include?(method)
|
|
296
|
+
|
|
297
|
+
raise ArgumentError, "start_method must be :fork or :spawn (got #{start_method.inspect})"
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
def spawn_worker?
|
|
301
|
+
ENV[SPAWN_ENV_KEY] == "1"
|
|
302
|
+
end
|
|
303
|
+
|
|
304
|
+
def run_spawn_worker(&block)
|
|
305
|
+
rank = Integer(ENV.fetch(SPAWN_RANK_ENV_KEY))
|
|
306
|
+
port = Integer(ENV.fetch(SPAWN_PORT_ENV_KEY))
|
|
307
|
+
pipe_fd = Integer(ENV.fetch(SPAWN_PIPE_ENV_KEY))
|
|
308
|
+
|
|
309
|
+
writer = IO.new(pipe_fd, "wb")
|
|
310
|
+
writer.binmode
|
|
311
|
+
writer.sync = true
|
|
312
|
+
|
|
313
|
+
result = block.call(rank, port)
|
|
314
|
+
Marshal.dump(result, writer)
|
|
315
|
+
writer.flush
|
|
316
|
+
writer.close
|
|
317
|
+
Process.exit!(0)
|
|
318
|
+
rescue Exception => e
|
|
319
|
+
begin
|
|
320
|
+
if defined?(writer) && writer && !writer.closed?
|
|
321
|
+
Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer)
|
|
322
|
+
writer.flush
|
|
323
|
+
writer.close
|
|
324
|
+
end
|
|
325
|
+
rescue StandardError
|
|
326
|
+
# best-effort error reporting back to parent
|
|
327
|
+
ensure
|
|
328
|
+
Process.exit!(1)
|
|
329
|
+
end
|
|
330
|
+
end
|
|
331
|
+
|
|
332
|
+
def fork_worker(reader, writer, rank, port, world_size, &block)
|
|
333
|
+
fork do
|
|
334
|
+
reader.close
|
|
335
|
+
begin
|
|
336
|
+
ENV["LOCAL_RANK"] = rank.to_s
|
|
337
|
+
ENV["LOCAL_WORLD_SIZE"] = world_size.to_s
|
|
338
|
+
ENV["RANK"] = rank.to_s
|
|
339
|
+
ENV["WORLD_SIZE"] = world_size.to_s
|
|
340
|
+
writer.binmode
|
|
341
|
+
writer.sync = true
|
|
342
|
+
result = block.call(rank, port)
|
|
343
|
+
Marshal.dump(result, writer)
|
|
344
|
+
writer.flush
|
|
345
|
+
writer.close
|
|
346
|
+
Process.exit!(0)
|
|
347
|
+
rescue => e
|
|
348
|
+
Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer)
|
|
349
|
+
writer.flush
|
|
350
|
+
writer.close
|
|
351
|
+
Process.exit!(1)
|
|
352
|
+
ensure
|
|
353
|
+
writer.close unless writer.closed?
|
|
354
|
+
end
|
|
355
|
+
end
|
|
356
|
+
end
|
|
357
|
+
|
|
358
|
+
def spawn_worker(reader, writer, rank, port, host:, world_size:, pgid: nil)
|
|
359
|
+
writer.binmode
|
|
360
|
+
writer.close_on_exec = false
|
|
361
|
+
|
|
362
|
+
script = ENV[SPAWN_SCRIPT_ENV_KEY] || $0
|
|
363
|
+
env = {
|
|
364
|
+
SPAWN_ENV_KEY => "1",
|
|
365
|
+
SPAWN_RANK_ENV_KEY => rank.to_s,
|
|
366
|
+
SPAWN_WORLD_SIZE_ENV_KEY => world_size.to_s,
|
|
367
|
+
SPAWN_PORT_ENV_KEY => port.to_s,
|
|
368
|
+
SPAWN_PIPE_ENV_KEY => writer.fileno.to_s,
|
|
369
|
+
"LOCAL_RANK" => rank.to_s,
|
|
370
|
+
"LOCAL_WORLD_SIZE" => world_size.to_s,
|
|
371
|
+
"MASTER_ADDR" => host,
|
|
372
|
+
"MASTER_PORT" => port.to_s,
|
|
373
|
+
"RANK" => rank.to_s,
|
|
374
|
+
"WORLD_SIZE" => world_size.to_s
|
|
375
|
+
}
|
|
376
|
+
env["RUBYLIB"] = [ENV["RUBYLIB"], $LOAD_PATH.join(File::PATH_SEPARATOR)].compact.reject(&:empty?).join(File::PATH_SEPARATOR)
|
|
377
|
+
|
|
378
|
+
spawn_opts = {close_others: false}
|
|
379
|
+
spawn_opts[:pgroup] = pgid ? pgid : true
|
|
380
|
+
|
|
381
|
+
pid = Process.spawn(env, RbConfig.ruby, script, *spawn_argv, spawn_opts)
|
|
382
|
+
pgid ||= pid
|
|
383
|
+
[pid, pgid]
|
|
384
|
+
rescue SystemCallError => e
|
|
385
|
+
raise Torch::Error, "failed to spawn worker #{rank}: #{e.message}"
|
|
386
|
+
end
|
|
387
|
+
|
|
388
|
+
def spawn_argv
|
|
389
|
+
test_filter = ENV[SPAWN_TEST_ENV_KEY]
|
|
390
|
+
return SPAWN_ARGV unless test_filter
|
|
391
|
+
return SPAWN_ARGV if SPAWN_ARGV.include?("-n")
|
|
392
|
+
|
|
393
|
+
# Restrict child to the specific test that triggered the spawn
|
|
394
|
+
SPAWN_ARGV + ["-n", test_filter]
|
|
395
|
+
end
|
|
396
|
+
|
|
397
|
+
def terminate_processes(pids, pgid: nil)
|
|
398
|
+
return if pids.empty? && !pgid
|
|
399
|
+
|
|
400
|
+
send_process_group_signal(pgid, "TERM")
|
|
401
|
+
pids.each { |pid| safe_kill(pid, "TERM") }
|
|
402
|
+
sleep(0.2)
|
|
403
|
+
pids.each do |pid|
|
|
404
|
+
next unless process_alive?(pid)
|
|
405
|
+
|
|
406
|
+
safe_kill(pid, "KILL")
|
|
407
|
+
end
|
|
408
|
+
pids.each do |pid|
|
|
409
|
+
begin
|
|
410
|
+
Process.wait(pid)
|
|
411
|
+
rescue Errno::ECHILD
|
|
412
|
+
end
|
|
413
|
+
end
|
|
414
|
+
end
|
|
415
|
+
|
|
416
|
+
def send_process_group_signal(pgid, sig)
|
|
417
|
+
return unless pgid
|
|
418
|
+
|
|
419
|
+
Process.kill(sig, -pgid)
|
|
420
|
+
rescue Errno::ESRCH
|
|
421
|
+
end
|
|
422
|
+
|
|
423
|
+
def safe_kill(pid, sig)
|
|
424
|
+
Process.kill(sig, pid)
|
|
425
|
+
rescue Errno::ESRCH
|
|
426
|
+
end
|
|
427
|
+
|
|
428
|
+
def process_alive?(pid)
|
|
429
|
+
Process.kill(0, pid)
|
|
430
|
+
true
|
|
431
|
+
rescue Errno::ESRCH
|
|
432
|
+
false
|
|
433
|
+
end
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
class TCPStore
|
|
437
|
+
def self.new(host, port, world_size, is_master, wait_for_workers: true, timeout: DEFAULT_TIMEOUT)
|
|
438
|
+
Torch::Distributed._create_tcp_store(host, port, world_size, is_master, (timeout * 1000).to_i, wait_for_workers)
|
|
439
|
+
end
|
|
440
|
+
end
|
|
441
|
+
|
|
442
|
+
class FileStore
|
|
443
|
+
def self.new(path, world_size)
|
|
444
|
+
Torch::Distributed._create_file_store(path, world_size)
|
|
445
|
+
end
|
|
446
|
+
end
|
|
447
|
+
|
|
448
|
+
if respond_to?(:_create_hash_store)
|
|
449
|
+
class HashStore
|
|
450
|
+
def self.new
|
|
451
|
+
Torch::Distributed._create_hash_store
|
|
452
|
+
end
|
|
453
|
+
end
|
|
454
|
+
end
|
|
455
|
+
end
|
|
456
|
+
end
|
|
457
|
+
|
|
458
|
+
require "torch/nn/parallel/distributed_data_parallel"
|
|
459
|
+
|
|
460
|
+
at_exit do
|
|
461
|
+
begin
|
|
462
|
+
Torch::Distributed.destroy_process_group if Torch::Distributed.available? && Torch::Distributed.initialized?
|
|
463
|
+
rescue Exception
|
|
464
|
+
# best-effort cleanup to avoid leaked process groups
|
|
465
|
+
end
|
|
466
|
+
end
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
module Torch
|
|
2
|
+
module NN
|
|
3
|
+
module Parallel
|
|
4
|
+
class DistributedDataParallel < Module
|
|
5
|
+
attr_reader :module, :process_group
|
|
6
|
+
|
|
7
|
+
def initialize(mod, device_ids: nil, process_group: nil, broadcast_buffers: true)
|
|
8
|
+
super()
|
|
9
|
+
raise Torch::Error, "torch.distributed is not available" unless Torch::Distributed.available?
|
|
10
|
+
|
|
11
|
+
@module = mod
|
|
12
|
+
@broadcast_buffers = broadcast_buffers
|
|
13
|
+
@process_group = process_group || Torch::Distributed.default_process_group
|
|
14
|
+
raise Torch::Error, "Process group must be initialized before using DistributedDataParallel" unless @process_group
|
|
15
|
+
|
|
16
|
+
@world_size = Torch::Distributed.get_world_size(@process_group)
|
|
17
|
+
@rank = Torch::Distributed.get_rank(@process_group)
|
|
18
|
+
@device = normalize_device(Array(device_ids).compact.first)
|
|
19
|
+
move_to_device(@device) if @device
|
|
20
|
+
|
|
21
|
+
synchronize_parameters
|
|
22
|
+
@hook_handles = register_parameter_hooks
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def forward(*inputs, **kwargs)
|
|
26
|
+
outputs = @module.call(*move_inputs(inputs), **move_kwargs(kwargs))
|
|
27
|
+
broadcast_buffers_if_needed
|
|
28
|
+
outputs
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
alias_method :call, :forward
|
|
32
|
+
|
|
33
|
+
def train(mode = true)
|
|
34
|
+
@module.train(mode)
|
|
35
|
+
broadcast_buffers_if_needed
|
|
36
|
+
self
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
private
|
|
40
|
+
|
|
41
|
+
def normalize_device(device)
|
|
42
|
+
return nil unless device
|
|
43
|
+
return device if device.is_a?(Torch::Device)
|
|
44
|
+
|
|
45
|
+
if device.is_a?(Integer)
|
|
46
|
+
if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
|
|
47
|
+
return Torch.device("cuda:#{device}")
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
Torch.device(device)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def move_to_device(device)
|
|
55
|
+
return unless device
|
|
56
|
+
|
|
57
|
+
@module.to(device)
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def move_inputs(inputs)
|
|
61
|
+
return inputs unless @device
|
|
62
|
+
|
|
63
|
+
inputs.map { |value| move_value(value, @device) }
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def move_kwargs(kwargs)
|
|
67
|
+
return kwargs unless @device
|
|
68
|
+
|
|
69
|
+
kwargs.transform_values { |value| move_value(value, @device) }
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def move_value(value, device)
|
|
73
|
+
case value
|
|
74
|
+
when Torch::Tensor
|
|
75
|
+
value.to(device)
|
|
76
|
+
when Array
|
|
77
|
+
value.map { |v| move_value(v, device) }
|
|
78
|
+
when Hash
|
|
79
|
+
value.transform_values { |v| move_value(v, device) }
|
|
80
|
+
else
|
|
81
|
+
value
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
def synchronize_parameters
|
|
86
|
+
Torch::Distributed.barrier(group: @process_group)
|
|
87
|
+
Torch.no_grad do
|
|
88
|
+
@module.parameters.each do |param|
|
|
89
|
+
Torch::Distributed.broadcast(param, src: 0, group: @process_group)
|
|
90
|
+
end
|
|
91
|
+
broadcast_buffers_if_needed
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def broadcast_buffers_if_needed
|
|
96
|
+
return unless @broadcast_buffers
|
|
97
|
+
|
|
98
|
+
Torch.no_grad do
|
|
99
|
+
@module.buffers.each do |buffer|
|
|
100
|
+
Torch::Distributed.broadcast(buffer, src: 0, group: @process_group)
|
|
101
|
+
end
|
|
102
|
+
end
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def register_parameter_hooks
|
|
106
|
+
@module.parameters.filter_map do |param|
|
|
107
|
+
next unless param.requires_grad?
|
|
108
|
+
|
|
109
|
+
Torch::Distributed.register_ddp_hook(param, @process_group, @world_size)
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
end
|