torch-ddp 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +46 -0
- data/README.md +114 -0
- data/bin/torchrun +6 -0
- data/examples/benchmark/training.rb +374 -0
- data/examples/mnist/distributed.rb +240 -0
- data/ext/torch_ddp/distributed.cpp +348 -0
- data/ext/torch_ddp/ext.cpp +11 -0
- data/ext/torch_ddp/extconf.rb +155 -0
- data/lib/torch/ddp/monkey_patch.rb +325 -0
- data/lib/torch/ddp/version.rb +5 -0
- data/lib/torch/distributed.rb +466 -0
- data/lib/torch/nn/parallel/distributed_data_parallel.rb +115 -0
- data/lib/torch/torchrun.rb +531 -0
- data/lib/torch-ddp.rb +8 -0
- data/test/distributed_test.rb +243 -0
- data/test/support/net.rb +42 -0
- data/test/support/scripts/show_ranks.rb +7 -0
- data/test/support/tensor.pth +0 -0
- data/test/test_helper.rb +71 -0
- data/test/torchrun_test.rb +33 -0
- metadata +92 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "optparse"
|
|
4
|
+
require "socket"
|
|
5
|
+
require "etc"
|
|
6
|
+
require "securerandom"
|
|
7
|
+
require "rbconfig"
|
|
8
|
+
|
|
9
|
+
require "torch"
|
|
10
|
+
require "torch/distributed"
|
|
11
|
+
|
|
12
|
+
module Torch
|
|
13
|
+
module TorchRun
|
|
14
|
+
SIGNALS = %w[INT TERM QUIT].freeze
|
|
15
|
+
|
|
16
|
+
class Error < StandardError; end
|
|
17
|
+
|
|
18
|
+
class Parser
|
|
19
|
+
attr_reader :parser
|
|
20
|
+
|
|
21
|
+
def initialize
|
|
22
|
+
@parser = OptionParser.new
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def parse(argv)
|
|
26
|
+
options = default_options
|
|
27
|
+
|
|
28
|
+
parser.banner = "Usage: torchrun [options] TRAINING_SCRIPT [script args]"
|
|
29
|
+
parser.separator ""
|
|
30
|
+
parser.separator "Launch parameters:"
|
|
31
|
+
|
|
32
|
+
parser.on("--nnodes MIN[:MAX]", String, "Number of nodes or range (default: #{options[:nnodes]})") do |value|
|
|
33
|
+
options[:nnodes] = value
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
parser.on("--nproc-per-node VALUE", String, "Processes per node (int, gpu, cpu, auto). Default: #{options[:nproc_per_node]}") do |value|
|
|
37
|
+
options[:nproc_per_node] = value
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
parser.on("--node-rank VALUE", Integer, "Rank of the node for multi-node jobs. Default: #{options[:node_rank]}") do |value|
|
|
41
|
+
options[:node_rank] = value
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
parser.on("--rdzv-backend NAME", String, "Rendezvous backend (static or c10d). Default: #{options[:rdzv_backend]}") do |value|
|
|
45
|
+
options[:rdzv_backend] = value
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
parser.on("--rdzv-endpoint HOST[:PORT]", String, "Rendezvous endpoint. Default: use --master-addr/--master-port") do |value|
|
|
49
|
+
options[:rdzv_endpoint] = value
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
parser.on("--rdzv-id ID", String, "User defined job id. Default: #{options[:rdzv_id]}") do |value|
|
|
53
|
+
options[:rdzv_id] = value
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
parser.on("--rdzv-conf CONF", String, "Additional rendezvous config (k=v,k2=v2)") do |value|
|
|
57
|
+
options[:rdzv_conf] = parse_kv_pairs(value)
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
parser.on("--standalone", "Start a local rendezvous store on a free port") do
|
|
61
|
+
options[:standalone] = true
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
parser.on("--max-restarts VALUE", Integer, "Restarts before failing. Default: #{options[:max_restarts]}") do |value|
|
|
65
|
+
options[:max_restarts] = value
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
parser.on("--monitor-interval SECONDS", Float, "Delay between restart attempts. Default: #{options[:monitor_interval]}") do |value|
|
|
69
|
+
options[:monitor_interval] = value
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
parser.on("--role NAME", String, "Role for the worker group. Default: #{options[:role]}") do |value|
|
|
73
|
+
options[:role] = value
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
parser.on("--master-addr HOST", String, "Master address for static rendezvous. Default: #{options[:master_addr]}") do |value|
|
|
77
|
+
options[:master_addr] = value
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
parser.on("--master-port PORT", Integer, "Master port for static rendezvous. Default: #{options[:master_port]}") do |value|
|
|
81
|
+
options[:master_port] = value
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
parser.on("--pass-local-rank-arg", "Append --local-rank to the training script invocation") do
|
|
85
|
+
options[:pass_local_rank_arg] = true
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
parser.on("--no-ruby", "Execute the training script directly instead of `#{RbConfig.ruby}`") do
|
|
89
|
+
options[:no_ruby] = true
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
parser.on("-h", "--help", "Prints this help") do
|
|
93
|
+
puts parser
|
|
94
|
+
exit
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
rest = parser.parse!(argv)
|
|
98
|
+
raise OptionParser::MissingArgument, "training_script" if rest.empty?
|
|
99
|
+
|
|
100
|
+
training_script = rest.shift
|
|
101
|
+
[options, training_script, rest]
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def to_s
|
|
105
|
+
parser.to_s
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
private
|
|
109
|
+
|
|
110
|
+
def default_options
|
|
111
|
+
{
|
|
112
|
+
nnodes: "1:1",
|
|
113
|
+
nproc_per_node: "1",
|
|
114
|
+
node_rank: 0,
|
|
115
|
+
rdzv_backend: "static",
|
|
116
|
+
rdzv_endpoint: "",
|
|
117
|
+
rdzv_id: "none",
|
|
118
|
+
rdzv_conf: {},
|
|
119
|
+
standalone: false,
|
|
120
|
+
max_restarts: 0,
|
|
121
|
+
monitor_interval: 1.0,
|
|
122
|
+
role: "default",
|
|
123
|
+
master_addr: "127.0.0.1",
|
|
124
|
+
master_port: 29_500,
|
|
125
|
+
pass_local_rank_arg: false,
|
|
126
|
+
no_ruby: false
|
|
127
|
+
}
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def parse_kv_pairs(value)
|
|
131
|
+
return {} if value.nil? || value.strip.empty?
|
|
132
|
+
|
|
133
|
+
value.split(",").each_with_object({}) do |pair, acc|
|
|
134
|
+
key, val = pair.split("=", 2)
|
|
135
|
+
raise OptionParser::InvalidArgument, "Invalid rendezvous config entry: #{pair.inspect}" unless key && val
|
|
136
|
+
|
|
137
|
+
acc[key.strip] = val.strip
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
module_function
|
|
143
|
+
|
|
144
|
+
def start(argv, out: $stdout, err: $stderr)
|
|
145
|
+
parser = Parser.new
|
|
146
|
+
options, script, script_args = parser.parse(argv)
|
|
147
|
+
status = Launcher.new(options, script, script_args, out: out, err: err).run
|
|
148
|
+
exit(status)
|
|
149
|
+
rescue OptionParser::ParseError => e
|
|
150
|
+
err.puts(e.message)
|
|
151
|
+
err.puts(parser)
|
|
152
|
+
exit(2)
|
|
153
|
+
rescue Error => e
|
|
154
|
+
err.puts("torchrun: #{e.message}")
|
|
155
|
+
exit(1)
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
class Launcher
|
|
159
|
+
def initialize(options, script, script_args, out: $stdout, err: $stderr)
|
|
160
|
+
@options = options
|
|
161
|
+
@script = script
|
|
162
|
+
@script_args = script_args
|
|
163
|
+
@out = out
|
|
164
|
+
@err = err
|
|
165
|
+
|
|
166
|
+
@local_world_size = determine_local_world_size(@options[:nproc_per_node])
|
|
167
|
+
@min_nodes, @max_nodes = parse_nnodes(@options[:nnodes])
|
|
168
|
+
@num_nodes = ensure_fixed_nnodes(@min_nodes, @max_nodes)
|
|
169
|
+
@node_rank = @options[:node_rank]
|
|
170
|
+
@max_restarts = [@options[:max_restarts], 0].max
|
|
171
|
+
@monitor_interval = [@options[:monitor_interval], 0.0].max
|
|
172
|
+
@role = @options[:role]
|
|
173
|
+
@pass_local_rank_arg = @options[:pass_local_rank_arg]
|
|
174
|
+
@no_ruby = @options[:no_ruby]
|
|
175
|
+
validate_node_rank!
|
|
176
|
+
|
|
177
|
+
setup_rendezvous!
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
def run
|
|
181
|
+
restarts = 0
|
|
182
|
+
|
|
183
|
+
loop do
|
|
184
|
+
status = launch_worker_group(restarts)
|
|
185
|
+
return status if status.zero? || @signal_received
|
|
186
|
+
return status if restarts >= @max_restarts
|
|
187
|
+
|
|
188
|
+
restarts += 1
|
|
189
|
+
log("Worker group failed (exit #{status}). Restarting #{restarts}/#{@max_restarts} ...")
|
|
190
|
+
sleep(@monitor_interval) if @monitor_interval.positive?
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
|
|
194
|
+
private
|
|
195
|
+
|
|
196
|
+
def launch_worker_group(restart_count)
|
|
197
|
+
@signal_received = nil
|
|
198
|
+
@current_pids = spawn_workers(restart_count)
|
|
199
|
+
handler_state = setup_signal_handlers
|
|
200
|
+
status = monitor_workers(@current_pids.dup)
|
|
201
|
+
cleanup_workers(@current_pids)
|
|
202
|
+
restore_signal_handlers(handler_state)
|
|
203
|
+
return signal_exit_status if @signal_received
|
|
204
|
+
|
|
205
|
+
status
|
|
206
|
+
ensure
|
|
207
|
+
@worker_pgid = nil
|
|
208
|
+
@current_pids = []
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
def spawn_workers(restart_count)
|
|
212
|
+
base_env = base_environment(restart_count)
|
|
213
|
+
pgid = nil
|
|
214
|
+
workers = Array.new(@local_world_size) do |local_rank|
|
|
215
|
+
env = base_env.merge(rank_environment(local_rank))
|
|
216
|
+
pid, pgid = spawn_worker(env, local_rank, pgid)
|
|
217
|
+
pid
|
|
218
|
+
end
|
|
219
|
+
@worker_pgid = pgid
|
|
220
|
+
workers
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
def spawn_worker(env, local_rank, pgid)
|
|
224
|
+
args = command_arguments(local_rank)
|
|
225
|
+
spawn_opts = pgid ? { pgroup: pgid } : { pgroup: true }
|
|
226
|
+
pid = Process.spawn(env, *args, spawn_opts)
|
|
227
|
+
pgid ||= pid
|
|
228
|
+
[pid, pgid]
|
|
229
|
+
rescue SystemCallError => e
|
|
230
|
+
raise Error, "failed to launch worker #{local_rank}: #{e.message}"
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
def command_arguments(local_rank)
|
|
234
|
+
cmd = []
|
|
235
|
+
if @no_ruby
|
|
236
|
+
cmd << @script
|
|
237
|
+
else
|
|
238
|
+
cmd << RbConfig.ruby
|
|
239
|
+
cmd << @script
|
|
240
|
+
end
|
|
241
|
+
cmd.concat(@script_args)
|
|
242
|
+
cmd << "--local-rank=#{local_rank}" if @pass_local_rank_arg
|
|
243
|
+
cmd
|
|
244
|
+
end
|
|
245
|
+
|
|
246
|
+
def base_environment(restart_count)
|
|
247
|
+
endpoint = "#{@master_addr}:#{@master_port}"
|
|
248
|
+
env = {
|
|
249
|
+
"MASTER_ADDR" => @master_addr,
|
|
250
|
+
"MASTER_PORT" => @master_port.to_s,
|
|
251
|
+
"WORLD_SIZE" => world_size.to_s,
|
|
252
|
+
"LOCAL_WORLD_SIZE" => @local_world_size.to_s,
|
|
253
|
+
"GROUP_RANK" => @node_rank.to_s,
|
|
254
|
+
"TORCHRUN_ROLE" => @role,
|
|
255
|
+
"TORCHRUN_NNODES" => @num_nodes.to_s,
|
|
256
|
+
"TORCHRUN_NPROC_PER_NODE" => @local_world_size.to_s,
|
|
257
|
+
"TORCHELASTIC_RUN_ID" => @rdzv_id,
|
|
258
|
+
"TORCHRUN_RDZV_BACKEND" => @rdzv_backend,
|
|
259
|
+
"TORCHRUN_RDZV_ENDPOINT" => endpoint,
|
|
260
|
+
"TORCHELASTIC_RESTART_COUNT" => restart_count.to_s,
|
|
261
|
+
"TORCHRUN_STANDALONE" => @standalone ? "1" : "0"
|
|
262
|
+
}
|
|
263
|
+
unless @rdzv_conf.empty?
|
|
264
|
+
env["TORCHRUN_RDZV_CONF"] = @rdzv_conf.map { |k, v| "#{k}=#{v}" }.join(",")
|
|
265
|
+
end
|
|
266
|
+
ENV.to_h.merge(env)
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
def rank_environment(local_rank)
|
|
270
|
+
rank = @node_rank * @local_world_size + local_rank
|
|
271
|
+
{
|
|
272
|
+
"LOCAL_RANK" => local_rank.to_s,
|
|
273
|
+
"RANK" => rank.to_s
|
|
274
|
+
}
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
def monitor_workers(pids)
|
|
278
|
+
exit_code = 0
|
|
279
|
+
remaining = pids.dup
|
|
280
|
+
until remaining.empty?
|
|
281
|
+
pid, status = Process.wait2
|
|
282
|
+
next unless pid
|
|
283
|
+
|
|
284
|
+
remaining.delete(pid)
|
|
285
|
+
unless status.success?
|
|
286
|
+
exit_code = exit_status_from(status)
|
|
287
|
+
terminate_workers(remaining)
|
|
288
|
+
break
|
|
289
|
+
end
|
|
290
|
+
end
|
|
291
|
+
exit_code
|
|
292
|
+
rescue Errno::ECHILD
|
|
293
|
+
0
|
|
294
|
+
end
|
|
295
|
+
|
|
296
|
+
def terminate_workers(pids)
|
|
297
|
+
return if pids.empty?
|
|
298
|
+
|
|
299
|
+
send_process_group_signal("TERM")
|
|
300
|
+
pids.each { |pid| send_signal(pid, "TERM") }
|
|
301
|
+
sleep(0.2)
|
|
302
|
+
pids.each do |pid|
|
|
303
|
+
next unless process_alive?(pid)
|
|
304
|
+
|
|
305
|
+
send_signal(pid, "KILL")
|
|
306
|
+
end
|
|
307
|
+
pids.each do |pid|
|
|
308
|
+
begin
|
|
309
|
+
Process.wait(pid)
|
|
310
|
+
rescue Errno::ECHILD
|
|
311
|
+
end
|
|
312
|
+
end
|
|
313
|
+
end
|
|
314
|
+
|
|
315
|
+
def process_alive?(pid)
|
|
316
|
+
Process.kill(0, pid)
|
|
317
|
+
true
|
|
318
|
+
rescue Errno::ESRCH
|
|
319
|
+
false
|
|
320
|
+
end
|
|
321
|
+
|
|
322
|
+
def setup_signal_handlers
|
|
323
|
+
SIGNALS.each_with_object({}) do |sig, acc|
|
|
324
|
+
next unless Signal.list.key?(sig)
|
|
325
|
+
|
|
326
|
+
previous = Signal.trap(sig) do
|
|
327
|
+
@signal_received = sig
|
|
328
|
+
forward_signal(sig)
|
|
329
|
+
end
|
|
330
|
+
acc[sig] = previous
|
|
331
|
+
end
|
|
332
|
+
end
|
|
333
|
+
|
|
334
|
+
def forward_signal(sig)
|
|
335
|
+
send_process_group_signal(sig)
|
|
336
|
+
(@current_pids || []).each { |pid| send_signal(pid, sig) }
|
|
337
|
+
end
|
|
338
|
+
|
|
339
|
+
def restore_signal_handlers(state)
|
|
340
|
+
return unless state
|
|
341
|
+
|
|
342
|
+
state.each do |sig, previous|
|
|
343
|
+
Signal.trap(sig, previous)
|
|
344
|
+
end
|
|
345
|
+
end
|
|
346
|
+
|
|
347
|
+
def send_signal(pid, sig)
|
|
348
|
+
Process.kill(sig, pid)
|
|
349
|
+
rescue Errno::ESRCH
|
|
350
|
+
nil
|
|
351
|
+
end
|
|
352
|
+
|
|
353
|
+
def send_process_group_signal(sig)
|
|
354
|
+
return unless @worker_pgid
|
|
355
|
+
|
|
356
|
+
Process.kill(sig, -@worker_pgid)
|
|
357
|
+
rescue Errno::ESRCH
|
|
358
|
+
nil
|
|
359
|
+
end
|
|
360
|
+
|
|
361
|
+
def cleanup_workers(pids)
|
|
362
|
+
pids.each do |pid|
|
|
363
|
+
next unless process_alive?(pid)
|
|
364
|
+
|
|
365
|
+
begin
|
|
366
|
+
Process.wait(pid)
|
|
367
|
+
rescue Errno::ECHILD
|
|
368
|
+
end
|
|
369
|
+
end
|
|
370
|
+
end
|
|
371
|
+
|
|
372
|
+
def signal_exit_status
|
|
373
|
+
return 0 unless @signal_received
|
|
374
|
+
|
|
375
|
+
128 + Signal.list.fetch(@signal_received, 0)
|
|
376
|
+
end
|
|
377
|
+
|
|
378
|
+
def exit_status_from(status)
|
|
379
|
+
if status.exited?
|
|
380
|
+
status.exitstatus
|
|
381
|
+
elsif status.signaled?
|
|
382
|
+
128 + status.termsig
|
|
383
|
+
else
|
|
384
|
+
1
|
|
385
|
+
end
|
|
386
|
+
end
|
|
387
|
+
|
|
388
|
+
def determine_local_world_size(value)
|
|
389
|
+
spec = value.to_s.strip.downcase
|
|
390
|
+
case spec
|
|
391
|
+
when "", "1"
|
|
392
|
+
1
|
|
393
|
+
when /\A\d+\z/
|
|
394
|
+
amount = spec.to_i
|
|
395
|
+
raise Error, "nproc-per-node must be >= 1" if amount < 1
|
|
396
|
+
|
|
397
|
+
amount
|
|
398
|
+
when "gpu"
|
|
399
|
+
gpu_count = cuda_device_count
|
|
400
|
+
raise Error, "CUDA is not available for --nproc-per-node=gpu" if gpu_count.zero?
|
|
401
|
+
|
|
402
|
+
gpu_count
|
|
403
|
+
when "auto"
|
|
404
|
+
gpu_count = cuda_device_count
|
|
405
|
+
return gpu_count if gpu_count.positive?
|
|
406
|
+
|
|
407
|
+
cpu_count
|
|
408
|
+
when "cpu"
|
|
409
|
+
cpu_count
|
|
410
|
+
else
|
|
411
|
+
raise Error, "Unsupported --nproc-per-node value: #{value}"
|
|
412
|
+
end
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
def cuda_device_count
|
|
416
|
+
return 0 unless defined?(Torch::CUDA)
|
|
417
|
+
return 0 unless Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
|
|
418
|
+
return 0 unless Torch::CUDA.respond_to?(:device_count)
|
|
419
|
+
|
|
420
|
+
Torch::CUDA.device_count
|
|
421
|
+
rescue StandardError
|
|
422
|
+
0
|
|
423
|
+
end
|
|
424
|
+
|
|
425
|
+
def cpu_count
|
|
426
|
+
Etc.respond_to?(:nprocessors) ? (Etc.nprocessors || 1) : 1
|
|
427
|
+
rescue StandardError
|
|
428
|
+
1
|
|
429
|
+
end
|
|
430
|
+
|
|
431
|
+
def parse_nnodes(value)
|
|
432
|
+
parts = value.split(":")
|
|
433
|
+
nums = parts.map do |part|
|
|
434
|
+
Integer(part, exception: false)
|
|
435
|
+
end
|
|
436
|
+
raise Error, "Invalid --nnodes value: #{value.inspect}" if nums.any?(&:nil?)
|
|
437
|
+
|
|
438
|
+
if nums.length == 1
|
|
439
|
+
[nums.first, nums.first]
|
|
440
|
+
elsif nums.length == 2
|
|
441
|
+
[nums.first, nums.last]
|
|
442
|
+
else
|
|
443
|
+
raise Error, "Invalid --nnodes value: #{value.inspect}"
|
|
444
|
+
end
|
|
445
|
+
end
|
|
446
|
+
|
|
447
|
+
def ensure_fixed_nnodes(min_nodes, max_nodes)
|
|
448
|
+
raise Error, "--nnodes minimum must be >= 1" if min_nodes < 1
|
|
449
|
+
raise Error, "--nnodes maximum must be >= minimum" if max_nodes < min_nodes
|
|
450
|
+
raise Error, "Elastic nnodes ranges are not supported yet (got #{min_nodes}:#{max_nodes})" if min_nodes != max_nodes
|
|
451
|
+
|
|
452
|
+
min_nodes
|
|
453
|
+
end
|
|
454
|
+
|
|
455
|
+
def world_size
|
|
456
|
+
@world_size ||= @num_nodes * @local_world_size
|
|
457
|
+
end
|
|
458
|
+
|
|
459
|
+
def validate_node_rank!
|
|
460
|
+
raise Error, "--node-rank must be >= 0" if @node_rank.negative?
|
|
461
|
+
raise Error, "--node-rank (#{@node_rank}) must be less than --nnodes (#{@num_nodes})" if @node_rank >= @num_nodes
|
|
462
|
+
end
|
|
463
|
+
|
|
464
|
+
def setup_rendezvous!
|
|
465
|
+
@rdzv_backend = normalize_backend(@options[:rdzv_backend])
|
|
466
|
+
@rdzv_conf = @options[:rdzv_conf] || {}
|
|
467
|
+
if @options[:standalone]
|
|
468
|
+
configure_standalone_rendezvous
|
|
469
|
+
else
|
|
470
|
+
configure_static_rendezvous
|
|
471
|
+
end
|
|
472
|
+
end
|
|
473
|
+
|
|
474
|
+
def normalize_backend(value)
|
|
475
|
+
backend = value.to_s.downcase
|
|
476
|
+
raise Error, "Unsupported rendezvous backend: #{value.inspect}" unless %w[static c10d].include?(backend)
|
|
477
|
+
|
|
478
|
+
backend
|
|
479
|
+
end
|
|
480
|
+
|
|
481
|
+
def configure_standalone_rendezvous
|
|
482
|
+
@standalone = true
|
|
483
|
+
@rdzv_backend = "c10d"
|
|
484
|
+
@rdzv_id = SecureRandom.uuid
|
|
485
|
+
@master_addr = "127.0.0.1"
|
|
486
|
+
@master_port = find_free_port(@master_addr)
|
|
487
|
+
log(<<~MSG)
|
|
488
|
+
|
|
489
|
+
**************************************
|
|
490
|
+
Rendezvous info:
|
|
491
|
+
--rdzv-backend=#{@rdzv_backend}
|
|
492
|
+
--rdzv-endpoint=#{@master_addr}:#{@master_port}
|
|
493
|
+
--rdzv-id=#{@rdzv_id}
|
|
494
|
+
**************************************
|
|
495
|
+
|
|
496
|
+
MSG
|
|
497
|
+
end
|
|
498
|
+
|
|
499
|
+
def configure_static_rendezvous
|
|
500
|
+
@standalone = false
|
|
501
|
+
endpoint_host, endpoint_port = parse_endpoint(@options[:rdzv_endpoint])
|
|
502
|
+
@master_addr = endpoint_host || @options[:master_addr]
|
|
503
|
+
@master_port = endpoint_port || @options[:master_port]
|
|
504
|
+
@rdzv_id = @options[:rdzv_id]
|
|
505
|
+
raise Error, "MASTER_ADDR must be provided" if @master_addr.to_s.empty?
|
|
506
|
+
raise Error, "MASTER_PORT must be > 0" unless @master_port.to_i.positive?
|
|
507
|
+
end
|
|
508
|
+
|
|
509
|
+
def parse_endpoint(value)
|
|
510
|
+
return [nil, nil] if value.nil? || value.strip.empty?
|
|
511
|
+
|
|
512
|
+
host, port_str = value.split(":", 2)
|
|
513
|
+
port = port_str ? Integer(port_str, exception: false) : nil
|
|
514
|
+
raise Error, "Invalid rendezvous endpoint: #{value.inspect}" if host.to_s.empty? || (port_str && port.nil?)
|
|
515
|
+
|
|
516
|
+
[host, port]
|
|
517
|
+
end
|
|
518
|
+
|
|
519
|
+
def find_free_port(host)
|
|
520
|
+
server = TCPServer.new(host, 0)
|
|
521
|
+
server.addr[1]
|
|
522
|
+
ensure
|
|
523
|
+
server&.close
|
|
524
|
+
end
|
|
525
|
+
|
|
526
|
+
def log(message)
|
|
527
|
+
@out.puts(message)
|
|
528
|
+
end
|
|
529
|
+
end
|
|
530
|
+
end
|
|
531
|
+
end
|