torch-dl 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +21 -0
- data/README.md +115 -0
- data/lib/torch_dl/cuda.rb +76 -0
- data/lib/torch_dl/ext.rb +20 -0
- data/lib/torch_dl/nn/parallel/data_parallel.rb +273 -0
- data/lib/torch_dl/nn/parallel/parallel_apply.rb +60 -0
- data/lib/torch_dl/nn/parallel/replicate.rb +178 -0
- data/lib/torch_dl/version.rb +3 -0
- data/lib/torch_dl.rb +24 -0
- metadata +84 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 395b034959ba88dc532f1b9903dfcbf5404902117d5b833c112ee7e3851b20e5
|
|
4
|
+
data.tar.gz: f7dce1d7003eafc711e172f0877d7929b01055d879a4057314ba22b26cbbb53e
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: 2bb333d0eb7bc3d2e9785d7a66e82d1340cc25547d760296cb0ae47f20ba5517f4fa2c03709c81868612d1313ac6ed4e4e4808041d20de6a66a3e4e00c9f8454
|
|
7
|
+
data.tar.gz: 1c78b7ada3a7f669db54ea24fe385c0e2811418ce1546836db6fa31549979e14fee9c1466b5dd34b79eae0541705a4db7cf4de59d66ccc19a4a221ad17cf26f6
|
data/CHANGELOG.md
ADDED
data/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Chris Hasinski
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# torch-dl
|
|
2
|
+
|
|
3
|
+
Multi-GPU training for [torch.rb](https://github.com/ankane/torch.rb). Split batches across multiple GPUs automatically.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
Add to your Gemfile:
|
|
8
|
+
|
|
9
|
+
```ruby
|
|
10
|
+
gem "torch-dl"
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Usage
|
|
14
|
+
|
|
15
|
+
### Basic Usage
|
|
16
|
+
|
|
17
|
+
```ruby
|
|
18
|
+
require "torch_dl"
|
|
19
|
+
|
|
20
|
+
# Create your model on the first GPU
|
|
21
|
+
model = MyModel.new.to("cuda:0")
|
|
22
|
+
|
|
23
|
+
# Wrap with DataParallel
|
|
24
|
+
dp_model = DataParallel.new(model, device_ids: [0, 1])
|
|
25
|
+
|
|
26
|
+
# Training loop
|
|
27
|
+
optimizer.zero_grad
|
|
28
|
+
output = dp_model.call(input)
|
|
29
|
+
loss = criterion.call(output, target)
|
|
30
|
+
loss.backward
|
|
31
|
+
optimizer.step
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Models That Return Loss
|
|
35
|
+
|
|
36
|
+
If your model returns a scalar loss (e.g., GPT models returning `[logits, loss]`), use `dp_model.backward` instead of `loss.backward`:
|
|
37
|
+
|
|
38
|
+
```ruby
|
|
39
|
+
optimizer.zero_grad
|
|
40
|
+
logits, loss = dp_model.call(input, targets: targets)
|
|
41
|
+
dp_model.backward(scale: 1.0)
|
|
42
|
+
optimizer.step
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
This is necessary because gathering scalar tensors across devices breaks the autograd graph. The `backward` method calls backward on each replica's loss separately, then reduces gradients to the original module.
|
|
46
|
+
|
|
47
|
+
### Gradient Accumulation
|
|
48
|
+
|
|
49
|
+
For gradient accumulation, scale the backward pass:
|
|
50
|
+
|
|
51
|
+
```ruby
|
|
52
|
+
gradient_accumulation_steps = 4
|
|
53
|
+
|
|
54
|
+
(0...gradient_accumulation_steps).each do |step|
|
|
55
|
+
logits, loss = dp_model.call(input_batch, targets: targets_batch)
|
|
56
|
+
dp_model.backward(scale: 1.0 / gradient_accumulation_steps)
|
|
57
|
+
end
|
|
58
|
+
optimizer.step
|
|
59
|
+
optimizer.zero_grad
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## API Reference
|
|
63
|
+
|
|
64
|
+
### DataParallel
|
|
65
|
+
|
|
66
|
+
```ruby
|
|
67
|
+
DataParallel.new(model, device_ids: nil, output_device: nil, dim: 0)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
- `model` - The module to parallelize (must be on `cuda:0`)
|
|
71
|
+
- `device_ids` - Array of GPU indices to use (default: all available)
|
|
72
|
+
- `output_device` - GPU index for output (default: first device)
|
|
73
|
+
- `dim` - Dimension to scatter inputs along (default: 0, batch dimension)
|
|
74
|
+
|
|
75
|
+
#### Methods
|
|
76
|
+
|
|
77
|
+
- `call(*inputs, **kwargs)` - Forward pass with automatic scattering/gathering
|
|
78
|
+
- `backward(scale: 1.0)` - Backward pass for loss-returning models
|
|
79
|
+
- `module` / `wrapped_module` - Access the underlying model
|
|
80
|
+
- `parameters` - Access model parameters (for optimizer)
|
|
81
|
+
- `state_dict` / `load_state_dict` - Save/load model state
|
|
82
|
+
- `train` / `eval` - Set training/evaluation mode
|
|
83
|
+
|
|
84
|
+
### Low-level Functions
|
|
85
|
+
|
|
86
|
+
```ruby
|
|
87
|
+
# Split tensor across devices
|
|
88
|
+
TorchDL.scatter(tensor, ["cuda:0", "cuda:1"], dim)
|
|
89
|
+
|
|
90
|
+
# Gather tensors to a single device
|
|
91
|
+
TorchDL.gather(tensors, "cuda:0", dim)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## How It Works
|
|
95
|
+
|
|
96
|
+
1. **Scatter**: Input batch is split across GPUs
|
|
97
|
+
2. **Replicate**: Model is copied to each GPU (cached for performance)
|
|
98
|
+
3. **Parallel Apply**: Forward pass runs on each GPU in parallel using threads
|
|
99
|
+
4. **Gather**: Outputs are collected back to the output device
|
|
100
|
+
|
|
101
|
+
## Requirements
|
|
102
|
+
|
|
103
|
+
- Ruby 3.1+
|
|
104
|
+
- torch.rb 0.17+
|
|
105
|
+
- Multiple CUDA GPUs
|
|
106
|
+
|
|
107
|
+
## Notes
|
|
108
|
+
|
|
109
|
+
- Works with stock torch.rb from RubyGems
|
|
110
|
+
- Uses FFI (fiddle) to call CUDA runtime directly for `synchronize`, `current_device`, and `set_device`
|
|
111
|
+
- No C extension required - pure Ruby gem
|
|
112
|
+
|
|
113
|
+
## License
|
|
114
|
+
|
|
115
|
+
MIT
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
require "fiddle"
|
|
2
|
+
|
|
3
|
+
module TorchDL
|
|
4
|
+
module CUDA
|
|
5
|
+
class << self
|
|
6
|
+
def synchronize
|
|
7
|
+
return unless Torch::CUDA.available?
|
|
8
|
+
|
|
9
|
+
@cudart ||= load_cudart
|
|
10
|
+
return unless @cudart
|
|
11
|
+
|
|
12
|
+
# cudaDeviceSynchronize returns cudaError_t (0 = success)
|
|
13
|
+
@cuda_device_synchronize ||= Fiddle::Function.new(
|
|
14
|
+
@cudart["cudaDeviceSynchronize"],
|
|
15
|
+
[],
|
|
16
|
+
Fiddle::TYPE_INT
|
|
17
|
+
)
|
|
18
|
+
@cuda_device_synchronize.call
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def current_device
|
|
22
|
+
return -1 unless Torch::CUDA.available?
|
|
23
|
+
|
|
24
|
+
@cudart ||= load_cudart
|
|
25
|
+
return -1 unless @cudart
|
|
26
|
+
|
|
27
|
+
@cuda_get_device ||= Fiddle::Function.new(
|
|
28
|
+
@cudart["cudaGetDevice"],
|
|
29
|
+
[Fiddle::TYPE_VOIDP],
|
|
30
|
+
Fiddle::TYPE_INT
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
device_ptr = Fiddle::Pointer.malloc(Fiddle::SIZEOF_INT)
|
|
34
|
+
@cuda_get_device.call(device_ptr)
|
|
35
|
+
device_ptr[0, Fiddle::SIZEOF_INT].unpack1("i")
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def set_device(device_id)
|
|
39
|
+
return unless Torch::CUDA.available?
|
|
40
|
+
|
|
41
|
+
@cudart ||= load_cudart
|
|
42
|
+
return unless @cudart
|
|
43
|
+
|
|
44
|
+
@cuda_set_device ||= Fiddle::Function.new(
|
|
45
|
+
@cudart["cudaSetDevice"],
|
|
46
|
+
[Fiddle::TYPE_INT],
|
|
47
|
+
Fiddle::TYPE_INT
|
|
48
|
+
)
|
|
49
|
+
@cuda_set_device.call(device_id)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
private
|
|
53
|
+
|
|
54
|
+
def load_cudart
|
|
55
|
+
# Try common CUDA runtime library paths
|
|
56
|
+
paths = [
|
|
57
|
+
"libcudart.so",
|
|
58
|
+
"libcudart.so.12",
|
|
59
|
+
"libcudart.so.11",
|
|
60
|
+
"/usr/local/cuda/lib64/libcudart.so",
|
|
61
|
+
"/usr/lib/x86_64-linux-gnu/libcudart.so"
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
paths.each do |path|
|
|
65
|
+
begin
|
|
66
|
+
return Fiddle.dlopen(path)
|
|
67
|
+
rescue Fiddle::DLError
|
|
68
|
+
next
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
nil
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
end
|
data/lib/torch_dl/ext.rb
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
module TorchDL
|
|
2
|
+
# Pure Ruby implementation of scatter/gather operations
|
|
3
|
+
module Ext
|
|
4
|
+
class << self
|
|
5
|
+
# Splits a tensor across devices
|
|
6
|
+
def scatter(input, devices, dim = 0)
|
|
7
|
+
chunks = input.chunk(devices.size, dim)
|
|
8
|
+
chunks.each_with_index.map do |chunk, i|
|
|
9
|
+
chunk.to(devices[i])
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Gathers tensors from multiple devices onto a single device
|
|
14
|
+
def gather(inputs, target_device, dim = 0)
|
|
15
|
+
on_target = inputs.map { |t| t.to(target_device) }
|
|
16
|
+
Torch.cat(on_target, dim)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
module TorchDL
|
|
2
|
+
module NN
|
|
3
|
+
# Implements data parallelism at the module level.
|
|
4
|
+
#
|
|
5
|
+
# This container parallelizes the application of the given module by
|
|
6
|
+
# splitting the input across the specified devices by chunking in the
|
|
7
|
+
# batch dimension. In the forward pass, the module is replicated on each
|
|
8
|
+
# device, and each replica handles a portion of the input. During the
|
|
9
|
+
# backwards pass, gradients from each replica are summed into the
|
|
10
|
+
# original module.
|
|
11
|
+
#
|
|
12
|
+
# @note Backward Pass for Models Returning Loss
|
|
13
|
+
# When your model returns a scalar loss (e.g., GPT models that return
|
|
14
|
+
# [logits, loss]), you must use the {#backward} method instead of
|
|
15
|
+
# calling loss.backward directly. This is because gathering scalar
|
|
16
|
+
# tensors across devices breaks the autograd graph in torch.rb.
|
|
17
|
+
#
|
|
18
|
+
# @example Training loop with loss-returning model
|
|
19
|
+
# dp_model = TorchDL::NN::DataParallel.new(model, device_ids: [0, 1])
|
|
20
|
+
# optimizer.zero_grad
|
|
21
|
+
# logits, loss = dp_model.call(input, targets: targets)
|
|
22
|
+
# dp_model.backward(scale: 1.0 / gradient_accumulation_steps)
|
|
23
|
+
# optimizer.step
|
|
24
|
+
#
|
|
25
|
+
# @example Basic usage (model returns output only)
|
|
26
|
+
# model = MyModel.new.to("cuda:0")
|
|
27
|
+
# dp_model = TorchDL::NN::DataParallel.new(model, device_ids: [0, 1])
|
|
28
|
+
# output = dp_model.call(input)
|
|
29
|
+
# loss = criterion.call(output, target)
|
|
30
|
+
# loss.backward # Standard backward works when loss computed after gather
|
|
31
|
+
#
|
|
32
|
+
class DataParallel < Torch::NN::Module
|
|
33
|
+
attr_reader :module, :device_ids, :output_device, :dim
|
|
34
|
+
alias_method :wrapped_module, :module
|
|
35
|
+
|
|
36
|
+
# @param mod [Torch::NN::Module] Module to parallelize
|
|
37
|
+
# @param device_ids [Array<Integer>, nil] CUDA devices (default: all available)
|
|
38
|
+
# @param output_device [Integer, nil] Device for output (default: device_ids[0])
|
|
39
|
+
# @param dim [Integer] Dimension to scatter inputs along (default: 0)
|
|
40
|
+
def initialize(mod, device_ids: nil, output_device: nil, dim: 0)
|
|
41
|
+
super()
|
|
42
|
+
@module = mod
|
|
43
|
+
@device_ids = device_ids || (0...Torch::CUDA.device_count).to_a
|
|
44
|
+
@output_device = output_device || @device_ids.first
|
|
45
|
+
@dim = dim
|
|
46
|
+
@replica_losses = nil
|
|
47
|
+
|
|
48
|
+
if @device_ids.empty?
|
|
49
|
+
raise ArgumentError, "device_ids cannot be empty"
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Convert to device strings for internal use
|
|
53
|
+
@device_strings = @device_ids.map { |id| "cuda:#{id}" }
|
|
54
|
+
@output_device_string = "cuda:#{@output_device}"
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def forward(*inputs, **kwargs)
|
|
58
|
+
# Empty input check
|
|
59
|
+
if inputs.empty?
|
|
60
|
+
return @module.call(**kwargs)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Single GPU fast path
|
|
64
|
+
if @device_ids.size == 1
|
|
65
|
+
return @module.call(*inputs.map { |i| i.to(@device_strings.first) }, **kwargs)
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# Scatter inputs across devices
|
|
69
|
+
scattered_inputs = scatter(inputs, @device_strings, @dim)
|
|
70
|
+
scattered_kwargs = scatter_kwargs(kwargs, @device_strings, @dim)
|
|
71
|
+
|
|
72
|
+
# Get or create replicas, sync weights
|
|
73
|
+
num_replicas = scattered_inputs.size
|
|
74
|
+
devices = @device_strings[0...num_replicas]
|
|
75
|
+
@replicas = get_replicas(devices)
|
|
76
|
+
sync_replica_weights
|
|
77
|
+
|
|
78
|
+
# Apply in parallel
|
|
79
|
+
outputs = parallel_apply(@replicas, scattered_inputs, scattered_kwargs)
|
|
80
|
+
|
|
81
|
+
# Ensure all CUDA operations complete before gathering
|
|
82
|
+
TorchDL::CUDA.synchronize if Torch::CUDA.available?
|
|
83
|
+
|
|
84
|
+
# Gather outputs back to output device
|
|
85
|
+
gather(outputs, @output_device_string, @dim)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# Performs backward pass on all replica losses and reduces gradients.
|
|
89
|
+
# This is needed because gather creates a new tensor that breaks the
|
|
90
|
+
# autograd connection across devices. By calling backward on each
|
|
91
|
+
# replica's loss separately, gradients flow properly.
|
|
92
|
+
#
|
|
93
|
+
# @param scale [Float] Scale factor for gradients (e.g., 1.0/gradient_accumulation_steps)
|
|
94
|
+
def backward(scale: 1.0)
|
|
95
|
+
if @replica_losses && @replica_losses.size > 1
|
|
96
|
+
# Each replica's loss contributes equally to total loss
|
|
97
|
+
# Scale by 1/N to average gradients across replicas
|
|
98
|
+
replica_scale = scale / @replica_losses.size
|
|
99
|
+
|
|
100
|
+
@replica_losses.each do |replica_loss|
|
|
101
|
+
# Scale the loss before backward to get properly scaled gradients
|
|
102
|
+
scaled_loss = replica_loss * replica_scale
|
|
103
|
+
scaled_loss.backward
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
@replica_losses = nil
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Reduce gradients from all replicas to the original module
|
|
110
|
+
reduce_gradients
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
# Reduce gradients from replicas back to the original module.
|
|
114
|
+
# Called automatically by backward(), but can be called manually if needed.
|
|
115
|
+
def reduce_gradients
|
|
116
|
+
return if @replicas.nil? || @replicas.size <= 1
|
|
117
|
+
|
|
118
|
+
# Get original model's parameters (first replica is the original)
|
|
119
|
+
original_params = @module.parameters.to_a
|
|
120
|
+
|
|
121
|
+
# Accumulate gradients from other replicas
|
|
122
|
+
@replicas[1..].each do |replica|
|
|
123
|
+
replica_params = replica.parameters.to_a
|
|
124
|
+
original_params.zip(replica_params).each do |orig, repl|
|
|
125
|
+
next unless repl.grad
|
|
126
|
+
|
|
127
|
+
# Move replica gradient to original device and add
|
|
128
|
+
if orig.grad
|
|
129
|
+
orig.grad.add!(repl.grad.to(orig.device))
|
|
130
|
+
else
|
|
131
|
+
orig.grad = repl.grad.to(orig.device).clone
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Delegate training mode to wrapped module
|
|
138
|
+
def train(mode = true)
|
|
139
|
+
super
|
|
140
|
+
@module.train(mode)
|
|
141
|
+
self
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
def eval
|
|
145
|
+
train(false)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Delegate parameter access to wrapped module
|
|
149
|
+
def parameters
|
|
150
|
+
@module.parameters
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
def named_parameters(prefix: "", recurse: true)
|
|
154
|
+
@module.named_parameters(prefix: prefix, recurse: recurse)
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
def state_dict(destination: nil, prefix: "")
|
|
158
|
+
@module.state_dict(destination: destination, prefix: prefix)
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
def load_state_dict(state_dict, strict: true)
|
|
162
|
+
@module.load_state_dict(state_dict, strict: strict)
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def extra_inspect
|
|
166
|
+
format("device_ids: %s, output_device: %s, dim: %d", @device_ids, @output_device, @dim)
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
private
|
|
170
|
+
|
|
171
|
+
# Scatter a single value across devices
|
|
172
|
+
def scatter_value(value, devices, dim)
|
|
173
|
+
if value.is_a?(Torch::Tensor)
|
|
174
|
+
TorchDL::Ext.scatter(value, devices, dim).map(&:contiguous)
|
|
175
|
+
else
|
|
176
|
+
devices.map { value }
|
|
177
|
+
end
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
def scatter(inputs, devices, dim)
|
|
181
|
+
# Scatter each input, then transpose to group by device
|
|
182
|
+
scattered = inputs.map { |input| scatter_value(input, devices, dim) }
|
|
183
|
+
scattered.first.size.times.map { |i| scattered.map { |s| s[i] } }
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
def scatter_kwargs(kwargs, devices, dim)
|
|
187
|
+
return devices.map { {} } if kwargs.empty?
|
|
188
|
+
scattered = kwargs.transform_values { |v| scatter_value(v, devices, dim) }
|
|
189
|
+
devices.size.times.map { |i| scattered.transform_values { |v| v[i] } }
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
def get_replicas(devices)
|
|
193
|
+
# Return cached replicas if they match the requested devices
|
|
194
|
+
if @replicas && @replica_devices == devices
|
|
195
|
+
return @replicas
|
|
196
|
+
end
|
|
197
|
+
|
|
198
|
+
# Create new replicas and cache them
|
|
199
|
+
@replica_devices = devices
|
|
200
|
+
Replicate.replicate(@module, devices)
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
def sync_replica_weights
|
|
204
|
+
return if @replicas.nil? || @replicas.size <= 1
|
|
205
|
+
|
|
206
|
+
# Sync parameters in-place to avoid memory allocation
|
|
207
|
+
original_params = @module.parameters.to_a
|
|
208
|
+
@replicas[1..].each do |replica|
|
|
209
|
+
replica.parameters.to_a.each_with_index do |param, i|
|
|
210
|
+
param.data.copy!(original_params[i].data)
|
|
211
|
+
end
|
|
212
|
+
end
|
|
213
|
+
end
|
|
214
|
+
|
|
215
|
+
def parallel_apply(replicas, inputs, kwargs_list)
|
|
216
|
+
ParallelApply.parallel_apply(replicas, inputs, kwargs_list)
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
def gather(outputs, target_device, dim)
|
|
220
|
+
# Handle different output types
|
|
221
|
+
first = outputs.first
|
|
222
|
+
|
|
223
|
+
case first
|
|
224
|
+
when Torch::Tensor
|
|
225
|
+
gather_tensor(outputs, target_device, dim)
|
|
226
|
+
when Array
|
|
227
|
+
# Tuple/array of tensors - gather each element
|
|
228
|
+
first.size.times.map do |i|
|
|
229
|
+
tensors = outputs.map { |o| o[i] }
|
|
230
|
+
if tensors.first.is_a?(Torch::Tensor)
|
|
231
|
+
gather_tensor(tensors, target_device, dim)
|
|
232
|
+
elsif tensors.first.nil?
|
|
233
|
+
nil
|
|
234
|
+
else
|
|
235
|
+
tensors.first # Non-tensor, just return first
|
|
236
|
+
end
|
|
237
|
+
end
|
|
238
|
+
when Hash
|
|
239
|
+
# Dict of tensors - gather each value
|
|
240
|
+
first.keys.map do |key|
|
|
241
|
+
tensors = outputs.map { |o| o[key] }
|
|
242
|
+
if tensors.first.is_a?(Torch::Tensor)
|
|
243
|
+
[key, gather_tensor(tensors, target_device, dim)]
|
|
244
|
+
else
|
|
245
|
+
[key, tensors.first]
|
|
246
|
+
end
|
|
247
|
+
end.to_h
|
|
248
|
+
else
|
|
249
|
+
# Scalar or other - return first output
|
|
250
|
+
first
|
|
251
|
+
end
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
def gather_tensor(tensors, target_device, dim)
|
|
255
|
+
first = tensors.first
|
|
256
|
+
# Handle scalar tensors (0-dim) - store for backward, return average for display
|
|
257
|
+
if first.dim == 0
|
|
258
|
+
# Store individual losses for backward() method
|
|
259
|
+
@replica_losses = tensors
|
|
260
|
+
|
|
261
|
+
# Return mean of losses (moved to same device) for logging/display
|
|
262
|
+
# Note: This returned tensor should NOT be used for backward() - use backward() instead
|
|
263
|
+
sum = tensors.reduce(Torch.tensor(0.0, device: target_device)) do |acc, t|
|
|
264
|
+
acc + t.to(target_device).detach
|
|
265
|
+
end
|
|
266
|
+
sum / tensors.size
|
|
267
|
+
else
|
|
268
|
+
TorchDL::Ext.gather(tensors, target_device, dim)
|
|
269
|
+
end
|
|
270
|
+
end
|
|
271
|
+
end
|
|
272
|
+
end
|
|
273
|
+
end
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
module TorchDL
|
|
2
|
+
module NN
|
|
3
|
+
module ParallelApply
|
|
4
|
+
class << self
|
|
5
|
+
# Applies modules to inputs in parallel across devices.
|
|
6
|
+
#
|
|
7
|
+
# @param modules [Array<Torch::NN::Module>] List of module replicas
|
|
8
|
+
# @param inputs [Array] List of inputs, one per module
|
|
9
|
+
# @param kwargs_list [Array<Hash>, nil] Optional list of kwargs, one per module
|
|
10
|
+
# @return [Array] List of outputs, one per module
|
|
11
|
+
def parallel_apply(modules, inputs, kwargs_list = nil)
|
|
12
|
+
kwargs_list ||= modules.map { {} }
|
|
13
|
+
|
|
14
|
+
unless modules.size == inputs.size && modules.size == kwargs_list.size
|
|
15
|
+
raise ArgumentError, "modules, inputs, and kwargs_list must have the same length"
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Single module - no parallelism needed
|
|
19
|
+
if modules.size == 1
|
|
20
|
+
return [apply_module(modules[0], inputs[0], kwargs_list[0])]
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
parallel_apply_threads(modules, inputs, kwargs_list)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
private
|
|
27
|
+
|
|
28
|
+
def parallel_apply_threads(modules, inputs, kwargs_list)
|
|
29
|
+
results = Array.new(modules.size)
|
|
30
|
+
errors = Array.new(modules.size)
|
|
31
|
+
|
|
32
|
+
threads = modules.each_with_index.map do |mod, i|
|
|
33
|
+
Thread.new(mod, inputs[i], kwargs_list[i], i) do |m, inp, kw, idx|
|
|
34
|
+
results[idx] = apply_module(m, inp, kw)
|
|
35
|
+
rescue => e
|
|
36
|
+
errors[idx] = e
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
threads.each(&:join)
|
|
41
|
+
|
|
42
|
+
# Re-raise first error if any
|
|
43
|
+
errors.each_with_index do |err, i|
|
|
44
|
+
raise err if err
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
results
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def apply_module(mod, input, kwargs)
|
|
51
|
+
if input.is_a?(Array)
|
|
52
|
+
mod.call(*input, **kwargs)
|
|
53
|
+
else
|
|
54
|
+
mod.call(input, **kwargs)
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
end
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
module TorchDL
|
|
2
|
+
module NN
|
|
3
|
+
module Replicate
|
|
4
|
+
class << self
|
|
5
|
+
# Replicates a module on multiple devices.
|
|
6
|
+
#
|
|
7
|
+
# @param network [Torch::NN::Module] The module to replicate
|
|
8
|
+
# @param devices [Array<String>] List of device strings (e.g., ["cuda:0", "cuda:1"])
|
|
9
|
+
# @return [Array<Torch::NN::Module>] List of module replicas, one per device
|
|
10
|
+
#
|
|
11
|
+
# Note: The first device uses the original network (not a copy) to ensure
|
|
12
|
+
# gradients flow back to the original parameters during backward pass.
|
|
13
|
+
def replicate(network, devices)
|
|
14
|
+
devices = devices.map { |d| d.is_a?(String) ? d : "cuda:#{d}" }
|
|
15
|
+
|
|
16
|
+
# Single device - just return the network (already on correct device)
|
|
17
|
+
return [network] if devices.size == 1
|
|
18
|
+
|
|
19
|
+
# Get the state dict once for creating replicas
|
|
20
|
+
state_dict = network.state_dict
|
|
21
|
+
|
|
22
|
+
# Create replicas - first device uses original network for gradient flow
|
|
23
|
+
devices.each_with_index.map do |device, idx|
|
|
24
|
+
if idx == 0
|
|
25
|
+
# First device: use the original network to maintain gradient connection
|
|
26
|
+
network
|
|
27
|
+
else
|
|
28
|
+
# Other devices: create independent replicas
|
|
29
|
+
replica = deep_copy_module(network)
|
|
30
|
+
|
|
31
|
+
# Copy state dict tensors to the target device
|
|
32
|
+
# Filter to only include keys that exist in the replica
|
|
33
|
+
replica_keys = replica.state_dict.keys
|
|
34
|
+
device_state = state_dict.select { |k, _| replica_keys.include?(k) }
|
|
35
|
+
.transform_values { |t| t.to(device) }
|
|
36
|
+
replica.load_state_dict(device_state)
|
|
37
|
+
replica.to(device)
|
|
38
|
+
replica.train(network.instance_variable_get(:@training))
|
|
39
|
+
replica
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
private
|
|
45
|
+
|
|
46
|
+
# Creates a deep copy of a module structure
|
|
47
|
+
def deep_copy_module(mod)
|
|
48
|
+
# Check for custom replication hook first
|
|
49
|
+
return mod.class._replicate(mod) if mod.class.respond_to?(:_replicate)
|
|
50
|
+
|
|
51
|
+
# Handle container modules
|
|
52
|
+
return copy_sequential(mod) if mod.is_a?(Torch::NN::Sequential)
|
|
53
|
+
return copy_module_list(mod) if mod.is_a?(Torch::NN::ModuleList)
|
|
54
|
+
|
|
55
|
+
# Try built-in module copiers
|
|
56
|
+
copier = module_copiers[mod.class]
|
|
57
|
+
return copier.call(mod) if copier
|
|
58
|
+
|
|
59
|
+
# Fallback to generic copy for custom modules
|
|
60
|
+
copy_custom_module(mod)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Lazily initialized registry of module copiers
|
|
64
|
+
def module_copiers
|
|
65
|
+
@module_copiers ||= {
|
|
66
|
+
Torch::NN::Linear => ->(m) { Torch::NN::Linear.new(m.in_features, m.out_features, bias: has_bias?(m)) },
|
|
67
|
+
Torch::NN::Conv1d => ->(m) { copy_conv(Torch::NN::Conv1d, m) },
|
|
68
|
+
Torch::NN::Conv2d => ->(m) { copy_conv(Torch::NN::Conv2d, m) },
|
|
69
|
+
Torch::NN::Conv3d => ->(m) { copy_conv(Torch::NN::Conv3d, m) },
|
|
70
|
+
Torch::NN::BatchNorm1d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm1d, m) },
|
|
71
|
+
Torch::NN::BatchNorm2d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm2d, m) },
|
|
72
|
+
Torch::NN::BatchNorm3d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm3d, m) },
|
|
73
|
+
Torch::NN::LayerNorm => ->(m) {
|
|
74
|
+
Torch::NN::LayerNorm.new(m.instance_variable_get(:@normalized_shape),
|
|
75
|
+
eps: m.instance_variable_get(:@eps),
|
|
76
|
+
elementwise_affine: m.instance_variable_get(:@elementwise_affine))
|
|
77
|
+
},
|
|
78
|
+
Torch::NN::Embedding => ->(m) {
|
|
79
|
+
Torch::NN::Embedding.new(m.instance_variable_get(:@num_embeddings),
|
|
80
|
+
m.instance_variable_get(:@embedding_dim),
|
|
81
|
+
padding_idx: m.instance_variable_get(:@padding_idx))
|
|
82
|
+
},
|
|
83
|
+
Torch::NN::Dropout => ->(m) { Torch::NN::Dropout.new(p: m.instance_variable_get(:@p)) },
|
|
84
|
+
Torch::NN::Dropout2d => ->(m) { Torch::NN::Dropout2d.new(p: m.instance_variable_get(:@p)) },
|
|
85
|
+
Torch::NN::Dropout3d => ->(m) { Torch::NN::Dropout3d.new(p: m.instance_variable_get(:@p)) },
|
|
86
|
+
Torch::NN::LSTM => ->(m) { copy_rnn(Torch::NN::LSTM, m) },
|
|
87
|
+
Torch::NN::GRU => ->(m) { copy_rnn(Torch::NN::GRU, m) },
|
|
88
|
+
Torch::NN::RNN => ->(m) { copy_rnn(Torch::NN::RNN, m) },
|
|
89
|
+
Torch::NN::ReLU => ->(_) { Torch::NN::ReLU.new },
|
|
90
|
+
Torch::NN::GELU => ->(_) { Torch::NN::GELU.new },
|
|
91
|
+
Torch::NN::Tanh => ->(_) { Torch::NN::Tanh.new },
|
|
92
|
+
Torch::NN::Sigmoid => ->(_) { Torch::NN::Sigmoid.new },
|
|
93
|
+
Torch::NN::Identity => ->(_) { Torch::NN::Identity.new },
|
|
94
|
+
Torch::NN::Softmax => ->(m) { Torch::NN::Softmax.new(dim: m.instance_variable_get(:@dim)) },
|
|
95
|
+
Torch::NN::LogSoftmax => ->(m) { Torch::NN::LogSoftmax.new(dim: m.instance_variable_get(:@dim)) },
|
|
96
|
+
}
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def has_bias?(mod)
|
|
100
|
+
!mod.instance_variable_get(:@bias).nil?
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def copy_conv(klass, mod)
|
|
104
|
+
klass.new(mod.in_channels, mod.out_channels, mod.kernel_size,
|
|
105
|
+
stride: mod.stride, padding: mod.padding, dilation: mod.dilation,
|
|
106
|
+
groups: mod.groups, bias: has_bias?(mod), padding_mode: mod.padding_mode)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def copy_batch_norm(klass, mod)
|
|
110
|
+
klass.new(mod.num_features, eps: mod.eps, momentum: mod.momentum,
|
|
111
|
+
affine: mod.affine, track_running_stats: mod.track_running_stats)
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def copy_rnn(klass, mod)
|
|
115
|
+
klass.new(mod.input_size, mod.hidden_size, num_layers: mod.num_layers,
|
|
116
|
+
bias: mod.bias, batch_first: mod.batch_first,
|
|
117
|
+
dropout: mod.dropout, bidirectional: mod.bidirectional)
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def copy_sequential(mod)
|
|
121
|
+
Torch::NN::Sequential.new(*mod.children.map { |child| deep_copy_module(child) })
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
def copy_module_list(mod)
|
|
125
|
+
Torch::NN::ModuleList.new(mod.map { |child| deep_copy_module(child) })
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def copy_custom_module(mod)
|
|
129
|
+
klass = mod.class
|
|
130
|
+
children = mod.named_children.to_h
|
|
131
|
+
|
|
132
|
+
if children.any?
|
|
133
|
+
# Module has submodules - create structural copy
|
|
134
|
+
replica = klass.allocate
|
|
135
|
+
replica.send(:initialize_module_state)
|
|
136
|
+
copy_instance_state(mod, replica)
|
|
137
|
+
children.each do |name, child|
|
|
138
|
+
replica.instance_variable_set("@#{name}", deep_copy_module(child))
|
|
139
|
+
end
|
|
140
|
+
replica
|
|
141
|
+
else
|
|
142
|
+
# Leaf module - try clone
|
|
143
|
+
mod.clone
|
|
144
|
+
end
|
|
145
|
+
rescue => e
|
|
146
|
+
raise ArgumentError, "Cannot replicate #{klass}. " \
|
|
147
|
+
"Implement #{klass}._replicate(mod) class method. Error: #{e.message}"
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
def copy_instance_state(src, dst)
|
|
151
|
+
src.instance_variables.each do |ivar|
|
|
152
|
+
next if %i[@parameters @buffers @modules @training @non_persistent_buffers_set].include?(ivar)
|
|
153
|
+
val = src.instance_variable_get(ivar)
|
|
154
|
+
next if val.is_a?(Torch::Tensor) || val.is_a?(Torch::NN::Module)
|
|
155
|
+
dst.instance_variable_set(ivar, val)
|
|
156
|
+
end
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
# Add helper method to Module for initializing state (if not already defined)
|
|
164
|
+
module Torch
|
|
165
|
+
module NN
|
|
166
|
+
class Module
|
|
167
|
+
private
|
|
168
|
+
|
|
169
|
+
def initialize_module_state
|
|
170
|
+
@training = true
|
|
171
|
+
@parameters = {}
|
|
172
|
+
@buffers = {}
|
|
173
|
+
@modules = {}
|
|
174
|
+
@non_persistent_buffers_set = Set.new
|
|
175
|
+
end unless private_method_defined?(:initialize_module_state)
|
|
176
|
+
end
|
|
177
|
+
end
|
|
178
|
+
end
|
data/lib/torch_dl.rb
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
require "torch"
|
|
2
|
+
require_relative "torch_dl/version"
|
|
3
|
+
require_relative "torch_dl/cuda"
|
|
4
|
+
require_relative "torch_dl/ext"
|
|
5
|
+
require_relative "torch_dl/nn/parallel/replicate"
|
|
6
|
+
require_relative "torch_dl/nn/parallel/parallel_apply"
|
|
7
|
+
require_relative "torch_dl/nn/parallel/data_parallel"
|
|
8
|
+
|
|
9
|
+
module TorchDL
|
|
10
|
+
class Error < StandardError; end
|
|
11
|
+
|
|
12
|
+
class << self
|
|
13
|
+
def scatter(input, devices, dim = 0)
|
|
14
|
+
TorchDL::Ext.scatter(input, devices, dim)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def gather(inputs, target_device, dim = 0)
|
|
18
|
+
TorchDL::Ext.gather(inputs, target_device, dim)
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Convenience alias
|
|
24
|
+
DataParallel = TorchDL::NN::DataParallel
|
metadata
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
|
2
|
+
name: torch-dl
|
|
3
|
+
version: !ruby/object:Gem::Version
|
|
4
|
+
version: 0.1.0
|
|
5
|
+
platform: ruby
|
|
6
|
+
authors:
|
|
7
|
+
- Chris Hasinski
|
|
8
|
+
bindir: bin
|
|
9
|
+
cert_chain: []
|
|
10
|
+
date: 1980-01-02 00:00:00.000000000 Z
|
|
11
|
+
dependencies:
|
|
12
|
+
- !ruby/object:Gem::Dependency
|
|
13
|
+
name: torch-rb
|
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
|
15
|
+
requirements:
|
|
16
|
+
- - ">="
|
|
17
|
+
- !ruby/object:Gem::Version
|
|
18
|
+
version: '0.17'
|
|
19
|
+
type: :runtime
|
|
20
|
+
prerelease: false
|
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
22
|
+
requirements:
|
|
23
|
+
- - ">="
|
|
24
|
+
- !ruby/object:Gem::Version
|
|
25
|
+
version: '0.17'
|
|
26
|
+
- !ruby/object:Gem::Dependency
|
|
27
|
+
name: fiddle
|
|
28
|
+
requirement: !ruby/object:Gem::Requirement
|
|
29
|
+
requirements:
|
|
30
|
+
- - ">="
|
|
31
|
+
- !ruby/object:Gem::Version
|
|
32
|
+
version: '0'
|
|
33
|
+
type: :runtime
|
|
34
|
+
prerelease: false
|
|
35
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
36
|
+
requirements:
|
|
37
|
+
- - ">="
|
|
38
|
+
- !ruby/object:Gem::Version
|
|
39
|
+
version: '0'
|
|
40
|
+
description: DataParallel and distributed training utilities for torch.rb. Split batches
|
|
41
|
+
across multiple GPUs automatically.
|
|
42
|
+
email:
|
|
43
|
+
- krzysztof.hasinski@gmail.com
|
|
44
|
+
executables: []
|
|
45
|
+
extensions: []
|
|
46
|
+
extra_rdoc_files: []
|
|
47
|
+
files:
|
|
48
|
+
- CHANGELOG.md
|
|
49
|
+
- LICENSE.txt
|
|
50
|
+
- README.md
|
|
51
|
+
- lib/torch_dl.rb
|
|
52
|
+
- lib/torch_dl/cuda.rb
|
|
53
|
+
- lib/torch_dl/ext.rb
|
|
54
|
+
- lib/torch_dl/nn/parallel/data_parallel.rb
|
|
55
|
+
- lib/torch_dl/nn/parallel/parallel_apply.rb
|
|
56
|
+
- lib/torch_dl/nn/parallel/replicate.rb
|
|
57
|
+
- lib/torch_dl/version.rb
|
|
58
|
+
homepage: https://github.com/khasinski/torch-rb-dl
|
|
59
|
+
licenses:
|
|
60
|
+
- MIT
|
|
61
|
+
metadata:
|
|
62
|
+
allowed_push_host: https://rubygems.org
|
|
63
|
+
homepage_uri: https://github.com/khasinski/torch-rb-dl
|
|
64
|
+
source_code_uri: https://github.com/khasinski/torch-rb-dl
|
|
65
|
+
changelog_uri: https://github.com/khasinski/torch-rb-dl/blob/master/CHANGELOG.md
|
|
66
|
+
rubygems_mfa_required: 'true'
|
|
67
|
+
rdoc_options: []
|
|
68
|
+
require_paths:
|
|
69
|
+
- lib
|
|
70
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
|
71
|
+
requirements:
|
|
72
|
+
- - ">="
|
|
73
|
+
- !ruby/object:Gem::Version
|
|
74
|
+
version: '3.1'
|
|
75
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
76
|
+
requirements:
|
|
77
|
+
- - ">="
|
|
78
|
+
- !ruby/object:Gem::Version
|
|
79
|
+
version: '0'
|
|
80
|
+
requirements: []
|
|
81
|
+
rubygems_version: 4.0.2
|
|
82
|
+
specification_version: 4
|
|
83
|
+
summary: Multi-GPU training for torch.rb
|
|
84
|
+
test_files: []
|