torch-ddp 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,240 @@
1
+ # Distributed MNIST training with Torch::Distributed + DistributedDataParallel
2
+ # Run with: ruby examples/mnist/distributed.rb --gpus 2
3
+
4
+ require "bundler/setup"
5
+ require "optparse"
6
+ require "torch"
7
+ require "torch/distributed"
8
+ require "torch/nn/parallel/distributed_data_parallel"
9
+ require "torchvision"
10
+ require "tmpdir"
11
+
12
+ unless Torch::Distributed.available?
13
+ abort "torch.distributed was not built in this binary"
14
+ end
15
+
16
+ DEFAULT_CHECKPOINT_PATH = File.join(Dir.tmpdir, "mnist_ddp_checkpoint.pt")
17
+ DEFAULT_BACKEND = if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
18
+ "nccl"
19
+ else
20
+ Torch::Distributed.get_default_backend_for_device(Torch::Accelerator.current_accelerator) || "gloo"
21
+ end
22
+
23
+ class MyNet < Torch::NN::Module
24
+ def initialize
25
+ super()
26
+ @conv1 = Torch::NN::Conv2d.new(1, 32, 3, stride: 1)
27
+ @conv2 = Torch::NN::Conv2d.new(32, 64, 3, stride: 1)
28
+ @dropout1 = Torch::NN::Dropout2d.new(p: 0.25)
29
+ @dropout2 = Torch::NN::Dropout2d.new(p: 0.5)
30
+ @fc1 = Torch::NN::Linear.new(9216, 128)
31
+ @fc2 = Torch::NN::Linear.new(128, 10)
32
+ end
33
+
34
+ def forward(x)
35
+ x = Torch::NN::F.relu(@conv1.call(x))
36
+ x = Torch::NN::F.relu(@conv2.call(x))
37
+ x = Torch::NN::F.max_pool2d(x, 2)
38
+ x = @dropout1.call(x)
39
+ x = Torch.flatten(x, start_dim: 1)
40
+ x = Torch::NN::F.relu(@fc1.call(x))
41
+ x = @dropout2.call(x)
42
+ Torch::NN::F.log_softmax(@fc2.call(x), 1)
43
+ end
44
+ end
45
+
46
+ def parse_options
47
+ defaults = {
48
+ epochs: 5,
49
+ batch_size: 64,
50
+ lr: 1.0,
51
+ gamma: 0.7,
52
+ backend: DEFAULT_BACKEND,
53
+ gpus: Torch::CUDA.available? ? [Torch::CUDA.device_count, 1].max : 1,
54
+ log_interval: 20,
55
+ data_dir: File.join(__dir__, "data"),
56
+ checkpoint_path: DEFAULT_CHECKPOINT_PATH,
57
+ resume: false
58
+ }
59
+
60
+ OptionParser.new do |opts|
61
+ opts.banner = "Usage: ruby distributed.rb [options]"
62
+ opts.on("--epochs N", Integer, "Number of epochs (default: #{defaults[:epochs]})") { |v| defaults[:epochs] = v }
63
+ opts.on("--batch-size N", Integer, "Batch size per process (default: #{defaults[:batch_size]})") { |v| defaults[:batch_size] = v }
64
+ opts.on("--lr FLOAT", Float, "Learning rate (default: #{defaults[:lr]})") { |v| defaults[:lr] = v }
65
+ opts.on("--gamma FLOAT", Float, "LR scheduler gamma (default: #{defaults[:gamma]})") { |v| defaults[:gamma] = v }
66
+ opts.on("--backend NAME", String, "Process group backend (default: #{defaults[:backend]})") { |v| defaults[:backend] = v }
67
+ opts.on("--gpus N", Integer, "Number of GPUs/processes to use") { |v| defaults[:gpus] = v }
68
+ opts.on("--log-interval N", Integer, "Batches between log statements") { |v| defaults[:log_interval] = v }
69
+ opts.on("--data-dir PATH", String, "Directory for cached MNIST data") { |v| defaults[:data_dir] = v }
70
+ opts.on("--checkpoint PATH", String, "Checkpoint file to save to (default: #{defaults[:checkpoint_path]})") { |v| defaults[:checkpoint_path] = v }
71
+ opts.on("--resume", "Load checkpoint weights before training if the file exists") { defaults[:resume] = true }
72
+ end.parse!(ARGV)
73
+
74
+ defaults
75
+ end
76
+
77
+ def load_datasets(rank, data_dir)
78
+ transforms = TorchVision::Transforms::Compose.new([
79
+ TorchVision::Transforms::ToTensor.new,
80
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
81
+ ])
82
+
83
+ if rank.zero?
84
+ train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: true, transform: transforms)
85
+ test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: true, transform: transforms)
86
+ Torch::Distributed.barrier
87
+ else
88
+ Torch::Distributed.barrier
89
+ train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: false, transform: transforms)
90
+ test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: false, transform: transforms)
91
+ end
92
+
93
+ [train, test]
94
+ end
95
+
96
+ def subset_for_rank(dataset, rank, world_size)
97
+ indices = rank.step(dataset.size - 1, world_size).to_a
98
+ Torch::Utils::Data::Subset.new(dataset, indices)
99
+ end
100
+
101
+ def checkpoint_map_location(device, rank)
102
+ accelerator_device = Torch::Accelerator.current_accelerator
103
+ return nil unless accelerator_device
104
+
105
+ accelerator_type = accelerator_device.type
106
+ target_index = device.index
107
+ if target_index.nil? && Torch::Accelerator.respond_to?(:device_count)
108
+ count = Torch::Accelerator.device_count
109
+ target_index = count.positive? ? rank % count : 0
110
+ end
111
+ { "#{accelerator_type}:0" => "#{accelerator_type}:#{target_index}" }
112
+ end
113
+
114
+ def load_checkpoint_if_present(ddp, device, rank, path)
115
+ return false unless path && File.exist?(path)
116
+
117
+ Torch::Distributed.barrier
118
+ kwargs = { weights_only: true }
119
+ map_location = checkpoint_map_location(device, rank)
120
+ kwargs[:map_location] = map_location if map_location
121
+ state_dict = Torch.load(path, **kwargs)
122
+ ddp.module.load_state_dict(state_dict)
123
+ true
124
+ end
125
+
126
+ def save_checkpoint(ddp, path, rank)
127
+ return unless path
128
+
129
+ Torch.save(ddp.module.state_dict, path) if rank.zero?
130
+ Torch::Distributed.barrier
131
+ puts "Saved checkpoint to #{path}" if rank.zero?
132
+ end
133
+
134
+ def train_epoch(model, device, loader, optimizer, epoch, rank, log_interval)
135
+ model.train
136
+ loader.each_with_index do |(data, target), batch_idx|
137
+ data = data.to(device)
138
+ target = target.to(device)
139
+
140
+ optimizer.zero_grad
141
+ loss = Torch::NN::F.nll_loss(model.call(data), target)
142
+ loss.backward
143
+ optimizer.step
144
+
145
+ next unless rank.zero? && (batch_idx % log_interval).zero?
146
+
147
+ processed = batch_idx * data.size(0)
148
+ total = loader.dataset.size
149
+ percent = 100.0 * processed / total
150
+ puts "Rank #{rank} | Epoch #{epoch} [#{processed}/#{total} (#{percent.round})%] Loss: #{'%.4f' % loss.item}"
151
+ end
152
+ end
153
+
154
+ def evaluate(model, device, loader)
155
+ model.eval
156
+ loss = 0.0
157
+ correct = 0
158
+ Torch.no_grad do
159
+ loader.each do |data, target|
160
+ data = data.to(device)
161
+ target = target.to(device)
162
+ output = model.call(data)
163
+ loss += Torch::NN::F.nll_loss(output, target, reduction: "sum").item
164
+ pred = output.argmax(1, keepdim: true)
165
+ correct += pred.eq(target.view_as(pred)).sum.item
166
+ end
167
+ end
168
+
169
+ loss /= loader.dataset.size
170
+ acc = 100.0 * correct / loader.dataset.size
171
+ puts "Test set: Average loss: #{format('%.4f', loss)}, Accuracy: #{correct}/#{loader.dataset.size} (#{format('%.1f', acc)}%)"
172
+ end
173
+
174
+ def run_worker(rank, world_size, port, options)
175
+ store = Torch::Distributed::TCPStore.new("127.0.0.1", port, world_size, rank.zero?)
176
+ accelerator = Torch::Accelerator.current_accelerator
177
+ backend = options[:backend] || Torch::Distributed.get_default_backend_for_device(accelerator) || DEFAULT_BACKEND
178
+ Torch::Distributed.init_process_group(backend, store: store, rank: rank, world_size: world_size)
179
+
180
+ device = if Torch::CUDA.available? && options[:gpus] > 0
181
+ Torch.device("cuda:#{rank % Torch::CUDA.device_count}")
182
+ else
183
+ Torch.device("cpu")
184
+ end
185
+
186
+ model = MyNet.new.to(device)
187
+ ddp = Torch::NN::Parallel::DistributedDataParallel.new(model, device_ids: device.type == "cuda" ? [device.index] : nil)
188
+ optimizer = Torch::Optim::Adadelta.new(ddp.module.parameters, lr: options[:lr])
189
+ scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: 1, gamma: options[:gamma])
190
+
191
+ train_dataset, test_dataset = load_datasets(rank, options[:data_dir])
192
+ train_subset = subset_for_rank(train_dataset, rank, world_size)
193
+ train_loader = Torch::Utils::Data::DataLoader.new(train_subset, batch_size: options[:batch_size], shuffle: true)
194
+ test_loader = Torch::Utils::Data::DataLoader.new(test_dataset, batch_size: options[:batch_size], shuffle: false) if rank.zero?
195
+ checkpoint_path = options[:checkpoint_path]
196
+
197
+ if options[:resume]
198
+ loaded = load_checkpoint_if_present(ddp, device, rank, checkpoint_path)
199
+ if rank.zero?
200
+ if loaded
201
+ puts "Loaded checkpoint weights from #{checkpoint_path}"
202
+ else
203
+ puts "No checkpoint found at #{checkpoint_path}, starting from random initialization"
204
+ end
205
+ end
206
+ end
207
+
208
+ options[:epochs].times do |epoch_idx|
209
+ epoch = epoch_idx + 1
210
+ train_epoch(ddp, device, train_loader, optimizer, epoch, rank, options[:log_interval])
211
+ if rank.zero?
212
+ evaluate(ddp.module, device, test_loader)
213
+ end
214
+ save_checkpoint(ddp, checkpoint_path, rank) if checkpoint_path
215
+ end
216
+
217
+ Torch::Distributed.destroy_process_group
218
+ end
219
+
220
+ options = parse_options
221
+ world_size = options[:gpus]
222
+ raise "Number of GPUs requested must be >= 1" if world_size < 1
223
+ if Torch::CUDA.available?
224
+ max_devices = Torch::CUDA.device_count
225
+ if world_size > max_devices
226
+ raise "Requested #{world_size} GPUs but only #{max_devices} visible"
227
+ end
228
+ else
229
+ puts "CUDA not available, running #{world_size} CPU workers"
230
+ end
231
+
232
+ Torch.manual_seed(1)
233
+
234
+ if world_size == 1
235
+ run_worker(0, 1, Torch::Distributed.free_port, options)
236
+ else
237
+ Torch::Distributed.fork_world(world_size, start_method: :spawn) do |rank, port|
238
+ run_worker(rank, world_size, port, options)
239
+ end
240
+ end
@@ -0,0 +1,348 @@
1
+ #include <algorithm>
2
+ #include <chrono>
3
+ #include <cctype>
4
+ #include <cstdlib>
5
+ #include <memory>
6
+ #include <mutex>
7
+ #include <string>
8
+ #include <vector>
9
+
10
+ #include <torch/torch.h>
11
+ #if defined(USE_C10D) && defined(USE_C10D_NCCL)
12
+ #include <torch/cuda.h>
13
+ #include <c10/cuda/CUDAFunctions.h>
14
+ #endif
15
+
16
+ #include <rice/rice.hpp>
17
+ #include <rice/stl.hpp>
18
+
19
+ static_assert(
20
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 9,
21
+ "Incompatible LibTorch version");
22
+
23
+ #ifdef USE_C10D
24
+ #include <torch/csrc/distributed/c10d/Backend.hpp>
25
+ #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
26
+ #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
27
+ #include <torch/csrc/distributed/c10d/TCPStore.hpp>
28
+ #include <torch/csrc/distributed/c10d/FileStore.hpp>
29
+ #include <torch/csrc/distributed/c10d/Work.hpp>
30
+ #include <torch/csrc/distributed/c10d/Types.hpp>
31
+ #endif
32
+
33
+ #if defined(USE_C10D) && defined(USE_C10D_NCCL)
34
+ #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
35
+ #endif
36
+
37
+ #if defined(USE_C10D) && !defined(_WIN32)
38
+ #include <torch/csrc/distributed/c10d/HashStore.hpp>
39
+ #endif
40
+
41
+ namespace {
42
+
43
+ #ifdef USE_C10D
44
+
45
+ using StorePtr = c10::intrusive_ptr<::c10d::Store>;
46
+ using ProcessGroupPtr = c10::intrusive_ptr<::c10d::Backend>;
47
+
48
+ struct StoreWrapper {
49
+ StoreWrapper() = default;
50
+ explicit StoreWrapper(StorePtr store) : store_(std::move(store)) {}
51
+
52
+ StorePtr store_;
53
+ };
54
+
55
+ struct ProcessGroupWrapper {
56
+ ProcessGroupWrapper() = default;
57
+ explicit ProcessGroupWrapper(ProcessGroupPtr pg) : pg_(std::move(pg)) {}
58
+
59
+ ProcessGroupPtr pg_;
60
+ };
61
+
62
+ ProcessGroupPtr default_process_group;
63
+ std::once_flag default_pg_cleanup_once;
64
+
65
+ void shutdown_default_process_group() {
66
+ if (default_process_group) {
67
+ try {
68
+ default_process_group->shutdown();
69
+ } catch (...) {
70
+ // best effort; ensure reset still happens
71
+ }
72
+ default_process_group.reset();
73
+ }
74
+ }
75
+
76
+ void register_default_pg_cleanup() {
77
+ std::call_once(default_pg_cleanup_once, []() {
78
+ std::atexit([]() { shutdown_default_process_group(); });
79
+ });
80
+ }
81
+
82
+ ProcessGroupPtr resolve_process_group(Rice::Object pg_obj) {
83
+ if (pg_obj.is_nil()) {
84
+ if (!default_process_group) {
85
+ rb_raise(rb_eRuntimeError, "Distributed process group not initialized");
86
+ }
87
+ return default_process_group;
88
+ }
89
+ auto& wrapper = Rice::detail::From_Ruby<ProcessGroupWrapper&>().convert(pg_obj.value());
90
+ if (!wrapper.pg_) {
91
+ rb_raise(rb_eRuntimeError, "Invalid process group");
92
+ }
93
+ return wrapper.pg_;
94
+ }
95
+
96
+ int reduce_op_from_int(int code) {
97
+ if (code < 0 || code > static_cast<int>(::c10d::ReduceOp::UNUSED)) {
98
+ rb_raise(rb_eArgError, "Unknown reduce op code");
99
+ }
100
+ return code;
101
+ }
102
+
103
+ #endif
104
+
105
+ } // namespace
106
+
107
+ void init_distributed(Rice::Module& m) {
108
+ auto rb_mDistributed = Rice::define_module_under(m, "Distributed");
109
+ #ifdef USE_C10D
110
+ register_default_pg_cleanup();
111
+ rb_mDistributed.define_singleton_function("available?", []() { return true; });
112
+
113
+ auto rb_cStore = Rice::define_class_under<StoreWrapper>(rb_mDistributed, "Store");
114
+ rb_cStore.define_method(
115
+ "_native?",
116
+ [](StoreWrapper& self) {
117
+ return static_cast<bool>(self.store_);
118
+ });
119
+
120
+ auto rb_cProcessGroup = Rice::define_class_under<ProcessGroupWrapper>(rb_mDistributed, "ProcessGroup")
121
+ .define_method(
122
+ "rank",
123
+ [](ProcessGroupWrapper& self) {
124
+ return self.pg_ ? self.pg_->getRank() : -1;
125
+ })
126
+ .define_method(
127
+ "size",
128
+ [](ProcessGroupWrapper& self) {
129
+ return self.pg_ ? self.pg_->getSize() : 0;
130
+ })
131
+ .define_method(
132
+ "backend",
133
+ [](ProcessGroupWrapper& self) {
134
+ if (!self.pg_) {
135
+ return std::string();
136
+ }
137
+ return self.pg_->getBackendName();
138
+ });
139
+
140
+ rb_mDistributed.define_singleton_function(
141
+ "_create_tcp_store",
142
+ [rb_cStore](const std::string& host,
143
+ int port,
144
+ int world_size,
145
+ bool is_master,
146
+ int64_t timeout_millis,
147
+ bool wait_for_workers) -> Rice::Object {
148
+ ::c10d::TCPStoreOptions opts;
149
+ opts.port = static_cast<uint16_t>(port);
150
+ opts.isServer = is_master;
151
+ opts.numWorkers = world_size;
152
+ opts.waitWorkers = wait_for_workers;
153
+ opts.timeout = std::chrono::milliseconds(timeout_millis);
154
+ auto store = c10::make_intrusive<::c10d::TCPStore>(host, opts);
155
+ // Pass ownership first, then the Ruby class so Rice doesn't treat the class as the owner flag
156
+ return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
157
+ });
158
+
159
+ rb_mDistributed.define_singleton_function(
160
+ "_create_file_store",
161
+ [rb_cStore](const std::string& path, int world_size) -> Rice::Object {
162
+ auto store = c10::make_intrusive<::c10d::FileStore>(path, world_size);
163
+ return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
164
+ });
165
+
166
+ #if !defined(_WIN32)
167
+ rb_mDistributed.define_singleton_function(
168
+ "_create_hash_store",
169
+ [rb_cStore]() -> Rice::Object {
170
+ auto store = c10::make_intrusive<::c10d::HashStore>();
171
+ return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
172
+ });
173
+ #endif
174
+
175
+ rb_mDistributed.define_singleton_function(
176
+ "_init_process_group",
177
+ [rb_cProcessGroup](const std::string& backend,
178
+ StoreWrapper& store_wrapper,
179
+ int rank,
180
+ int world_size,
181
+ int64_t timeout_millis,
182
+ int device_id) -> Rice::Object {
183
+ StorePtr store = store_wrapper.store_;
184
+ if (!store) {
185
+ rb_raise(rb_eArgError, "Store is required for init_process_group");
186
+ }
187
+
188
+ std::string backend_lower = backend;
189
+ std::transform(backend_lower.begin(), backend_lower.end(), backend_lower.begin(), ::tolower);
190
+
191
+ ProcessGroupPtr pg;
192
+ if (backend_lower == "gloo") {
193
+ #ifdef USE_C10D_GLOO
194
+ auto options = ::c10d::ProcessGroupGloo::Options::create();
195
+ options->timeout = std::chrono::milliseconds(timeout_millis);
196
+ options->devices.push_back(::c10d::ProcessGroupGloo::createDefaultDevice());
197
+ pg = c10::make_intrusive<::c10d::ProcessGroupGloo>(store, rank, world_size, options);
198
+ #else
199
+ rb_raise(rb_eRuntimeError, "Gloo backend is not available in this build");
200
+ #endif
201
+ } else if (backend_lower == "nccl") {
202
+ #if defined(USE_C10D_NCCL)
203
+ auto options = c10::make_intrusive<::c10d::ProcessGroupNCCL::Options>();
204
+ options->timeout = std::chrono::milliseconds(timeout_millis);
205
+ pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(store, rank, world_size, options);
206
+ #else
207
+ rb_raise(rb_eRuntimeError, "NCCL backend is not available in this build");
208
+ #endif
209
+ } else {
210
+ rb_raise(rb_eArgError, "Unsupported backend: %s", backend.c_str());
211
+ }
212
+
213
+ if (device_id >= 0 && backend_lower == "nccl") {
214
+ #if defined(USE_C10D_NCCL)
215
+ if (!torch::cuda::is_available()) {
216
+ rb_raise(rb_eRuntimeError, "CUDA is not available for NCCL backend");
217
+ }
218
+ auto device_count = torch::cuda::device_count();
219
+ if (device_id >= static_cast<int>(device_count)) {
220
+ rb_raise(
221
+ rb_eArgError,
222
+ "Invalid device_id %d for NCCL backend (available devices: %d)",
223
+ device_id,
224
+ static_cast<int>(device_count));
225
+ }
226
+ c10::cuda::set_device(device_id);
227
+ pg->setBoundDeviceId(c10::Device(c10::kCUDA, device_id));
228
+ #endif
229
+ }
230
+
231
+ default_process_group = pg;
232
+ return Rice::Data_Object<ProcessGroupWrapper>(new ProcessGroupWrapper(pg), true, rb_cProcessGroup);
233
+ });
234
+
235
+ rb_mDistributed.define_singleton_function(
236
+ "_destroy_process_group",
237
+ []() {
238
+ shutdown_default_process_group();
239
+ return Rice::Nil;
240
+ });
241
+
242
+ rb_mDistributed.define_singleton_function(
243
+ "_initialized?",
244
+ []() {
245
+ return static_cast<bool>(default_process_group);
246
+ });
247
+
248
+ rb_mDistributed.define_singleton_function(
249
+ "_default_process_group",
250
+ [rb_cProcessGroup]() -> Rice::Object {
251
+ if (!default_process_group) {
252
+ return Rice::Nil;
253
+ }
254
+ return Rice::Data_Object<ProcessGroupWrapper>(new ProcessGroupWrapper(default_process_group), true, rb_cProcessGroup);
255
+ });
256
+
257
+ rb_mDistributed.define_singleton_function(
258
+ "_get_world_size",
259
+ [](Rice::Object pg_obj) {
260
+ auto pg = resolve_process_group(pg_obj);
261
+ return pg->getSize();
262
+ });
263
+
264
+ rb_mDistributed.define_singleton_function(
265
+ "_get_rank",
266
+ [](Rice::Object pg_obj) {
267
+ auto pg = resolve_process_group(pg_obj);
268
+ return pg->getRank();
269
+ });
270
+
271
+ rb_mDistributed.define_singleton_function(
272
+ "_barrier",
273
+ [](Rice::Object pg_obj) {
274
+ auto pg = resolve_process_group(pg_obj);
275
+ ::c10d::BarrierOptions opts;
276
+ auto work = pg->barrier(opts);
277
+ work->wait();
278
+ return Rice::Nil;
279
+ });
280
+
281
+ rb_mDistributed.define_singleton_function(
282
+ "_all_reduce",
283
+ [](torch::Tensor& tensor, int op_code, Rice::Object pg_obj) {
284
+ auto pg = resolve_process_group(pg_obj);
285
+ ::c10d::AllreduceOptions opts;
286
+ opts.reduceOp = ::c10d::ReduceOp(static_cast<::c10d::ReduceOp::RedOpType>(reduce_op_from_int(op_code)));
287
+ std::vector<at::Tensor> tensors{tensor};
288
+ auto work = pg->allreduce(tensors, opts);
289
+ work->wait();
290
+ return tensor;
291
+ });
292
+
293
+ rb_mDistributed.define_singleton_function(
294
+ "_broadcast",
295
+ [](torch::Tensor& tensor, int src, Rice::Object pg_obj) {
296
+ auto pg = resolve_process_group(pg_obj);
297
+ ::c10d::BroadcastOptions opts;
298
+ opts.rootRank = src;
299
+ std::vector<at::Tensor> tensors{tensor};
300
+ auto work = pg->broadcast(tensors, opts);
301
+ work->wait();
302
+ return tensor;
303
+ });
304
+
305
+ rb_mDistributed.define_singleton_function(
306
+ "_register_ddp_hook",
307
+ [](torch::Tensor& tensor, ProcessGroupWrapper& pg_wrapper, int world_size) -> unsigned {
308
+ if (!pg_wrapper.pg_) {
309
+ rb_raise(rb_eArgError, "Process group is required for DDP hook registration");
310
+ }
311
+ if (world_size <= 0) {
312
+ rb_raise(rb_eArgError, "world_size must be positive");
313
+ }
314
+
315
+ auto pg = pg_wrapper.pg_;
316
+ // Register a native autograd hook that all-reduces gradients and scales
317
+ // them by the world size. This avoids calling back into Ruby from
318
+ // autograd worker threads.
319
+ unsigned handle = tensor.register_hook([pg, world_size](const at::Tensor& grad) {
320
+ ::c10d::AllreduceOptions opts;
321
+ opts.reduceOp = ::c10d::ReduceOp::SUM;
322
+ std::vector<at::Tensor> tensors{grad};
323
+ auto work = pg->allreduce(tensors, opts);
324
+ work->wait();
325
+ grad.div_(static_cast<double>(world_size));
326
+ return grad;
327
+ });
328
+
329
+ return handle;
330
+ });
331
+
332
+ auto rb_mReduceOp = Rice::define_module_under(rb_mDistributed, "ReduceOp");
333
+ rb_mReduceOp.const_set("SUM", INT2NUM(static_cast<int>(::c10d::ReduceOp::SUM)));
334
+ rb_mReduceOp.const_set("AVG", INT2NUM(static_cast<int>(::c10d::ReduceOp::AVG)));
335
+ rb_mReduceOp.const_set("PRODUCT", INT2NUM(static_cast<int>(::c10d::ReduceOp::PRODUCT)));
336
+ rb_mReduceOp.const_set("MIN", INT2NUM(static_cast<int>(::c10d::ReduceOp::MIN)));
337
+ rb_mReduceOp.const_set("MAX", INT2NUM(static_cast<int>(::c10d::ReduceOp::MAX)));
338
+ rb_mReduceOp.const_set("BAND", INT2NUM(static_cast<int>(::c10d::ReduceOp::BAND)));
339
+ rb_mReduceOp.const_set("BOR", INT2NUM(static_cast<int>(::c10d::ReduceOp::BOR)));
340
+ rb_mReduceOp.const_set("BXOR", INT2NUM(static_cast<int>(::c10d::ReduceOp::BXOR)));
341
+ rb_mReduceOp.const_set("PREMUL_SUM", INT2NUM(static_cast<int>(::c10d::ReduceOp::PREMUL_SUM)));
342
+
343
+ rb_mDistributed.const_set("DEFAULT_TIMEOUT", INT2NUM(::kProcessGroupDefaultTimeout.count() / 1000));
344
+ #else
345
+ rb_mDistributed.define_singleton_function("available?", []() { return false; });
346
+ rb_mDistributed.const_set("DEFAULT_TIMEOUT", INT2NUM(30 * 60));
347
+ #endif
348
+ }
@@ -0,0 +1,11 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ void init_distributed(Rice::Module& m);
6
+
7
+ extern "C"
8
+ void Init_ddp_ext() {
9
+ auto m = Rice::define_module("Torch");
10
+ init_distributed(m);
11
+ }