torch-ddp 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,466 @@
1
+ require "torch"
2
+ require "torch/ddp_ext"
3
+ require "socket"
4
+ require "rbconfig"
5
+
6
+ module Torch
7
+ module Distributed
8
+ DEFAULT_DEVICE_BACKENDS = {
9
+ "cpu" => "gloo",
10
+ "cuda" => "nccl",
11
+ "xpu" => "xccl",
12
+ "mps" => "gloo"
13
+ }.freeze
14
+
15
+ DEFAULT_TIMEOUT = 30 * 60 unless const_defined?(:DEFAULT_TIMEOUT)
16
+
17
+ unless const_defined?(:ReduceOp)
18
+ module ReduceOp
19
+ SUM = 0
20
+ end
21
+ end
22
+
23
+ SPAWN_ENV_KEY = "TORCH_DISTRIBUTED_SPAWNED".freeze
24
+ SPAWN_RANK_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_RANK".freeze
25
+ SPAWN_WORLD_SIZE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_WORLD_SIZE".freeze
26
+ SPAWN_PORT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PORT".freeze
27
+ SPAWN_PIPE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PIPE".freeze
28
+ SPAWN_SCRIPT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_SCRIPT".freeze
29
+ SPAWN_TEST_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_TEST".freeze
30
+ SPAWN_ARGV = ARGV.dup.freeze
31
+
32
+ class << self
33
+ def initialized?
34
+ _initialized?
35
+ end
36
+
37
+ def init_process_group(backend = nil, init_method: "env://", store: nil, rank: nil, world_size: nil, timeout: DEFAULT_TIMEOUT, wait_for_workers: true, device_id: nil)
38
+ raise Torch::Error, "torch.distributed is not available" unless available?
39
+
40
+ backend ||= default_backend_for(device_id)
41
+
42
+ if store.nil?
43
+ case init_method
44
+ when "env://"
45
+ rank = Integer(ENV.fetch("RANK")) if rank.nil?
46
+ world_size = Integer(ENV.fetch("WORLD_SIZE")) if world_size.nil?
47
+ master_addr = ENV.fetch("MASTER_ADDR", "127.0.0.1")
48
+ master_port = Integer(ENV.fetch("MASTER_PORT", "29500"))
49
+ raise ArgumentError, "rank is required" if rank.nil?
50
+ raise ArgumentError, "world_size is required" if world_size.nil?
51
+ is_master = rank.zero?
52
+ store = TCPStore.new(master_addr, master_port, world_size, is_master, wait_for_workers: wait_for_workers, timeout: timeout)
53
+ else
54
+ raise ArgumentError, "store is required when using init_method=#{init_method.inspect}"
55
+ end
56
+ end
57
+
58
+ raise ArgumentError, "rank is required" if rank.nil?
59
+ raise ArgumentError, "world_size is required" if world_size.nil?
60
+
61
+ device_id ||= default_device_id_for_backend(backend, rank, world_size)
62
+
63
+ timeout_ms = (timeout * 1000).to_i
64
+ bound_device_id = device_id.nil? ? -1 : Integer(device_id)
65
+ if backend == "nccl" && bound_device_id >= 0 && Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device)
66
+ device_count = Torch::CUDA.device_count if Torch::CUDA.respond_to?(:device_count)
67
+ # Only attempt to switch devices when the requested id exists to avoid
68
+ # raising on hosts with fewer GPUs than the provided local rank.
69
+ Torch::CUDA.set_device(bound_device_id) if device_count.nil? || bound_device_id < device_count
70
+ end
71
+ pg = _init_process_group(backend, store, rank, world_size, timeout_ms, bound_device_id)
72
+ warmup_process_group(pg, backend)
73
+ end
74
+
75
+ def destroy_process_group
76
+ _destroy_process_group
77
+ end
78
+
79
+ def default_process_group
80
+ _default_process_group
81
+ end
82
+
83
+ def get_world_size(group = nil)
84
+ ensure_process_group!(group)
85
+ _get_world_size(group)
86
+ end
87
+
88
+ def get_rank(group = nil)
89
+ ensure_process_group!(group)
90
+ _get_rank(group)
91
+ end
92
+
93
+ def barrier(group: nil)
94
+ ensure_process_group!(group)
95
+ _barrier(group)
96
+ end
97
+
98
+ def all_reduce(tensor, op: ReduceOp::SUM, group: nil)
99
+ ensure_process_group!(group)
100
+ _all_reduce(tensor, op, group)
101
+ end
102
+
103
+ def broadcast(tensor, src:, group: nil)
104
+ ensure_process_group!(group)
105
+ _broadcast(tensor, src, group)
106
+ end
107
+
108
+ def register_ddp_hook(tensor, process_group, world_size)
109
+ ensure_process_group!(process_group)
110
+ _register_ddp_hook(tensor, process_group, Integer(world_size))
111
+ rescue NoMethodError
112
+ # Fallback for environments built without the native helper; this may
113
+ # still call back into Ruby from autograd threads.
114
+ tensor.register_hook do |grad|
115
+ all_reduce(grad, group: process_group)
116
+ grad.div!(world_size.to_f)
117
+ end
118
+ end
119
+
120
+ def get_default_backend_for_device(device)
121
+ backend = DEFAULT_DEVICE_BACKENDS[device_type_from(device)]
122
+ raise ArgumentError, "Default backend not registered for device: #{device.inspect}" unless backend
123
+ backend
124
+ end
125
+
126
+ def fork_world(world_size, host: "127.0.0.1", start_method: :fork, &block)
127
+ raise ArgumentError, "world_size must be positive" unless world_size.to_i.positive?
128
+ raise ArgumentError, "block required" unless block
129
+
130
+ start_method = normalize_start_method(start_method)
131
+ return run_spawn_worker(&block) if start_method == :spawn && spawn_worker?
132
+
133
+ fork_spawn_world(world_size, host: host, start_method: start_method, &block)
134
+ end
135
+
136
+ def fork_spawn_world(world_size, host:, start_method:, &block)
137
+ port = free_port(host: host)
138
+ readers = []
139
+ pids = []
140
+ pgid = nil
141
+ completed = false
142
+
143
+ begin
144
+ world_size.times do |rank|
145
+ reader, writer = IO.pipe
146
+ begin
147
+ case start_method
148
+ when :fork
149
+ pids << fork_worker(reader, writer, rank, port, world_size, &block)
150
+ when :spawn
151
+ pid, pgid = spawn_worker(reader, writer, rank, port, host: host, world_size: world_size, pgid: pgid)
152
+ pids << pid
153
+ else
154
+ raise ArgumentError, "Unsupported start_method: #{start_method.inspect}"
155
+ end
156
+ readers << reader
157
+ writer.close unless writer.closed?
158
+ rescue Exception
159
+ reader.close unless reader.closed?
160
+ writer.close unless writer.closed?
161
+ raise
162
+ end
163
+ end
164
+
165
+ read_failure = Object.new
166
+
167
+ outputs = readers.map do |reader|
168
+ begin
169
+ Marshal.load(reader)
170
+ rescue EOFError
171
+ read_failure
172
+ ensure
173
+ reader.close unless reader.closed?
174
+ end
175
+ end
176
+
177
+ statuses = pids.each_with_index.map do |pid, idx|
178
+ _pid, status = Process.wait2(pid)
179
+ [idx, pid, status]
180
+ end
181
+
182
+ statuses.each do |idx, pid, status|
183
+ output = outputs[idx]
184
+ if output.equal?(read_failure)
185
+ raise Torch::Error, "Child #{pid} closed pipe before sending result (status #{status.exitstatus})"
186
+ end
187
+ if !status.success? || (output.is_a?(Hash) && output[:error])
188
+ message = if output.is_a?(Hash) && output[:error]
189
+ "Child #{pid} failed: #{output[:error]}\n#{Array(output[:backtrace]).join("\n")}"
190
+ else
191
+ "Child #{pid} exited with status #{status.exitstatus}"
192
+ end
193
+ raise Torch::Error, message
194
+ end
195
+ end
196
+
197
+ completed = true
198
+ outputs
199
+ ensure
200
+ # Ensure child workers are cleaned up if an interrupt or error occurs.
201
+ terminate_processes(pids, pgid: pgid) unless completed
202
+ end
203
+ end
204
+
205
+ def free_port(host: "127.0.0.1")
206
+ server = TCPServer.new(host, 0)
207
+ port = server.addr[1]
208
+ server.close
209
+ port
210
+ end
211
+
212
+ private
213
+
214
+ def ensure_process_group!(group)
215
+ return if group || initialized?
216
+
217
+ raise Torch::Error, "Default process group is not initialized"
218
+ end
219
+
220
+ def default_device_id_for_backend(backend, rank, world_size)
221
+ return unless backend == "nccl"
222
+
223
+ default_local_rank(rank, world_size)
224
+ end
225
+
226
+ def warmup_process_group(pg, backend)
227
+ return pg unless backend == "nccl"
228
+
229
+ # Only warm up when a native process group was returned.
230
+ # Test helpers may stub out `_init_process_group` and return arbitrary
231
+ # Ruby objects, which cannot be passed to the C++ bindings.
232
+ return pg unless pg.nil? || (defined?(Torch::Distributed::ProcessGroup) && pg.is_a?(Torch::Distributed::ProcessGroup))
233
+
234
+ # Prime NCCL communicators so the first user-visible collective is fast
235
+ _barrier(pg)
236
+ pg
237
+ rescue
238
+ _destroy_process_group
239
+ raise
240
+ end
241
+
242
+ def default_local_rank(rank, world_size)
243
+ local_rank = env_integer("LOCAL_RANK")
244
+ return local_rank unless local_rank.nil?
245
+
246
+ local_world_size = env_integer("LOCAL_WORLD_SIZE") || world_size
247
+ return unless local_world_size && rank
248
+
249
+ rank % local_world_size if local_world_size.positive?
250
+ end
251
+
252
+ def env_integer(key)
253
+ Integer(ENV[key]) if ENV.key?(key)
254
+ rescue ArgumentError
255
+ nil
256
+ end
257
+
258
+ def default_backend_for(device_id)
259
+ get_default_backend_for_device(device_id)
260
+ end
261
+
262
+ def device_type_from(device)
263
+ case device
264
+ when Torch::Device
265
+ device.type
266
+ when NilClass
267
+ accelerator_type || "cpu"
268
+ when String
269
+ Torch.device(device).type
270
+ when Integer
271
+ return accelerator_type || "cpu" if device.negative?
272
+ if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:device_count)
273
+ max = Torch::CUDA.device_count
274
+ return accelerator_type || "cpu" if max <= 0 || device >= max
275
+ return Torch.device("cuda:#{device}").type
276
+ end
277
+ accelerator_type || "cpu"
278
+ else
279
+ return device.type if device.respond_to?(:type)
280
+ Torch.device(device).type
281
+ end
282
+ rescue => e
283
+ raise ArgumentError, "Invalid device #{device.inspect}: #{e.message}"
284
+ end
285
+
286
+ def accelerator_type
287
+ acc = Torch::Accelerator.current_accelerator
288
+ acc.type if acc && acc.respond_to?(:type)
289
+ rescue
290
+ nil
291
+ end
292
+
293
+ def normalize_start_method(start_method)
294
+ method = start_method&.to_sym
295
+ return method if [:fork, :spawn].include?(method)
296
+
297
+ raise ArgumentError, "start_method must be :fork or :spawn (got #{start_method.inspect})"
298
+ end
299
+
300
+ def spawn_worker?
301
+ ENV[SPAWN_ENV_KEY] == "1"
302
+ end
303
+
304
+ def run_spawn_worker(&block)
305
+ rank = Integer(ENV.fetch(SPAWN_RANK_ENV_KEY))
306
+ port = Integer(ENV.fetch(SPAWN_PORT_ENV_KEY))
307
+ pipe_fd = Integer(ENV.fetch(SPAWN_PIPE_ENV_KEY))
308
+
309
+ writer = IO.new(pipe_fd, "wb")
310
+ writer.binmode
311
+ writer.sync = true
312
+
313
+ result = block.call(rank, port)
314
+ Marshal.dump(result, writer)
315
+ writer.flush
316
+ writer.close
317
+ Process.exit!(0)
318
+ rescue Exception => e
319
+ begin
320
+ if defined?(writer) && writer && !writer.closed?
321
+ Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer)
322
+ writer.flush
323
+ writer.close
324
+ end
325
+ rescue StandardError
326
+ # best-effort error reporting back to parent
327
+ ensure
328
+ Process.exit!(1)
329
+ end
330
+ end
331
+
332
+ def fork_worker(reader, writer, rank, port, world_size, &block)
333
+ fork do
334
+ reader.close
335
+ begin
336
+ ENV["LOCAL_RANK"] = rank.to_s
337
+ ENV["LOCAL_WORLD_SIZE"] = world_size.to_s
338
+ ENV["RANK"] = rank.to_s
339
+ ENV["WORLD_SIZE"] = world_size.to_s
340
+ writer.binmode
341
+ writer.sync = true
342
+ result = block.call(rank, port)
343
+ Marshal.dump(result, writer)
344
+ writer.flush
345
+ writer.close
346
+ Process.exit!(0)
347
+ rescue => e
348
+ Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer)
349
+ writer.flush
350
+ writer.close
351
+ Process.exit!(1)
352
+ ensure
353
+ writer.close unless writer.closed?
354
+ end
355
+ end
356
+ end
357
+
358
+ def spawn_worker(reader, writer, rank, port, host:, world_size:, pgid: nil)
359
+ writer.binmode
360
+ writer.close_on_exec = false
361
+
362
+ script = ENV[SPAWN_SCRIPT_ENV_KEY] || $0
363
+ env = {
364
+ SPAWN_ENV_KEY => "1",
365
+ SPAWN_RANK_ENV_KEY => rank.to_s,
366
+ SPAWN_WORLD_SIZE_ENV_KEY => world_size.to_s,
367
+ SPAWN_PORT_ENV_KEY => port.to_s,
368
+ SPAWN_PIPE_ENV_KEY => writer.fileno.to_s,
369
+ "LOCAL_RANK" => rank.to_s,
370
+ "LOCAL_WORLD_SIZE" => world_size.to_s,
371
+ "MASTER_ADDR" => host,
372
+ "MASTER_PORT" => port.to_s,
373
+ "RANK" => rank.to_s,
374
+ "WORLD_SIZE" => world_size.to_s
375
+ }
376
+ env["RUBYLIB"] = [ENV["RUBYLIB"], $LOAD_PATH.join(File::PATH_SEPARATOR)].compact.reject(&:empty?).join(File::PATH_SEPARATOR)
377
+
378
+ spawn_opts = {close_others: false}
379
+ spawn_opts[:pgroup] = pgid ? pgid : true
380
+
381
+ pid = Process.spawn(env, RbConfig.ruby, script, *spawn_argv, spawn_opts)
382
+ pgid ||= pid
383
+ [pid, pgid]
384
+ rescue SystemCallError => e
385
+ raise Torch::Error, "failed to spawn worker #{rank}: #{e.message}"
386
+ end
387
+
388
+ def spawn_argv
389
+ test_filter = ENV[SPAWN_TEST_ENV_KEY]
390
+ return SPAWN_ARGV unless test_filter
391
+ return SPAWN_ARGV if SPAWN_ARGV.include?("-n")
392
+
393
+ # Restrict child to the specific test that triggered the spawn
394
+ SPAWN_ARGV + ["-n", test_filter]
395
+ end
396
+
397
+ def terminate_processes(pids, pgid: nil)
398
+ return if pids.empty? && !pgid
399
+
400
+ send_process_group_signal(pgid, "TERM")
401
+ pids.each { |pid| safe_kill(pid, "TERM") }
402
+ sleep(0.2)
403
+ pids.each do |pid|
404
+ next unless process_alive?(pid)
405
+
406
+ safe_kill(pid, "KILL")
407
+ end
408
+ pids.each do |pid|
409
+ begin
410
+ Process.wait(pid)
411
+ rescue Errno::ECHILD
412
+ end
413
+ end
414
+ end
415
+
416
+ def send_process_group_signal(pgid, sig)
417
+ return unless pgid
418
+
419
+ Process.kill(sig, -pgid)
420
+ rescue Errno::ESRCH
421
+ end
422
+
423
+ def safe_kill(pid, sig)
424
+ Process.kill(sig, pid)
425
+ rescue Errno::ESRCH
426
+ end
427
+
428
+ def process_alive?(pid)
429
+ Process.kill(0, pid)
430
+ true
431
+ rescue Errno::ESRCH
432
+ false
433
+ end
434
+ end
435
+
436
+ class TCPStore
437
+ def self.new(host, port, world_size, is_master, wait_for_workers: true, timeout: DEFAULT_TIMEOUT)
438
+ Torch::Distributed._create_tcp_store(host, port, world_size, is_master, (timeout * 1000).to_i, wait_for_workers)
439
+ end
440
+ end
441
+
442
+ class FileStore
443
+ def self.new(path, world_size)
444
+ Torch::Distributed._create_file_store(path, world_size)
445
+ end
446
+ end
447
+
448
+ if respond_to?(:_create_hash_store)
449
+ class HashStore
450
+ def self.new
451
+ Torch::Distributed._create_hash_store
452
+ end
453
+ end
454
+ end
455
+ end
456
+ end
457
+
458
+ require "torch/nn/parallel/distributed_data_parallel"
459
+
460
+ at_exit do
461
+ begin
462
+ Torch::Distributed.destroy_process_group if Torch::Distributed.available? && Torch::Distributed.initialized?
463
+ rescue Exception
464
+ # best-effort cleanup to avoid leaked process groups
465
+ end
466
+ end
@@ -0,0 +1,115 @@
1
+ module Torch
2
+ module NN
3
+ module Parallel
4
+ class DistributedDataParallel < Module
5
+ attr_reader :module, :process_group
6
+
7
+ def initialize(mod, device_ids: nil, process_group: nil, broadcast_buffers: true)
8
+ super()
9
+ raise Torch::Error, "torch.distributed is not available" unless Torch::Distributed.available?
10
+
11
+ @module = mod
12
+ @broadcast_buffers = broadcast_buffers
13
+ @process_group = process_group || Torch::Distributed.default_process_group
14
+ raise Torch::Error, "Process group must be initialized before using DistributedDataParallel" unless @process_group
15
+
16
+ @world_size = Torch::Distributed.get_world_size(@process_group)
17
+ @rank = Torch::Distributed.get_rank(@process_group)
18
+ @device = normalize_device(Array(device_ids).compact.first)
19
+ move_to_device(@device) if @device
20
+
21
+ synchronize_parameters
22
+ @hook_handles = register_parameter_hooks
23
+ end
24
+
25
+ def forward(*inputs, **kwargs)
26
+ outputs = @module.call(*move_inputs(inputs), **move_kwargs(kwargs))
27
+ broadcast_buffers_if_needed
28
+ outputs
29
+ end
30
+
31
+ alias_method :call, :forward
32
+
33
+ def train(mode = true)
34
+ @module.train(mode)
35
+ broadcast_buffers_if_needed
36
+ self
37
+ end
38
+
39
+ private
40
+
41
+ def normalize_device(device)
42
+ return nil unless device
43
+ return device if device.is_a?(Torch::Device)
44
+
45
+ if device.is_a?(Integer)
46
+ if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
47
+ return Torch.device("cuda:#{device}")
48
+ end
49
+ end
50
+
51
+ Torch.device(device)
52
+ end
53
+
54
+ def move_to_device(device)
55
+ return unless device
56
+
57
+ @module.to(device)
58
+ end
59
+
60
+ def move_inputs(inputs)
61
+ return inputs unless @device
62
+
63
+ inputs.map { |value| move_value(value, @device) }
64
+ end
65
+
66
+ def move_kwargs(kwargs)
67
+ return kwargs unless @device
68
+
69
+ kwargs.transform_values { |value| move_value(value, @device) }
70
+ end
71
+
72
+ def move_value(value, device)
73
+ case value
74
+ when Torch::Tensor
75
+ value.to(device)
76
+ when Array
77
+ value.map { |v| move_value(v, device) }
78
+ when Hash
79
+ value.transform_values { |v| move_value(v, device) }
80
+ else
81
+ value
82
+ end
83
+ end
84
+
85
+ def synchronize_parameters
86
+ Torch::Distributed.barrier(group: @process_group)
87
+ Torch.no_grad do
88
+ @module.parameters.each do |param|
89
+ Torch::Distributed.broadcast(param, src: 0, group: @process_group)
90
+ end
91
+ broadcast_buffers_if_needed
92
+ end
93
+ end
94
+
95
+ def broadcast_buffers_if_needed
96
+ return unless @broadcast_buffers
97
+
98
+ Torch.no_grad do
99
+ @module.buffers.each do |buffer|
100
+ Torch::Distributed.broadcast(buffer, src: 0, group: @process_group)
101
+ end
102
+ end
103
+ end
104
+
105
+ def register_parameter_hooks
106
+ @module.parameters.filter_map do |param|
107
+ next unless param.requires_grad?
108
+
109
+ Torch::Distributed.register_ddp_hook(param, @process_group, @world_size)
110
+ end
111
+ end
112
+ end
113
+ end
114
+ end
115
+ end