torch-ddp 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +46 -0
- data/README.md +114 -0
- data/bin/torchrun +6 -0
- data/examples/benchmark/training.rb +374 -0
- data/examples/mnist/distributed.rb +240 -0
- data/ext/torch_ddp/distributed.cpp +348 -0
- data/ext/torch_ddp/ext.cpp +11 -0
- data/ext/torch_ddp/extconf.rb +155 -0
- data/lib/torch/ddp/monkey_patch.rb +325 -0
- data/lib/torch/ddp/version.rb +5 -0
- data/lib/torch/distributed.rb +466 -0
- data/lib/torch/nn/parallel/distributed_data_parallel.rb +115 -0
- data/lib/torch/torchrun.rb +531 -0
- data/lib/torch-ddp.rb +8 -0
- data/test/distributed_test.rb +243 -0
- data/test/support/net.rb +42 -0
- data/test/support/scripts/show_ranks.rb +7 -0
- data/test/support/tensor.pth +0 -0
- data/test/test_helper.rb +71 -0
- data/test/torchrun_test.rb +33 -0
- metadata +92 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Distributed MNIST training with Torch::Distributed + DistributedDataParallel
|
|
2
|
+
# Run with: ruby examples/mnist/distributed.rb --gpus 2
|
|
3
|
+
|
|
4
|
+
require "bundler/setup"
|
|
5
|
+
require "optparse"
|
|
6
|
+
require "torch"
|
|
7
|
+
require "torch/distributed"
|
|
8
|
+
require "torch/nn/parallel/distributed_data_parallel"
|
|
9
|
+
require "torchvision"
|
|
10
|
+
require "tmpdir"
|
|
11
|
+
|
|
12
|
+
unless Torch::Distributed.available?
|
|
13
|
+
abort "torch.distributed was not built in this binary"
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
DEFAULT_CHECKPOINT_PATH = File.join(Dir.tmpdir, "mnist_ddp_checkpoint.pt")
|
|
17
|
+
DEFAULT_BACKEND = if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available?
|
|
18
|
+
"nccl"
|
|
19
|
+
else
|
|
20
|
+
Torch::Distributed.get_default_backend_for_device(Torch::Accelerator.current_accelerator) || "gloo"
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
class MyNet < Torch::NN::Module
|
|
24
|
+
def initialize
|
|
25
|
+
super()
|
|
26
|
+
@conv1 = Torch::NN::Conv2d.new(1, 32, 3, stride: 1)
|
|
27
|
+
@conv2 = Torch::NN::Conv2d.new(32, 64, 3, stride: 1)
|
|
28
|
+
@dropout1 = Torch::NN::Dropout2d.new(p: 0.25)
|
|
29
|
+
@dropout2 = Torch::NN::Dropout2d.new(p: 0.5)
|
|
30
|
+
@fc1 = Torch::NN::Linear.new(9216, 128)
|
|
31
|
+
@fc2 = Torch::NN::Linear.new(128, 10)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def forward(x)
|
|
35
|
+
x = Torch::NN::F.relu(@conv1.call(x))
|
|
36
|
+
x = Torch::NN::F.relu(@conv2.call(x))
|
|
37
|
+
x = Torch::NN::F.max_pool2d(x, 2)
|
|
38
|
+
x = @dropout1.call(x)
|
|
39
|
+
x = Torch.flatten(x, start_dim: 1)
|
|
40
|
+
x = Torch::NN::F.relu(@fc1.call(x))
|
|
41
|
+
x = @dropout2.call(x)
|
|
42
|
+
Torch::NN::F.log_softmax(@fc2.call(x), 1)
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def parse_options
|
|
47
|
+
defaults = {
|
|
48
|
+
epochs: 5,
|
|
49
|
+
batch_size: 64,
|
|
50
|
+
lr: 1.0,
|
|
51
|
+
gamma: 0.7,
|
|
52
|
+
backend: DEFAULT_BACKEND,
|
|
53
|
+
gpus: Torch::CUDA.available? ? [Torch::CUDA.device_count, 1].max : 1,
|
|
54
|
+
log_interval: 20,
|
|
55
|
+
data_dir: File.join(__dir__, "data"),
|
|
56
|
+
checkpoint_path: DEFAULT_CHECKPOINT_PATH,
|
|
57
|
+
resume: false
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
OptionParser.new do |opts|
|
|
61
|
+
opts.banner = "Usage: ruby distributed.rb [options]"
|
|
62
|
+
opts.on("--epochs N", Integer, "Number of epochs (default: #{defaults[:epochs]})") { |v| defaults[:epochs] = v }
|
|
63
|
+
opts.on("--batch-size N", Integer, "Batch size per process (default: #{defaults[:batch_size]})") { |v| defaults[:batch_size] = v }
|
|
64
|
+
opts.on("--lr FLOAT", Float, "Learning rate (default: #{defaults[:lr]})") { |v| defaults[:lr] = v }
|
|
65
|
+
opts.on("--gamma FLOAT", Float, "LR scheduler gamma (default: #{defaults[:gamma]})") { |v| defaults[:gamma] = v }
|
|
66
|
+
opts.on("--backend NAME", String, "Process group backend (default: #{defaults[:backend]})") { |v| defaults[:backend] = v }
|
|
67
|
+
opts.on("--gpus N", Integer, "Number of GPUs/processes to use") { |v| defaults[:gpus] = v }
|
|
68
|
+
opts.on("--log-interval N", Integer, "Batches between log statements") { |v| defaults[:log_interval] = v }
|
|
69
|
+
opts.on("--data-dir PATH", String, "Directory for cached MNIST data") { |v| defaults[:data_dir] = v }
|
|
70
|
+
opts.on("--checkpoint PATH", String, "Checkpoint file to save to (default: #{defaults[:checkpoint_path]})") { |v| defaults[:checkpoint_path] = v }
|
|
71
|
+
opts.on("--resume", "Load checkpoint weights before training if the file exists") { defaults[:resume] = true }
|
|
72
|
+
end.parse!(ARGV)
|
|
73
|
+
|
|
74
|
+
defaults
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def load_datasets(rank, data_dir)
|
|
78
|
+
transforms = TorchVision::Transforms::Compose.new([
|
|
79
|
+
TorchVision::Transforms::ToTensor.new,
|
|
80
|
+
TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
|
|
81
|
+
])
|
|
82
|
+
|
|
83
|
+
if rank.zero?
|
|
84
|
+
train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: true, transform: transforms)
|
|
85
|
+
test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: true, transform: transforms)
|
|
86
|
+
Torch::Distributed.barrier
|
|
87
|
+
else
|
|
88
|
+
Torch::Distributed.barrier
|
|
89
|
+
train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: false, transform: transforms)
|
|
90
|
+
test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: false, transform: transforms)
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
[train, test]
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def subset_for_rank(dataset, rank, world_size)
|
|
97
|
+
indices = rank.step(dataset.size - 1, world_size).to_a
|
|
98
|
+
Torch::Utils::Data::Subset.new(dataset, indices)
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def checkpoint_map_location(device, rank)
|
|
102
|
+
accelerator_device = Torch::Accelerator.current_accelerator
|
|
103
|
+
return nil unless accelerator_device
|
|
104
|
+
|
|
105
|
+
accelerator_type = accelerator_device.type
|
|
106
|
+
target_index = device.index
|
|
107
|
+
if target_index.nil? && Torch::Accelerator.respond_to?(:device_count)
|
|
108
|
+
count = Torch::Accelerator.device_count
|
|
109
|
+
target_index = count.positive? ? rank % count : 0
|
|
110
|
+
end
|
|
111
|
+
{ "#{accelerator_type}:0" => "#{accelerator_type}:#{target_index}" }
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def load_checkpoint_if_present(ddp, device, rank, path)
|
|
115
|
+
return false unless path && File.exist?(path)
|
|
116
|
+
|
|
117
|
+
Torch::Distributed.barrier
|
|
118
|
+
kwargs = { weights_only: true }
|
|
119
|
+
map_location = checkpoint_map_location(device, rank)
|
|
120
|
+
kwargs[:map_location] = map_location if map_location
|
|
121
|
+
state_dict = Torch.load(path, **kwargs)
|
|
122
|
+
ddp.module.load_state_dict(state_dict)
|
|
123
|
+
true
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def save_checkpoint(ddp, path, rank)
|
|
127
|
+
return unless path
|
|
128
|
+
|
|
129
|
+
Torch.save(ddp.module.state_dict, path) if rank.zero?
|
|
130
|
+
Torch::Distributed.barrier
|
|
131
|
+
puts "Saved checkpoint to #{path}" if rank.zero?
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
def train_epoch(model, device, loader, optimizer, epoch, rank, log_interval)
|
|
135
|
+
model.train
|
|
136
|
+
loader.each_with_index do |(data, target), batch_idx|
|
|
137
|
+
data = data.to(device)
|
|
138
|
+
target = target.to(device)
|
|
139
|
+
|
|
140
|
+
optimizer.zero_grad
|
|
141
|
+
loss = Torch::NN::F.nll_loss(model.call(data), target)
|
|
142
|
+
loss.backward
|
|
143
|
+
optimizer.step
|
|
144
|
+
|
|
145
|
+
next unless rank.zero? && (batch_idx % log_interval).zero?
|
|
146
|
+
|
|
147
|
+
processed = batch_idx * data.size(0)
|
|
148
|
+
total = loader.dataset.size
|
|
149
|
+
percent = 100.0 * processed / total
|
|
150
|
+
puts "Rank #{rank} | Epoch #{epoch} [#{processed}/#{total} (#{percent.round})%] Loss: #{'%.4f' % loss.item}"
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
def evaluate(model, device, loader)
|
|
155
|
+
model.eval
|
|
156
|
+
loss = 0.0
|
|
157
|
+
correct = 0
|
|
158
|
+
Torch.no_grad do
|
|
159
|
+
loader.each do |data, target|
|
|
160
|
+
data = data.to(device)
|
|
161
|
+
target = target.to(device)
|
|
162
|
+
output = model.call(data)
|
|
163
|
+
loss += Torch::NN::F.nll_loss(output, target, reduction: "sum").item
|
|
164
|
+
pred = output.argmax(1, keepdim: true)
|
|
165
|
+
correct += pred.eq(target.view_as(pred)).sum.item
|
|
166
|
+
end
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
loss /= loader.dataset.size
|
|
170
|
+
acc = 100.0 * correct / loader.dataset.size
|
|
171
|
+
puts "Test set: Average loss: #{format('%.4f', loss)}, Accuracy: #{correct}/#{loader.dataset.size} (#{format('%.1f', acc)}%)"
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
def run_worker(rank, world_size, port, options)
|
|
175
|
+
store = Torch::Distributed::TCPStore.new("127.0.0.1", port, world_size, rank.zero?)
|
|
176
|
+
accelerator = Torch::Accelerator.current_accelerator
|
|
177
|
+
backend = options[:backend] || Torch::Distributed.get_default_backend_for_device(accelerator) || DEFAULT_BACKEND
|
|
178
|
+
Torch::Distributed.init_process_group(backend, store: store, rank: rank, world_size: world_size)
|
|
179
|
+
|
|
180
|
+
device = if Torch::CUDA.available? && options[:gpus] > 0
|
|
181
|
+
Torch.device("cuda:#{rank % Torch::CUDA.device_count}")
|
|
182
|
+
else
|
|
183
|
+
Torch.device("cpu")
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
model = MyNet.new.to(device)
|
|
187
|
+
ddp = Torch::NN::Parallel::DistributedDataParallel.new(model, device_ids: device.type == "cuda" ? [device.index] : nil)
|
|
188
|
+
optimizer = Torch::Optim::Adadelta.new(ddp.module.parameters, lr: options[:lr])
|
|
189
|
+
scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: 1, gamma: options[:gamma])
|
|
190
|
+
|
|
191
|
+
train_dataset, test_dataset = load_datasets(rank, options[:data_dir])
|
|
192
|
+
train_subset = subset_for_rank(train_dataset, rank, world_size)
|
|
193
|
+
train_loader = Torch::Utils::Data::DataLoader.new(train_subset, batch_size: options[:batch_size], shuffle: true)
|
|
194
|
+
test_loader = Torch::Utils::Data::DataLoader.new(test_dataset, batch_size: options[:batch_size], shuffle: false) if rank.zero?
|
|
195
|
+
checkpoint_path = options[:checkpoint_path]
|
|
196
|
+
|
|
197
|
+
if options[:resume]
|
|
198
|
+
loaded = load_checkpoint_if_present(ddp, device, rank, checkpoint_path)
|
|
199
|
+
if rank.zero?
|
|
200
|
+
if loaded
|
|
201
|
+
puts "Loaded checkpoint weights from #{checkpoint_path}"
|
|
202
|
+
else
|
|
203
|
+
puts "No checkpoint found at #{checkpoint_path}, starting from random initialization"
|
|
204
|
+
end
|
|
205
|
+
end
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
options[:epochs].times do |epoch_idx|
|
|
209
|
+
epoch = epoch_idx + 1
|
|
210
|
+
train_epoch(ddp, device, train_loader, optimizer, epoch, rank, options[:log_interval])
|
|
211
|
+
if rank.zero?
|
|
212
|
+
evaluate(ddp.module, device, test_loader)
|
|
213
|
+
end
|
|
214
|
+
save_checkpoint(ddp, checkpoint_path, rank) if checkpoint_path
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
Torch::Distributed.destroy_process_group
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
options = parse_options
|
|
221
|
+
world_size = options[:gpus]
|
|
222
|
+
raise "Number of GPUs requested must be >= 1" if world_size < 1
|
|
223
|
+
if Torch::CUDA.available?
|
|
224
|
+
max_devices = Torch::CUDA.device_count
|
|
225
|
+
if world_size > max_devices
|
|
226
|
+
raise "Requested #{world_size} GPUs but only #{max_devices} visible"
|
|
227
|
+
end
|
|
228
|
+
else
|
|
229
|
+
puts "CUDA not available, running #{world_size} CPU workers"
|
|
230
|
+
end
|
|
231
|
+
|
|
232
|
+
Torch.manual_seed(1)
|
|
233
|
+
|
|
234
|
+
if world_size == 1
|
|
235
|
+
run_worker(0, 1, Torch::Distributed.free_port, options)
|
|
236
|
+
else
|
|
237
|
+
Torch::Distributed.fork_world(world_size, start_method: :spawn) do |rank, port|
|
|
238
|
+
run_worker(rank, world_size, port, options)
|
|
239
|
+
end
|
|
240
|
+
end
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
#include <algorithm>
|
|
2
|
+
#include <chrono>
|
|
3
|
+
#include <cctype>
|
|
4
|
+
#include <cstdlib>
|
|
5
|
+
#include <memory>
|
|
6
|
+
#include <mutex>
|
|
7
|
+
#include <string>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
#include <torch/torch.h>
|
|
11
|
+
#if defined(USE_C10D) && defined(USE_C10D_NCCL)
|
|
12
|
+
#include <torch/cuda.h>
|
|
13
|
+
#include <c10/cuda/CUDAFunctions.h>
|
|
14
|
+
#endif
|
|
15
|
+
|
|
16
|
+
#include <rice/rice.hpp>
|
|
17
|
+
#include <rice/stl.hpp>
|
|
18
|
+
|
|
19
|
+
static_assert(
|
|
20
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 9,
|
|
21
|
+
"Incompatible LibTorch version");
|
|
22
|
+
|
|
23
|
+
#ifdef USE_C10D
|
|
24
|
+
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
|
25
|
+
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
26
|
+
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
|
27
|
+
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
|
28
|
+
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
|
29
|
+
#include <torch/csrc/distributed/c10d/Work.hpp>
|
|
30
|
+
#include <torch/csrc/distributed/c10d/Types.hpp>
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
#if defined(USE_C10D) && defined(USE_C10D_NCCL)
|
|
34
|
+
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
#if defined(USE_C10D) && !defined(_WIN32)
|
|
38
|
+
#include <torch/csrc/distributed/c10d/HashStore.hpp>
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
namespace {
|
|
42
|
+
|
|
43
|
+
#ifdef USE_C10D
|
|
44
|
+
|
|
45
|
+
using StorePtr = c10::intrusive_ptr<::c10d::Store>;
|
|
46
|
+
using ProcessGroupPtr = c10::intrusive_ptr<::c10d::Backend>;
|
|
47
|
+
|
|
48
|
+
struct StoreWrapper {
|
|
49
|
+
StoreWrapper() = default;
|
|
50
|
+
explicit StoreWrapper(StorePtr store) : store_(std::move(store)) {}
|
|
51
|
+
|
|
52
|
+
StorePtr store_;
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
struct ProcessGroupWrapper {
|
|
56
|
+
ProcessGroupWrapper() = default;
|
|
57
|
+
explicit ProcessGroupWrapper(ProcessGroupPtr pg) : pg_(std::move(pg)) {}
|
|
58
|
+
|
|
59
|
+
ProcessGroupPtr pg_;
|
|
60
|
+
};
|
|
61
|
+
|
|
62
|
+
ProcessGroupPtr default_process_group;
|
|
63
|
+
std::once_flag default_pg_cleanup_once;
|
|
64
|
+
|
|
65
|
+
void shutdown_default_process_group() {
|
|
66
|
+
if (default_process_group) {
|
|
67
|
+
try {
|
|
68
|
+
default_process_group->shutdown();
|
|
69
|
+
} catch (...) {
|
|
70
|
+
// best effort; ensure reset still happens
|
|
71
|
+
}
|
|
72
|
+
default_process_group.reset();
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
void register_default_pg_cleanup() {
|
|
77
|
+
std::call_once(default_pg_cleanup_once, []() {
|
|
78
|
+
std::atexit([]() { shutdown_default_process_group(); });
|
|
79
|
+
});
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
ProcessGroupPtr resolve_process_group(Rice::Object pg_obj) {
|
|
83
|
+
if (pg_obj.is_nil()) {
|
|
84
|
+
if (!default_process_group) {
|
|
85
|
+
rb_raise(rb_eRuntimeError, "Distributed process group not initialized");
|
|
86
|
+
}
|
|
87
|
+
return default_process_group;
|
|
88
|
+
}
|
|
89
|
+
auto& wrapper = Rice::detail::From_Ruby<ProcessGroupWrapper&>().convert(pg_obj.value());
|
|
90
|
+
if (!wrapper.pg_) {
|
|
91
|
+
rb_raise(rb_eRuntimeError, "Invalid process group");
|
|
92
|
+
}
|
|
93
|
+
return wrapper.pg_;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
int reduce_op_from_int(int code) {
|
|
97
|
+
if (code < 0 || code > static_cast<int>(::c10d::ReduceOp::UNUSED)) {
|
|
98
|
+
rb_raise(rb_eArgError, "Unknown reduce op code");
|
|
99
|
+
}
|
|
100
|
+
return code;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
#endif
|
|
104
|
+
|
|
105
|
+
} // namespace
|
|
106
|
+
|
|
107
|
+
void init_distributed(Rice::Module& m) {
|
|
108
|
+
auto rb_mDistributed = Rice::define_module_under(m, "Distributed");
|
|
109
|
+
#ifdef USE_C10D
|
|
110
|
+
register_default_pg_cleanup();
|
|
111
|
+
rb_mDistributed.define_singleton_function("available?", []() { return true; });
|
|
112
|
+
|
|
113
|
+
auto rb_cStore = Rice::define_class_under<StoreWrapper>(rb_mDistributed, "Store");
|
|
114
|
+
rb_cStore.define_method(
|
|
115
|
+
"_native?",
|
|
116
|
+
[](StoreWrapper& self) {
|
|
117
|
+
return static_cast<bool>(self.store_);
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
auto rb_cProcessGroup = Rice::define_class_under<ProcessGroupWrapper>(rb_mDistributed, "ProcessGroup")
|
|
121
|
+
.define_method(
|
|
122
|
+
"rank",
|
|
123
|
+
[](ProcessGroupWrapper& self) {
|
|
124
|
+
return self.pg_ ? self.pg_->getRank() : -1;
|
|
125
|
+
})
|
|
126
|
+
.define_method(
|
|
127
|
+
"size",
|
|
128
|
+
[](ProcessGroupWrapper& self) {
|
|
129
|
+
return self.pg_ ? self.pg_->getSize() : 0;
|
|
130
|
+
})
|
|
131
|
+
.define_method(
|
|
132
|
+
"backend",
|
|
133
|
+
[](ProcessGroupWrapper& self) {
|
|
134
|
+
if (!self.pg_) {
|
|
135
|
+
return std::string();
|
|
136
|
+
}
|
|
137
|
+
return self.pg_->getBackendName();
|
|
138
|
+
});
|
|
139
|
+
|
|
140
|
+
rb_mDistributed.define_singleton_function(
|
|
141
|
+
"_create_tcp_store",
|
|
142
|
+
[rb_cStore](const std::string& host,
|
|
143
|
+
int port,
|
|
144
|
+
int world_size,
|
|
145
|
+
bool is_master,
|
|
146
|
+
int64_t timeout_millis,
|
|
147
|
+
bool wait_for_workers) -> Rice::Object {
|
|
148
|
+
::c10d::TCPStoreOptions opts;
|
|
149
|
+
opts.port = static_cast<uint16_t>(port);
|
|
150
|
+
opts.isServer = is_master;
|
|
151
|
+
opts.numWorkers = world_size;
|
|
152
|
+
opts.waitWorkers = wait_for_workers;
|
|
153
|
+
opts.timeout = std::chrono::milliseconds(timeout_millis);
|
|
154
|
+
auto store = c10::make_intrusive<::c10d::TCPStore>(host, opts);
|
|
155
|
+
// Pass ownership first, then the Ruby class so Rice doesn't treat the class as the owner flag
|
|
156
|
+
return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
|
|
157
|
+
});
|
|
158
|
+
|
|
159
|
+
rb_mDistributed.define_singleton_function(
|
|
160
|
+
"_create_file_store",
|
|
161
|
+
[rb_cStore](const std::string& path, int world_size) -> Rice::Object {
|
|
162
|
+
auto store = c10::make_intrusive<::c10d::FileStore>(path, world_size);
|
|
163
|
+
return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
|
|
164
|
+
});
|
|
165
|
+
|
|
166
|
+
#if !defined(_WIN32)
|
|
167
|
+
rb_mDistributed.define_singleton_function(
|
|
168
|
+
"_create_hash_store",
|
|
169
|
+
[rb_cStore]() -> Rice::Object {
|
|
170
|
+
auto store = c10::make_intrusive<::c10d::HashStore>();
|
|
171
|
+
return Rice::Data_Object<StoreWrapper>(new StoreWrapper(store), true, rb_cStore);
|
|
172
|
+
});
|
|
173
|
+
#endif
|
|
174
|
+
|
|
175
|
+
rb_mDistributed.define_singleton_function(
|
|
176
|
+
"_init_process_group",
|
|
177
|
+
[rb_cProcessGroup](const std::string& backend,
|
|
178
|
+
StoreWrapper& store_wrapper,
|
|
179
|
+
int rank,
|
|
180
|
+
int world_size,
|
|
181
|
+
int64_t timeout_millis,
|
|
182
|
+
int device_id) -> Rice::Object {
|
|
183
|
+
StorePtr store = store_wrapper.store_;
|
|
184
|
+
if (!store) {
|
|
185
|
+
rb_raise(rb_eArgError, "Store is required for init_process_group");
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
std::string backend_lower = backend;
|
|
189
|
+
std::transform(backend_lower.begin(), backend_lower.end(), backend_lower.begin(), ::tolower);
|
|
190
|
+
|
|
191
|
+
ProcessGroupPtr pg;
|
|
192
|
+
if (backend_lower == "gloo") {
|
|
193
|
+
#ifdef USE_C10D_GLOO
|
|
194
|
+
auto options = ::c10d::ProcessGroupGloo::Options::create();
|
|
195
|
+
options->timeout = std::chrono::milliseconds(timeout_millis);
|
|
196
|
+
options->devices.push_back(::c10d::ProcessGroupGloo::createDefaultDevice());
|
|
197
|
+
pg = c10::make_intrusive<::c10d::ProcessGroupGloo>(store, rank, world_size, options);
|
|
198
|
+
#else
|
|
199
|
+
rb_raise(rb_eRuntimeError, "Gloo backend is not available in this build");
|
|
200
|
+
#endif
|
|
201
|
+
} else if (backend_lower == "nccl") {
|
|
202
|
+
#if defined(USE_C10D_NCCL)
|
|
203
|
+
auto options = c10::make_intrusive<::c10d::ProcessGroupNCCL::Options>();
|
|
204
|
+
options->timeout = std::chrono::milliseconds(timeout_millis);
|
|
205
|
+
pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(store, rank, world_size, options);
|
|
206
|
+
#else
|
|
207
|
+
rb_raise(rb_eRuntimeError, "NCCL backend is not available in this build");
|
|
208
|
+
#endif
|
|
209
|
+
} else {
|
|
210
|
+
rb_raise(rb_eArgError, "Unsupported backend: %s", backend.c_str());
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if (device_id >= 0 && backend_lower == "nccl") {
|
|
214
|
+
#if defined(USE_C10D_NCCL)
|
|
215
|
+
if (!torch::cuda::is_available()) {
|
|
216
|
+
rb_raise(rb_eRuntimeError, "CUDA is not available for NCCL backend");
|
|
217
|
+
}
|
|
218
|
+
auto device_count = torch::cuda::device_count();
|
|
219
|
+
if (device_id >= static_cast<int>(device_count)) {
|
|
220
|
+
rb_raise(
|
|
221
|
+
rb_eArgError,
|
|
222
|
+
"Invalid device_id %d for NCCL backend (available devices: %d)",
|
|
223
|
+
device_id,
|
|
224
|
+
static_cast<int>(device_count));
|
|
225
|
+
}
|
|
226
|
+
c10::cuda::set_device(device_id);
|
|
227
|
+
pg->setBoundDeviceId(c10::Device(c10::kCUDA, device_id));
|
|
228
|
+
#endif
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
default_process_group = pg;
|
|
232
|
+
return Rice::Data_Object<ProcessGroupWrapper>(new ProcessGroupWrapper(pg), true, rb_cProcessGroup);
|
|
233
|
+
});
|
|
234
|
+
|
|
235
|
+
rb_mDistributed.define_singleton_function(
|
|
236
|
+
"_destroy_process_group",
|
|
237
|
+
[]() {
|
|
238
|
+
shutdown_default_process_group();
|
|
239
|
+
return Rice::Nil;
|
|
240
|
+
});
|
|
241
|
+
|
|
242
|
+
rb_mDistributed.define_singleton_function(
|
|
243
|
+
"_initialized?",
|
|
244
|
+
[]() {
|
|
245
|
+
return static_cast<bool>(default_process_group);
|
|
246
|
+
});
|
|
247
|
+
|
|
248
|
+
rb_mDistributed.define_singleton_function(
|
|
249
|
+
"_default_process_group",
|
|
250
|
+
[rb_cProcessGroup]() -> Rice::Object {
|
|
251
|
+
if (!default_process_group) {
|
|
252
|
+
return Rice::Nil;
|
|
253
|
+
}
|
|
254
|
+
return Rice::Data_Object<ProcessGroupWrapper>(new ProcessGroupWrapper(default_process_group), true, rb_cProcessGroup);
|
|
255
|
+
});
|
|
256
|
+
|
|
257
|
+
rb_mDistributed.define_singleton_function(
|
|
258
|
+
"_get_world_size",
|
|
259
|
+
[](Rice::Object pg_obj) {
|
|
260
|
+
auto pg = resolve_process_group(pg_obj);
|
|
261
|
+
return pg->getSize();
|
|
262
|
+
});
|
|
263
|
+
|
|
264
|
+
rb_mDistributed.define_singleton_function(
|
|
265
|
+
"_get_rank",
|
|
266
|
+
[](Rice::Object pg_obj) {
|
|
267
|
+
auto pg = resolve_process_group(pg_obj);
|
|
268
|
+
return pg->getRank();
|
|
269
|
+
});
|
|
270
|
+
|
|
271
|
+
rb_mDistributed.define_singleton_function(
|
|
272
|
+
"_barrier",
|
|
273
|
+
[](Rice::Object pg_obj) {
|
|
274
|
+
auto pg = resolve_process_group(pg_obj);
|
|
275
|
+
::c10d::BarrierOptions opts;
|
|
276
|
+
auto work = pg->barrier(opts);
|
|
277
|
+
work->wait();
|
|
278
|
+
return Rice::Nil;
|
|
279
|
+
});
|
|
280
|
+
|
|
281
|
+
rb_mDistributed.define_singleton_function(
|
|
282
|
+
"_all_reduce",
|
|
283
|
+
[](torch::Tensor& tensor, int op_code, Rice::Object pg_obj) {
|
|
284
|
+
auto pg = resolve_process_group(pg_obj);
|
|
285
|
+
::c10d::AllreduceOptions opts;
|
|
286
|
+
opts.reduceOp = ::c10d::ReduceOp(static_cast<::c10d::ReduceOp::RedOpType>(reduce_op_from_int(op_code)));
|
|
287
|
+
std::vector<at::Tensor> tensors{tensor};
|
|
288
|
+
auto work = pg->allreduce(tensors, opts);
|
|
289
|
+
work->wait();
|
|
290
|
+
return tensor;
|
|
291
|
+
});
|
|
292
|
+
|
|
293
|
+
rb_mDistributed.define_singleton_function(
|
|
294
|
+
"_broadcast",
|
|
295
|
+
[](torch::Tensor& tensor, int src, Rice::Object pg_obj) {
|
|
296
|
+
auto pg = resolve_process_group(pg_obj);
|
|
297
|
+
::c10d::BroadcastOptions opts;
|
|
298
|
+
opts.rootRank = src;
|
|
299
|
+
std::vector<at::Tensor> tensors{tensor};
|
|
300
|
+
auto work = pg->broadcast(tensors, opts);
|
|
301
|
+
work->wait();
|
|
302
|
+
return tensor;
|
|
303
|
+
});
|
|
304
|
+
|
|
305
|
+
rb_mDistributed.define_singleton_function(
|
|
306
|
+
"_register_ddp_hook",
|
|
307
|
+
[](torch::Tensor& tensor, ProcessGroupWrapper& pg_wrapper, int world_size) -> unsigned {
|
|
308
|
+
if (!pg_wrapper.pg_) {
|
|
309
|
+
rb_raise(rb_eArgError, "Process group is required for DDP hook registration");
|
|
310
|
+
}
|
|
311
|
+
if (world_size <= 0) {
|
|
312
|
+
rb_raise(rb_eArgError, "world_size must be positive");
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
auto pg = pg_wrapper.pg_;
|
|
316
|
+
// Register a native autograd hook that all-reduces gradients and scales
|
|
317
|
+
// them by the world size. This avoids calling back into Ruby from
|
|
318
|
+
// autograd worker threads.
|
|
319
|
+
unsigned handle = tensor.register_hook([pg, world_size](const at::Tensor& grad) {
|
|
320
|
+
::c10d::AllreduceOptions opts;
|
|
321
|
+
opts.reduceOp = ::c10d::ReduceOp::SUM;
|
|
322
|
+
std::vector<at::Tensor> tensors{grad};
|
|
323
|
+
auto work = pg->allreduce(tensors, opts);
|
|
324
|
+
work->wait();
|
|
325
|
+
grad.div_(static_cast<double>(world_size));
|
|
326
|
+
return grad;
|
|
327
|
+
});
|
|
328
|
+
|
|
329
|
+
return handle;
|
|
330
|
+
});
|
|
331
|
+
|
|
332
|
+
auto rb_mReduceOp = Rice::define_module_under(rb_mDistributed, "ReduceOp");
|
|
333
|
+
rb_mReduceOp.const_set("SUM", INT2NUM(static_cast<int>(::c10d::ReduceOp::SUM)));
|
|
334
|
+
rb_mReduceOp.const_set("AVG", INT2NUM(static_cast<int>(::c10d::ReduceOp::AVG)));
|
|
335
|
+
rb_mReduceOp.const_set("PRODUCT", INT2NUM(static_cast<int>(::c10d::ReduceOp::PRODUCT)));
|
|
336
|
+
rb_mReduceOp.const_set("MIN", INT2NUM(static_cast<int>(::c10d::ReduceOp::MIN)));
|
|
337
|
+
rb_mReduceOp.const_set("MAX", INT2NUM(static_cast<int>(::c10d::ReduceOp::MAX)));
|
|
338
|
+
rb_mReduceOp.const_set("BAND", INT2NUM(static_cast<int>(::c10d::ReduceOp::BAND)));
|
|
339
|
+
rb_mReduceOp.const_set("BOR", INT2NUM(static_cast<int>(::c10d::ReduceOp::BOR)));
|
|
340
|
+
rb_mReduceOp.const_set("BXOR", INT2NUM(static_cast<int>(::c10d::ReduceOp::BXOR)));
|
|
341
|
+
rb_mReduceOp.const_set("PREMUL_SUM", INT2NUM(static_cast<int>(::c10d::ReduceOp::PREMUL_SUM)));
|
|
342
|
+
|
|
343
|
+
rb_mDistributed.const_set("DEFAULT_TIMEOUT", INT2NUM(::kProcessGroupDefaultTimeout.count() / 1000));
|
|
344
|
+
#else
|
|
345
|
+
rb_mDistributed.define_singleton_function("available?", []() { return false; });
|
|
346
|
+
rb_mDistributed.const_set("DEFAULT_TIMEOUT", INT2NUM(30 * 60));
|
|
347
|
+
#endif
|
|
348
|
+
}
|