torch-ddp 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,243 @@
1
+ require_relative "test_helper"
2
+ require "torch/distributed"
3
+ require "socket"
4
+ require "timeout"
5
+
6
+ class DistributedInitProcessGroupTest < Minitest::Test
7
+ def setup
8
+ skip "Distributed backend not available" unless Torch::Distributed.available?
9
+ skip "CUDA not available for NCCL backend" unless cuda_available?
10
+ end
11
+
12
+ def test_defaults_nccl_device_id_from_local_rank_env
13
+ calls = []
14
+ with_stubbed_init_process_group(calls) do
15
+ ENV["LOCAL_RANK"] = "2"
16
+ Torch::Distributed.init_process_group("nccl", store: Object.new, rank: 5, world_size: 8)
17
+ ensure
18
+ ENV.delete("LOCAL_RANK")
19
+ end
20
+
21
+ assert_equal 1, calls.size
22
+ assert_equal 2, calls.first[:device_id]
23
+ end
24
+
25
+ def test_falls_back_to_local_world_size_modulo
26
+ calls = []
27
+ with_stubbed_init_process_group(calls) do
28
+ ENV["LOCAL_WORLD_SIZE"] = "2"
29
+ Torch::Distributed.init_process_group("nccl", store: Object.new, rank: 3, world_size: 4)
30
+ ensure
31
+ ENV.delete("LOCAL_WORLD_SIZE")
32
+ end
33
+
34
+ assert_equal 1, calls.size
35
+ assert_equal 1, calls.first[:device_id]
36
+ end
37
+
38
+ def test_uses_world_size_when_env_missing
39
+ calls = []
40
+ with_stubbed_init_process_group(calls) do
41
+ Torch::Distributed.init_process_group("nccl", store: Object.new, rank: 1, world_size: 2)
42
+ end
43
+
44
+ assert_equal 1, calls.size
45
+ assert_equal 1, calls.first[:device_id]
46
+ end
47
+
48
+ private
49
+
50
+ # Stub out low-level init to capture arguments without starting a real process group
51
+ # Used for upper-level tests that don't require actial process group spawning
52
+ def with_stubbed_init_process_group(calls)
53
+ original = Torch::Distributed.method(:_init_process_group)
54
+ Torch::Distributed.singleton_class.define_method(:_init_process_group) do |backend, store, rank, world_size, timeout_ms, device_id|
55
+ calls << {backend: backend, rank: rank, world_size: world_size, timeout_ms: timeout_ms, device_id: device_id}
56
+ :stub
57
+ end
58
+ yield
59
+ ensure
60
+ Torch::Distributed.singleton_class.define_method(:_init_process_group, original)
61
+ end
62
+
63
+ def cuda_available?
64
+ Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
65
+ end
66
+ end
67
+
68
+ class DistributedSpawnStartMethodTest < Minitest::Test
69
+ def test_spawn_worker_env_runs_block
70
+ reader, writer = IO.pipe
71
+ writer.close_on_exec = false
72
+
73
+ pid = fork do
74
+ reader.close
75
+ ENV[Torch::Distributed::SPAWN_ENV_KEY] = "1"
76
+ ENV[Torch::Distributed::SPAWN_RANK_ENV_KEY] = "0"
77
+ ENV[Torch::Distributed::SPAWN_WORLD_SIZE_ENV_KEY] = "1"
78
+ ENV[Torch::Distributed::SPAWN_PORT_ENV_KEY] = "1234"
79
+ ENV[Torch::Distributed::SPAWN_PIPE_ENV_KEY] = writer.fileno.to_s
80
+ Torch::Distributed.fork_world(1, start_method: :spawn) { |rank, port| [rank, port] }
81
+ end
82
+
83
+ writer.close
84
+ result = Marshal.load(reader)
85
+ reader.close
86
+
87
+ _pid, status = Process.wait2(pid)
88
+ assert status.success?
89
+ assert_equal [0, 1234], result
90
+ end
91
+ end
92
+
93
+ class DistributedBackendTest < Minitest::Test
94
+ BACKEND = nil
95
+
96
+ def setup
97
+ super
98
+ skip "Distributed backend not available" unless Torch::Distributed.available?
99
+ skip "No backend configured for test" unless backend
100
+ skip_unless_backend_available!
101
+ end
102
+
103
+ def backend
104
+ self.class::BACKEND
105
+ end
106
+
107
+ def tensor_options
108
+ {}
109
+ end
110
+
111
+ def skip_unless_backend_available!
112
+ skip "#{backend} backend not available" unless backend_available?
113
+ end
114
+
115
+ def backend_available?
116
+ timeout = distributed_timeout
117
+ port = Torch::Distributed.free_port
118
+ store = Torch::Distributed::TCPStore.new("127.0.0.1", port, 1, true, wait_for_workers: false, timeout: timeout)
119
+ Torch::Distributed.init_process_group(backend, store: store, rank: 0, world_size: 1, timeout: timeout)
120
+ true
121
+ rescue StandardError => e
122
+ return false if e.message =~ /not available/i || e.message =~ /unsupported backend/i
123
+ raise
124
+ ensure
125
+ Torch::Distributed.destroy_process_group if Torch::Distributed.initialized?
126
+ end
127
+
128
+ def fork_with_backend(world_size: 2, start_method: :spawn)
129
+ timeout = distributed_timeout
130
+ original_filter = ENV[Torch::Distributed::SPAWN_TEST_ENV_KEY]
131
+ original_script = ENV[Torch::Distributed::SPAWN_SCRIPT_ENV_KEY]
132
+ ENV[Torch::Distributed::SPAWN_TEST_ENV_KEY] = name if start_method == :spawn
133
+ ENV[Torch::Distributed::SPAWN_SCRIPT_ENV_KEY] = File.expand_path(__FILE__) if start_method == :spawn
134
+ Timeout.timeout(timeout, Timeout::Error, "distributed test exceeded #{timeout}s") do
135
+ Torch::Distributed.fork_world(world_size, start_method: start_method) do |rank, port|
136
+ Timeout.timeout(timeout, Timeout::Error, "distributed worker #{rank} exceeded #{timeout}s") do
137
+ store = Torch::Distributed::TCPStore.new("127.0.0.1", port, world_size, rank.zero?, timeout: timeout)
138
+ Torch::Distributed.init_process_group(
139
+ backend,
140
+ store: store,
141
+ rank: rank,
142
+ world_size: world_size,
143
+ device_id: rank,
144
+ timeout: timeout
145
+ )
146
+ begin
147
+ yield(rank)
148
+ ensure
149
+ Torch::Distributed.destroy_process_group
150
+ end
151
+ end
152
+ end
153
+ end
154
+ ensure
155
+ ENV[Torch::Distributed::SPAWN_TEST_ENV_KEY] = original_filter
156
+ ENV[Torch::Distributed::SPAWN_SCRIPT_ENV_KEY] = original_script
157
+ end
158
+
159
+ def test_all_reduce
160
+ results = fork_with_backend do |rank|
161
+ tensor = Torch.tensor([rank + 1.0], **tensor_options)
162
+ Torch::Distributed.all_reduce(tensor)
163
+ tensor.to_a
164
+ end
165
+
166
+ assert_equal [[3.0], [3.0]], results
167
+ end
168
+
169
+ def test_barrier
170
+ wait_times = fork_with_backend do |rank|
171
+ sleep 0.3 if rank.zero?
172
+ before = Process.clock_gettime(Process::CLOCK_MONOTONIC)
173
+ Torch::Distributed.barrier
174
+ after = Process.clock_gettime(Process::CLOCK_MONOTONIC)
175
+ after - before
176
+ end
177
+
178
+ assert_operator wait_times.first, :<, 0.1
179
+ assert_operator wait_times.last, :>=, 0.25
180
+ end
181
+
182
+ def test_broadcast
183
+ tensors = fork_with_backend do |rank|
184
+ tensor = Torch.tensor([rank + 1.0], **tensor_options)
185
+ Torch::Distributed.broadcast(tensor, src: 0)
186
+ tensor.to_a
187
+ end
188
+
189
+ assert_equal [[1.0], [1.0]], tensors
190
+ end
191
+
192
+ def test_ddp_gradient_sync
193
+ # autograd cannot run safely with fork-based multiprocessing; always use spawn here
194
+ grads = fork_with_backend(start_method: :spawn) do |rank|
195
+ device = tensor_options[:device]
196
+ model = Torch::NN::Linear.new(1, 1, bias: false)
197
+ model = model.to(device) if device
198
+ ddp = Torch::NN::Parallel::DistributedDataParallel.new(model)
199
+ input = Torch.tensor([[rank + 1.0]], **tensor_options)
200
+ output = ddp.call(input)
201
+ loss = output.sum
202
+ loss.backward
203
+
204
+ grad = model.parameters.first.grad
205
+ grad = grad.to("cpu") if device
206
+ grad.item
207
+ end
208
+
209
+ grads.each do |grad|
210
+ assert_in_delta 1.5, grad, 1e-6
211
+ end
212
+ end
213
+
214
+ def distributed_timeout
215
+ Integer(ENV.fetch("TORCH_DISTRIBUTED_TEST_TIMEOUT", "30"))
216
+ end
217
+ end
218
+
219
+ class DistributedGlooTest < DistributedBackendTest
220
+ BACKEND = "gloo"
221
+
222
+ def fork_with_backend(world_size: 2, start_method: :fork)
223
+ super(world_size: world_size, start_method: start_method)
224
+ end
225
+ end
226
+
227
+ class DistributedNcclTest < DistributedBackendTest
228
+ BACKEND = "nccl"
229
+
230
+ def setup
231
+ skip "CUDA not available for NCCL backend" unless Torch.const_defined?(:CUDA) && Torch::CUDA.available?
232
+ skip "Need at least 2 CUDA devices for NCCL tests" unless Torch::CUDA.device_count >= 2
233
+ super
234
+ end
235
+
236
+ def tensor_options
237
+ {device: "cuda"}
238
+ end
239
+
240
+ def fork_with_backend(world_size: 2, start_method: :spawn)
241
+ super(world_size: world_size, start_method: start_method)
242
+ end
243
+ end
@@ -0,0 +1,42 @@
1
+ class TestNet < Torch::NN::Module
2
+ def initialize
3
+ super()
4
+ @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
5
+ @conv2 = Torch::NN::Conv2d.new(6, 16, 3)
6
+ @fc1 = Torch::NN::Linear.new(16 * 6 * 6, 120)
7
+ @fc2 = Torch::NN::Linear.new(120, 84)
8
+ @fc3 = Torch::NN::Linear.new(84, 10)
9
+ end
10
+
11
+ def forward(x)
12
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv1.call(x)), [2, 2])
13
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv2.call(x)), 2)
14
+ x = Torch.flatten(x, 1)
15
+ x = Torch::NN::F.relu(@fc1.call(x))
16
+ x = Torch::NN::F.relu(@fc2.call(x))
17
+ @fc3.call(x)
18
+ end
19
+ end
20
+
21
+ class SimpleResidualBlock < Torch::NN::Module
22
+ def initialize
23
+ super()
24
+
25
+ @relu = Torch::NN::ReLU.new
26
+
27
+ @seq = Torch::NN::Sequential.new(
28
+ Torch::NN::Conv2d.new(64, 128, 3, padding: 1, bias: false),
29
+ Torch::NN::BatchNorm2d.new(128),
30
+ @relu,
31
+ Torch::NN::Conv2d.new(128, 128, 3, padding: 1, bias: false),
32
+ Torch::NN::BatchNorm2d.new(128),
33
+ @relu,
34
+ Torch::NN::Conv2d.new(128, 64, 3, bias: false),
35
+ Torch::NN::BatchNorm2d.new(64)
36
+ )
37
+ end
38
+
39
+ def forward(x)
40
+ @relu.forward(@seq.forward(x) + x)
41
+ end
42
+ end
@@ -0,0 +1,7 @@
1
+ # frozen_string_literal: true
2
+
3
+ $stdout.sync = true
4
+ rank = ENV.fetch("RANK", "unknown")
5
+ local_rank = ENV.fetch("LOCAL_RANK", "unknown")
6
+ world_size = ENV.fetch("WORLD_SIZE", "unknown")
7
+ puts "RANK=#{rank} LOCAL_RANK=#{local_rank} WORLD_SIZE=#{world_size}"
Binary file
@@ -0,0 +1,71 @@
1
+ spawn_worker = ENV["TORCH_DISTRIBUTED_SPAWNED"] == "1"
2
+
3
+ # Spawned distributed workers shouldn't try to load minitest plugins from the
4
+ # parent test environment.
5
+ ENV["MT_NO_PLUGINS"] = "1" if spawn_worker
6
+
7
+ require "bundler/setup"
8
+ Bundler.require(:default)
9
+ require "torch-ddp"
10
+ require "minitest/autorun"
11
+
12
+ if spawn_worker
13
+ module TorchDistributedSpawnTest
14
+ module QuietSummaryReporter
15
+ def start # :nodoc:
16
+ Minitest::StatisticsReporter.instance_method(:start).bind(self).call
17
+ self.sync = io.respond_to?(:"sync=")
18
+ self.old_sync, io.sync = io.sync, true if self.sync
19
+ end
20
+
21
+ def report # :nodoc:
22
+ super
23
+ ensure
24
+ io.sync = self.old_sync if self.sync
25
+ end
26
+ end
27
+ end
28
+
29
+ Minitest::SummaryReporter.prepend(TorchDistributedSpawnTest::QuietSummaryReporter)
30
+ end
31
+
32
+ # support
33
+ require_relative "support/net"
34
+
35
+ class Minitest::Test
36
+ def assert_elements_in_delta(expected, actual)
37
+ assert_equal expected.size, actual.size
38
+ expected.zip(actual) do |exp, act|
39
+ if exp.finite?
40
+ assert_in_delta exp, act
41
+ else
42
+ assert_equal exp, act
43
+ end
44
+ end
45
+ end
46
+
47
+ def assert_tensor(expected, actual, dtype: nil)
48
+ assert_kind_of Torch::Tensor, actual
49
+ assert_equal actual.dtype, dtype if dtype
50
+ if (actual.floating_point? || actual.complex?) && actual.dim < 2
51
+ assert_elements_in_delta expected, actual.to_a
52
+ else
53
+ assert_equal expected, actual.to_a
54
+ end
55
+ end
56
+
57
+ def mac?
58
+ RbConfig::CONFIG["host_os"] =~ /darwin/i
59
+ end
60
+
61
+ def stress_gc
62
+ previous = GC.stress
63
+ begin
64
+ GC.stress = true
65
+ yield
66
+ ensure
67
+ GC.stress = previous
68
+ GC.start
69
+ end
70
+ end
71
+ end
@@ -0,0 +1,33 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test_helper"
4
+
5
+ require "open3"
6
+ require "rbconfig"
7
+
8
+ class TorchRunTest < Minitest::Test
9
+ def test_standalone_launches_multiple_workers
10
+ script = File.expand_path("support/scripts/show_ranks.rb", __dir__)
11
+ torchrun = File.expand_path("../bin/torchrun", __dir__)
12
+ stdout, stderr, status = Open3.capture3(
13
+ {"TORCHRUN_TEST" => "1"},
14
+ RbConfig.ruby,
15
+ torchrun,
16
+ "--standalone",
17
+ "--nproc-per-node=2",
18
+ script
19
+ )
20
+
21
+ assert status.success?, "torchrun failed: #{stderr}"
22
+
23
+ lines = stdout.lines.map(&:strip).select { |line| line.start_with?("RANK=") }
24
+ assert_equal 2, lines.size, "expected two worker outputs, got: #{lines.inspect}"
25
+ ranks = lines.map do |line|
26
+ match = line.match(/RANK=(\d+)\s+LOCAL_RANK=(\d+)\s+WORLD_SIZE=(\d+)/)
27
+ raise "unexpected output: #{line}" unless match
28
+
29
+ [match[1].to_i, match[2].to_i, match[3].to_i]
30
+ end
31
+ assert_equal [[0, 0, 2], [1, 1, 2]], ranks.sort
32
+ end
33
+ end
metadata ADDED
@@ -0,0 +1,92 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torch-ddp
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Ivan Razuvaev
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2025-12-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: torch-rb
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.22.2
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.22.2
27
+ - !ruby/object:Gem::Dependency
28
+ name: rice
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '4.7'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '4.7'
41
+ description:
42
+ email: i@orlando-labs.com
43
+ executables:
44
+ - torchrun
45
+ extensions:
46
+ - ext/torch_ddp/extconf.rb
47
+ extra_rdoc_files: []
48
+ files:
49
+ - LICENSE.txt
50
+ - README.md
51
+ - bin/torchrun
52
+ - examples/benchmark/training.rb
53
+ - examples/mnist/distributed.rb
54
+ - ext/torch_ddp/distributed.cpp
55
+ - ext/torch_ddp/ext.cpp
56
+ - ext/torch_ddp/extconf.rb
57
+ - lib/torch-ddp.rb
58
+ - lib/torch/ddp/monkey_patch.rb
59
+ - lib/torch/ddp/version.rb
60
+ - lib/torch/distributed.rb
61
+ - lib/torch/nn/parallel/distributed_data_parallel.rb
62
+ - lib/torch/torchrun.rb
63
+ - test/distributed_test.rb
64
+ - test/support/net.rb
65
+ - test/support/scripts/show_ranks.rb
66
+ - test/support/tensor.pth
67
+ - test/test_helper.rb
68
+ - test/torchrun_test.rb
69
+ homepage: https://github.com/ankane/torch.rb
70
+ licenses:
71
+ - BSD-3-Clause
72
+ metadata: {}
73
+ post_install_message:
74
+ rdoc_options: []
75
+ require_paths:
76
+ - lib
77
+ required_ruby_version: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: '3.2'
82
+ required_rubygems_version: !ruby/object:Gem::Requirement
83
+ requirements:
84
+ - - ">="
85
+ - !ruby/object:Gem::Version
86
+ version: '0'
87
+ requirements: []
88
+ rubygems_version: 3.5.22
89
+ signing_key:
90
+ specification_version: 4
91
+ summary: Distributed data parallel support for torch-rb
92
+ test_files: []