torch-ddp 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,531 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "optparse"
4
+ require "socket"
5
+ require "etc"
6
+ require "securerandom"
7
+ require "rbconfig"
8
+
9
+ require "torch"
10
+ require "torch/distributed"
11
+
12
+ module Torch
13
+ module TorchRun
14
+ SIGNALS = %w[INT TERM QUIT].freeze
15
+
16
+ class Error < StandardError; end
17
+
18
+ class Parser
19
+ attr_reader :parser
20
+
21
+ def initialize
22
+ @parser = OptionParser.new
23
+ end
24
+
25
+ def parse(argv)
26
+ options = default_options
27
+
28
+ parser.banner = "Usage: torchrun [options] TRAINING_SCRIPT [script args]"
29
+ parser.separator ""
30
+ parser.separator "Launch parameters:"
31
+
32
+ parser.on("--nnodes MIN[:MAX]", String, "Number of nodes or range (default: #{options[:nnodes]})") do |value|
33
+ options[:nnodes] = value
34
+ end
35
+
36
+ parser.on("--nproc-per-node VALUE", String, "Processes per node (int, gpu, cpu, auto). Default: #{options[:nproc_per_node]}") do |value|
37
+ options[:nproc_per_node] = value
38
+ end
39
+
40
+ parser.on("--node-rank VALUE", Integer, "Rank of the node for multi-node jobs. Default: #{options[:node_rank]}") do |value|
41
+ options[:node_rank] = value
42
+ end
43
+
44
+ parser.on("--rdzv-backend NAME", String, "Rendezvous backend (static or c10d). Default: #{options[:rdzv_backend]}") do |value|
45
+ options[:rdzv_backend] = value
46
+ end
47
+
48
+ parser.on("--rdzv-endpoint HOST[:PORT]", String, "Rendezvous endpoint. Default: use --master-addr/--master-port") do |value|
49
+ options[:rdzv_endpoint] = value
50
+ end
51
+
52
+ parser.on("--rdzv-id ID", String, "User defined job id. Default: #{options[:rdzv_id]}") do |value|
53
+ options[:rdzv_id] = value
54
+ end
55
+
56
+ parser.on("--rdzv-conf CONF", String, "Additional rendezvous config (k=v,k2=v2)") do |value|
57
+ options[:rdzv_conf] = parse_kv_pairs(value)
58
+ end
59
+
60
+ parser.on("--standalone", "Start a local rendezvous store on a free port") do
61
+ options[:standalone] = true
62
+ end
63
+
64
+ parser.on("--max-restarts VALUE", Integer, "Restarts before failing. Default: #{options[:max_restarts]}") do |value|
65
+ options[:max_restarts] = value
66
+ end
67
+
68
+ parser.on("--monitor-interval SECONDS", Float, "Delay between restart attempts. Default: #{options[:monitor_interval]}") do |value|
69
+ options[:monitor_interval] = value
70
+ end
71
+
72
+ parser.on("--role NAME", String, "Role for the worker group. Default: #{options[:role]}") do |value|
73
+ options[:role] = value
74
+ end
75
+
76
+ parser.on("--master-addr HOST", String, "Master address for static rendezvous. Default: #{options[:master_addr]}") do |value|
77
+ options[:master_addr] = value
78
+ end
79
+
80
+ parser.on("--master-port PORT", Integer, "Master port for static rendezvous. Default: #{options[:master_port]}") do |value|
81
+ options[:master_port] = value
82
+ end
83
+
84
+ parser.on("--pass-local-rank-arg", "Append --local-rank to the training script invocation") do
85
+ options[:pass_local_rank_arg] = true
86
+ end
87
+
88
+ parser.on("--no-ruby", "Execute the training script directly instead of `#{RbConfig.ruby}`") do
89
+ options[:no_ruby] = true
90
+ end
91
+
92
+ parser.on("-h", "--help", "Prints this help") do
93
+ puts parser
94
+ exit
95
+ end
96
+
97
+ rest = parser.parse!(argv)
98
+ raise OptionParser::MissingArgument, "training_script" if rest.empty?
99
+
100
+ training_script = rest.shift
101
+ [options, training_script, rest]
102
+ end
103
+
104
+ def to_s
105
+ parser.to_s
106
+ end
107
+
108
+ private
109
+
110
+ def default_options
111
+ {
112
+ nnodes: "1:1",
113
+ nproc_per_node: "1",
114
+ node_rank: 0,
115
+ rdzv_backend: "static",
116
+ rdzv_endpoint: "",
117
+ rdzv_id: "none",
118
+ rdzv_conf: {},
119
+ standalone: false,
120
+ max_restarts: 0,
121
+ monitor_interval: 1.0,
122
+ role: "default",
123
+ master_addr: "127.0.0.1",
124
+ master_port: 29_500,
125
+ pass_local_rank_arg: false,
126
+ no_ruby: false
127
+ }
128
+ end
129
+
130
+ def parse_kv_pairs(value)
131
+ return {} if value.nil? || value.strip.empty?
132
+
133
+ value.split(",").each_with_object({}) do |pair, acc|
134
+ key, val = pair.split("=", 2)
135
+ raise OptionParser::InvalidArgument, "Invalid rendezvous config entry: #{pair.inspect}" unless key && val
136
+
137
+ acc[key.strip] = val.strip
138
+ end
139
+ end
140
+ end
141
+
142
+ module_function
143
+
144
+ def start(argv, out: $stdout, err: $stderr)
145
+ parser = Parser.new
146
+ options, script, script_args = parser.parse(argv)
147
+ status = Launcher.new(options, script, script_args, out: out, err: err).run
148
+ exit(status)
149
+ rescue OptionParser::ParseError => e
150
+ err.puts(e.message)
151
+ err.puts(parser)
152
+ exit(2)
153
+ rescue Error => e
154
+ err.puts("torchrun: #{e.message}")
155
+ exit(1)
156
+ end
157
+
158
+ class Launcher
159
+ def initialize(options, script, script_args, out: $stdout, err: $stderr)
160
+ @options = options
161
+ @script = script
162
+ @script_args = script_args
163
+ @out = out
164
+ @err = err
165
+
166
+ @local_world_size = determine_local_world_size(@options[:nproc_per_node])
167
+ @min_nodes, @max_nodes = parse_nnodes(@options[:nnodes])
168
+ @num_nodes = ensure_fixed_nnodes(@min_nodes, @max_nodes)
169
+ @node_rank = @options[:node_rank]
170
+ @max_restarts = [@options[:max_restarts], 0].max
171
+ @monitor_interval = [@options[:monitor_interval], 0.0].max
172
+ @role = @options[:role]
173
+ @pass_local_rank_arg = @options[:pass_local_rank_arg]
174
+ @no_ruby = @options[:no_ruby]
175
+ validate_node_rank!
176
+
177
+ setup_rendezvous!
178
+ end
179
+
180
+ def run
181
+ restarts = 0
182
+
183
+ loop do
184
+ status = launch_worker_group(restarts)
185
+ return status if status.zero? || @signal_received
186
+ return status if restarts >= @max_restarts
187
+
188
+ restarts += 1
189
+ log("Worker group failed (exit #{status}). Restarting #{restarts}/#{@max_restarts} ...")
190
+ sleep(@monitor_interval) if @monitor_interval.positive?
191
+ end
192
+ end
193
+
194
+ private
195
+
196
+ def launch_worker_group(restart_count)
197
+ @signal_received = nil
198
+ @current_pids = spawn_workers(restart_count)
199
+ handler_state = setup_signal_handlers
200
+ status = monitor_workers(@current_pids.dup)
201
+ cleanup_workers(@current_pids)
202
+ restore_signal_handlers(handler_state)
203
+ return signal_exit_status if @signal_received
204
+
205
+ status
206
+ ensure
207
+ @worker_pgid = nil
208
+ @current_pids = []
209
+ end
210
+
211
+ def spawn_workers(restart_count)
212
+ base_env = base_environment(restart_count)
213
+ pgid = nil
214
+ workers = Array.new(@local_world_size) do |local_rank|
215
+ env = base_env.merge(rank_environment(local_rank))
216
+ pid, pgid = spawn_worker(env, local_rank, pgid)
217
+ pid
218
+ end
219
+ @worker_pgid = pgid
220
+ workers
221
+ end
222
+
223
+ def spawn_worker(env, local_rank, pgid)
224
+ args = command_arguments(local_rank)
225
+ spawn_opts = pgid ? { pgroup: pgid } : { pgroup: true }
226
+ pid = Process.spawn(env, *args, spawn_opts)
227
+ pgid ||= pid
228
+ [pid, pgid]
229
+ rescue SystemCallError => e
230
+ raise Error, "failed to launch worker #{local_rank}: #{e.message}"
231
+ end
232
+
233
+ def command_arguments(local_rank)
234
+ cmd = []
235
+ if @no_ruby
236
+ cmd << @script
237
+ else
238
+ cmd << RbConfig.ruby
239
+ cmd << @script
240
+ end
241
+ cmd.concat(@script_args)
242
+ cmd << "--local-rank=#{local_rank}" if @pass_local_rank_arg
243
+ cmd
244
+ end
245
+
246
+ def base_environment(restart_count)
247
+ endpoint = "#{@master_addr}:#{@master_port}"
248
+ env = {
249
+ "MASTER_ADDR" => @master_addr,
250
+ "MASTER_PORT" => @master_port.to_s,
251
+ "WORLD_SIZE" => world_size.to_s,
252
+ "LOCAL_WORLD_SIZE" => @local_world_size.to_s,
253
+ "GROUP_RANK" => @node_rank.to_s,
254
+ "TORCHRUN_ROLE" => @role,
255
+ "TORCHRUN_NNODES" => @num_nodes.to_s,
256
+ "TORCHRUN_NPROC_PER_NODE" => @local_world_size.to_s,
257
+ "TORCHELASTIC_RUN_ID" => @rdzv_id,
258
+ "TORCHRUN_RDZV_BACKEND" => @rdzv_backend,
259
+ "TORCHRUN_RDZV_ENDPOINT" => endpoint,
260
+ "TORCHELASTIC_RESTART_COUNT" => restart_count.to_s,
261
+ "TORCHRUN_STANDALONE" => @standalone ? "1" : "0"
262
+ }
263
+ unless @rdzv_conf.empty?
264
+ env["TORCHRUN_RDZV_CONF"] = @rdzv_conf.map { |k, v| "#{k}=#{v}" }.join(",")
265
+ end
266
+ ENV.to_h.merge(env)
267
+ end
268
+
269
+ def rank_environment(local_rank)
270
+ rank = @node_rank * @local_world_size + local_rank
271
+ {
272
+ "LOCAL_RANK" => local_rank.to_s,
273
+ "RANK" => rank.to_s
274
+ }
275
+ end
276
+
277
+ def monitor_workers(pids)
278
+ exit_code = 0
279
+ remaining = pids.dup
280
+ until remaining.empty?
281
+ pid, status = Process.wait2
282
+ next unless pid
283
+
284
+ remaining.delete(pid)
285
+ unless status.success?
286
+ exit_code = exit_status_from(status)
287
+ terminate_workers(remaining)
288
+ break
289
+ end
290
+ end
291
+ exit_code
292
+ rescue Errno::ECHILD
293
+ 0
294
+ end
295
+
296
+ def terminate_workers(pids)
297
+ return if pids.empty?
298
+
299
+ send_process_group_signal("TERM")
300
+ pids.each { |pid| send_signal(pid, "TERM") }
301
+ sleep(0.2)
302
+ pids.each do |pid|
303
+ next unless process_alive?(pid)
304
+
305
+ send_signal(pid, "KILL")
306
+ end
307
+ pids.each do |pid|
308
+ begin
309
+ Process.wait(pid)
310
+ rescue Errno::ECHILD
311
+ end
312
+ end
313
+ end
314
+
315
+ def process_alive?(pid)
316
+ Process.kill(0, pid)
317
+ true
318
+ rescue Errno::ESRCH
319
+ false
320
+ end
321
+
322
+ def setup_signal_handlers
323
+ SIGNALS.each_with_object({}) do |sig, acc|
324
+ next unless Signal.list.key?(sig)
325
+
326
+ previous = Signal.trap(sig) do
327
+ @signal_received = sig
328
+ forward_signal(sig)
329
+ end
330
+ acc[sig] = previous
331
+ end
332
+ end
333
+
334
+ def forward_signal(sig)
335
+ send_process_group_signal(sig)
336
+ (@current_pids || []).each { |pid| send_signal(pid, sig) }
337
+ end
338
+
339
+ def restore_signal_handlers(state)
340
+ return unless state
341
+
342
+ state.each do |sig, previous|
343
+ Signal.trap(sig, previous)
344
+ end
345
+ end
346
+
347
+ def send_signal(pid, sig)
348
+ Process.kill(sig, pid)
349
+ rescue Errno::ESRCH
350
+ nil
351
+ end
352
+
353
+ def send_process_group_signal(sig)
354
+ return unless @worker_pgid
355
+
356
+ Process.kill(sig, -@worker_pgid)
357
+ rescue Errno::ESRCH
358
+ nil
359
+ end
360
+
361
+ def cleanup_workers(pids)
362
+ pids.each do |pid|
363
+ next unless process_alive?(pid)
364
+
365
+ begin
366
+ Process.wait(pid)
367
+ rescue Errno::ECHILD
368
+ end
369
+ end
370
+ end
371
+
372
+ def signal_exit_status
373
+ return 0 unless @signal_received
374
+
375
+ 128 + Signal.list.fetch(@signal_received, 0)
376
+ end
377
+
378
+ def exit_status_from(status)
379
+ if status.exited?
380
+ status.exitstatus
381
+ elsif status.signaled?
382
+ 128 + status.termsig
383
+ else
384
+ 1
385
+ end
386
+ end
387
+
388
+ def determine_local_world_size(value)
389
+ spec = value.to_s.strip.downcase
390
+ case spec
391
+ when "", "1"
392
+ 1
393
+ when /\A\d+\z/
394
+ amount = spec.to_i
395
+ raise Error, "nproc-per-node must be >= 1" if amount < 1
396
+
397
+ amount
398
+ when "gpu"
399
+ gpu_count = cuda_device_count
400
+ raise Error, "CUDA is not available for --nproc-per-node=gpu" if gpu_count.zero?
401
+
402
+ gpu_count
403
+ when "auto"
404
+ gpu_count = cuda_device_count
405
+ return gpu_count if gpu_count.positive?
406
+
407
+ cpu_count
408
+ when "cpu"
409
+ cpu_count
410
+ else
411
+ raise Error, "Unsupported --nproc-per-node value: #{value}"
412
+ end
413
+ end
414
+
415
+ def cuda_device_count
416
+ return 0 unless defined?(Torch::CUDA)
417
+ return 0 unless Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
418
+ return 0 unless Torch::CUDA.respond_to?(:device_count)
419
+
420
+ Torch::CUDA.device_count
421
+ rescue StandardError
422
+ 0
423
+ end
424
+
425
+ def cpu_count
426
+ Etc.respond_to?(:nprocessors) ? (Etc.nprocessors || 1) : 1
427
+ rescue StandardError
428
+ 1
429
+ end
430
+
431
+ def parse_nnodes(value)
432
+ parts = value.split(":")
433
+ nums = parts.map do |part|
434
+ Integer(part, exception: false)
435
+ end
436
+ raise Error, "Invalid --nnodes value: #{value.inspect}" if nums.any?(&:nil?)
437
+
438
+ if nums.length == 1
439
+ [nums.first, nums.first]
440
+ elsif nums.length == 2
441
+ [nums.first, nums.last]
442
+ else
443
+ raise Error, "Invalid --nnodes value: #{value.inspect}"
444
+ end
445
+ end
446
+
447
+ def ensure_fixed_nnodes(min_nodes, max_nodes)
448
+ raise Error, "--nnodes minimum must be >= 1" if min_nodes < 1
449
+ raise Error, "--nnodes maximum must be >= minimum" if max_nodes < min_nodes
450
+ raise Error, "Elastic nnodes ranges are not supported yet (got #{min_nodes}:#{max_nodes})" if min_nodes != max_nodes
451
+
452
+ min_nodes
453
+ end
454
+
455
+ def world_size
456
+ @world_size ||= @num_nodes * @local_world_size
457
+ end
458
+
459
+ def validate_node_rank!
460
+ raise Error, "--node-rank must be >= 0" if @node_rank.negative?
461
+ raise Error, "--node-rank (#{@node_rank}) must be less than --nnodes (#{@num_nodes})" if @node_rank >= @num_nodes
462
+ end
463
+
464
+ def setup_rendezvous!
465
+ @rdzv_backend = normalize_backend(@options[:rdzv_backend])
466
+ @rdzv_conf = @options[:rdzv_conf] || {}
467
+ if @options[:standalone]
468
+ configure_standalone_rendezvous
469
+ else
470
+ configure_static_rendezvous
471
+ end
472
+ end
473
+
474
+ def normalize_backend(value)
475
+ backend = value.to_s.downcase
476
+ raise Error, "Unsupported rendezvous backend: #{value.inspect}" unless %w[static c10d].include?(backend)
477
+
478
+ backend
479
+ end
480
+
481
+ def configure_standalone_rendezvous
482
+ @standalone = true
483
+ @rdzv_backend = "c10d"
484
+ @rdzv_id = SecureRandom.uuid
485
+ @master_addr = "127.0.0.1"
486
+ @master_port = find_free_port(@master_addr)
487
+ log(<<~MSG)
488
+
489
+ **************************************
490
+ Rendezvous info:
491
+ --rdzv-backend=#{@rdzv_backend}
492
+ --rdzv-endpoint=#{@master_addr}:#{@master_port}
493
+ --rdzv-id=#{@rdzv_id}
494
+ **************************************
495
+
496
+ MSG
497
+ end
498
+
499
+ def configure_static_rendezvous
500
+ @standalone = false
501
+ endpoint_host, endpoint_port = parse_endpoint(@options[:rdzv_endpoint])
502
+ @master_addr = endpoint_host || @options[:master_addr]
503
+ @master_port = endpoint_port || @options[:master_port]
504
+ @rdzv_id = @options[:rdzv_id]
505
+ raise Error, "MASTER_ADDR must be provided" if @master_addr.to_s.empty?
506
+ raise Error, "MASTER_PORT must be > 0" unless @master_port.to_i.positive?
507
+ end
508
+
509
+ def parse_endpoint(value)
510
+ return [nil, nil] if value.nil? || value.strip.empty?
511
+
512
+ host, port_str = value.split(":", 2)
513
+ port = port_str ? Integer(port_str, exception: false) : nil
514
+ raise Error, "Invalid rendezvous endpoint: #{value.inspect}" if host.to_s.empty? || (port_str && port.nil?)
515
+
516
+ [host, port]
517
+ end
518
+
519
+ def find_free_port(host)
520
+ server = TCPServer.new(host, 0)
521
+ server.addr[1]
522
+ ensure
523
+ server&.close
524
+ end
525
+
526
+ def log(message)
527
+ @out.puts(message)
528
+ end
529
+ end
530
+ end
531
+ end
data/lib/torch-ddp.rb ADDED
@@ -0,0 +1,8 @@
1
+ require "torch"
2
+ require "torch/ddp/monkey_patch"
3
+ require "torch/ddp/version"
4
+ require "torch/distributed"
5
+ require "torch/nn/parallel/distributed_data_parallel"
6
+ require "torch/torchrun"
7
+
8
+ Torch::DDP::MonkeyPatch.apply_if_needed