torch-dl 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 395b034959ba88dc532f1b9903dfcbf5404902117d5b833c112ee7e3851b20e5
4
+ data.tar.gz: f7dce1d7003eafc711e172f0877d7929b01055d879a4057314ba22b26cbbb53e
5
+ SHA512:
6
+ metadata.gz: 2bb333d0eb7bc3d2e9785d7a66e82d1340cc25547d760296cb0ae47f20ba5517f4fa2c03709c81868612d1313ac6ed4e4e4808041d20de6a66a3e4e00c9f8454
7
+ data.tar.gz: 1c78b7ada3a7f669db54ea24fe385c0e2811418ce1546836db6fa31549979e14fee9c1466b5dd34b79eae0541705a4db7cf4de59d66ccc19a4a221ad17cf26f6
data/CHANGELOG.md ADDED
@@ -0,0 +1,9 @@
1
+ # Changelog
2
+
3
+ ## 0.1.0 (2025-01-01)
4
+
5
+ - Initial release
6
+ - DataParallel for multi-GPU training
7
+ - Scatter/gather operations
8
+ - Thread-based parallel execution
9
+ - Replica caching for performance
data/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Chris Hasinski
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,115 @@
1
+ # torch-dl
2
+
3
+ Multi-GPU training for [torch.rb](https://github.com/ankane/torch.rb). Split batches across multiple GPUs automatically.
4
+
5
+ ## Installation
6
+
7
+ Add to your Gemfile:
8
+
9
+ ```ruby
10
+ gem "torch-dl"
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ ### Basic Usage
16
+
17
+ ```ruby
18
+ require "torch_dl"
19
+
20
+ # Create your model on the first GPU
21
+ model = MyModel.new.to("cuda:0")
22
+
23
+ # Wrap with DataParallel
24
+ dp_model = DataParallel.new(model, device_ids: [0, 1])
25
+
26
+ # Training loop
27
+ optimizer.zero_grad
28
+ output = dp_model.call(input)
29
+ loss = criterion.call(output, target)
30
+ loss.backward
31
+ optimizer.step
32
+ ```
33
+
34
+ ### Models That Return Loss
35
+
36
+ If your model returns a scalar loss (e.g., GPT models returning `[logits, loss]`), use `dp_model.backward` instead of `loss.backward`:
37
+
38
+ ```ruby
39
+ optimizer.zero_grad
40
+ logits, loss = dp_model.call(input, targets: targets)
41
+ dp_model.backward(scale: 1.0)
42
+ optimizer.step
43
+ ```
44
+
45
+ This is necessary because gathering scalar tensors across devices breaks the autograd graph. The `backward` method calls backward on each replica's loss separately, then reduces gradients to the original module.
46
+
47
+ ### Gradient Accumulation
48
+
49
+ For gradient accumulation, scale the backward pass:
50
+
51
+ ```ruby
52
+ gradient_accumulation_steps = 4
53
+
54
+ (0...gradient_accumulation_steps).each do |step|
55
+ logits, loss = dp_model.call(input_batch, targets: targets_batch)
56
+ dp_model.backward(scale: 1.0 / gradient_accumulation_steps)
57
+ end
58
+ optimizer.step
59
+ optimizer.zero_grad
60
+ ```
61
+
62
+ ## API Reference
63
+
64
+ ### DataParallel
65
+
66
+ ```ruby
67
+ DataParallel.new(model, device_ids: nil, output_device: nil, dim: 0)
68
+ ```
69
+
70
+ - `model` - The module to parallelize (must be on `cuda:0`)
71
+ - `device_ids` - Array of GPU indices to use (default: all available)
72
+ - `output_device` - GPU index for output (default: first device)
73
+ - `dim` - Dimension to scatter inputs along (default: 0, batch dimension)
74
+
75
+ #### Methods
76
+
77
+ - `call(*inputs, **kwargs)` - Forward pass with automatic scattering/gathering
78
+ - `backward(scale: 1.0)` - Backward pass for loss-returning models
79
+ - `module` / `wrapped_module` - Access the underlying model
80
+ - `parameters` - Access model parameters (for optimizer)
81
+ - `state_dict` / `load_state_dict` - Save/load model state
82
+ - `train` / `eval` - Set training/evaluation mode
83
+
84
+ ### Low-level Functions
85
+
86
+ ```ruby
87
+ # Split tensor across devices
88
+ TorchDL.scatter(tensor, ["cuda:0", "cuda:1"], dim)
89
+
90
+ # Gather tensors to a single device
91
+ TorchDL.gather(tensors, "cuda:0", dim)
92
+ ```
93
+
94
+ ## How It Works
95
+
96
+ 1. **Scatter**: Input batch is split across GPUs
97
+ 2. **Replicate**: Model is copied to each GPU (cached for performance)
98
+ 3. **Parallel Apply**: Forward pass runs on each GPU in parallel using threads
99
+ 4. **Gather**: Outputs are collected back to the output device
100
+
101
+ ## Requirements
102
+
103
+ - Ruby 3.1+
104
+ - torch.rb 0.17+
105
+ - Multiple CUDA GPUs
106
+
107
+ ## Notes
108
+
109
+ - Works with stock torch.rb from RubyGems
110
+ - Uses FFI (fiddle) to call CUDA runtime directly for `synchronize`, `current_device`, and `set_device`
111
+ - No C extension required - pure Ruby gem
112
+
113
+ ## License
114
+
115
+ MIT
@@ -0,0 +1,76 @@
1
+ require "fiddle"
2
+
3
+ module TorchDL
4
+ module CUDA
5
+ class << self
6
+ def synchronize
7
+ return unless Torch::CUDA.available?
8
+
9
+ @cudart ||= load_cudart
10
+ return unless @cudart
11
+
12
+ # cudaDeviceSynchronize returns cudaError_t (0 = success)
13
+ @cuda_device_synchronize ||= Fiddle::Function.new(
14
+ @cudart["cudaDeviceSynchronize"],
15
+ [],
16
+ Fiddle::TYPE_INT
17
+ )
18
+ @cuda_device_synchronize.call
19
+ end
20
+
21
+ def current_device
22
+ return -1 unless Torch::CUDA.available?
23
+
24
+ @cudart ||= load_cudart
25
+ return -1 unless @cudart
26
+
27
+ @cuda_get_device ||= Fiddle::Function.new(
28
+ @cudart["cudaGetDevice"],
29
+ [Fiddle::TYPE_VOIDP],
30
+ Fiddle::TYPE_INT
31
+ )
32
+
33
+ device_ptr = Fiddle::Pointer.malloc(Fiddle::SIZEOF_INT)
34
+ @cuda_get_device.call(device_ptr)
35
+ device_ptr[0, Fiddle::SIZEOF_INT].unpack1("i")
36
+ end
37
+
38
+ def set_device(device_id)
39
+ return unless Torch::CUDA.available?
40
+
41
+ @cudart ||= load_cudart
42
+ return unless @cudart
43
+
44
+ @cuda_set_device ||= Fiddle::Function.new(
45
+ @cudart["cudaSetDevice"],
46
+ [Fiddle::TYPE_INT],
47
+ Fiddle::TYPE_INT
48
+ )
49
+ @cuda_set_device.call(device_id)
50
+ end
51
+
52
+ private
53
+
54
+ def load_cudart
55
+ # Try common CUDA runtime library paths
56
+ paths = [
57
+ "libcudart.so",
58
+ "libcudart.so.12",
59
+ "libcudart.so.11",
60
+ "/usr/local/cuda/lib64/libcudart.so",
61
+ "/usr/lib/x86_64-linux-gnu/libcudart.so"
62
+ ]
63
+
64
+ paths.each do |path|
65
+ begin
66
+ return Fiddle.dlopen(path)
67
+ rescue Fiddle::DLError
68
+ next
69
+ end
70
+ end
71
+
72
+ nil
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,20 @@
1
+ module TorchDL
2
+ # Pure Ruby implementation of scatter/gather operations
3
+ module Ext
4
+ class << self
5
+ # Splits a tensor across devices
6
+ def scatter(input, devices, dim = 0)
7
+ chunks = input.chunk(devices.size, dim)
8
+ chunks.each_with_index.map do |chunk, i|
9
+ chunk.to(devices[i])
10
+ end
11
+ end
12
+
13
+ # Gathers tensors from multiple devices onto a single device
14
+ def gather(inputs, target_device, dim = 0)
15
+ on_target = inputs.map { |t| t.to(target_device) }
16
+ Torch.cat(on_target, dim)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,273 @@
1
+ module TorchDL
2
+ module NN
3
+ # Implements data parallelism at the module level.
4
+ #
5
+ # This container parallelizes the application of the given module by
6
+ # splitting the input across the specified devices by chunking in the
7
+ # batch dimension. In the forward pass, the module is replicated on each
8
+ # device, and each replica handles a portion of the input. During the
9
+ # backwards pass, gradients from each replica are summed into the
10
+ # original module.
11
+ #
12
+ # @note Backward Pass for Models Returning Loss
13
+ # When your model returns a scalar loss (e.g., GPT models that return
14
+ # [logits, loss]), you must use the {#backward} method instead of
15
+ # calling loss.backward directly. This is because gathering scalar
16
+ # tensors across devices breaks the autograd graph in torch.rb.
17
+ #
18
+ # @example Training loop with loss-returning model
19
+ # dp_model = TorchDL::NN::DataParallel.new(model, device_ids: [0, 1])
20
+ # optimizer.zero_grad
21
+ # logits, loss = dp_model.call(input, targets: targets)
22
+ # dp_model.backward(scale: 1.0 / gradient_accumulation_steps)
23
+ # optimizer.step
24
+ #
25
+ # @example Basic usage (model returns output only)
26
+ # model = MyModel.new.to("cuda:0")
27
+ # dp_model = TorchDL::NN::DataParallel.new(model, device_ids: [0, 1])
28
+ # output = dp_model.call(input)
29
+ # loss = criterion.call(output, target)
30
+ # loss.backward # Standard backward works when loss computed after gather
31
+ #
32
+ class DataParallel < Torch::NN::Module
33
+ attr_reader :module, :device_ids, :output_device, :dim
34
+ alias_method :wrapped_module, :module
35
+
36
+ # @param mod [Torch::NN::Module] Module to parallelize
37
+ # @param device_ids [Array<Integer>, nil] CUDA devices (default: all available)
38
+ # @param output_device [Integer, nil] Device for output (default: device_ids[0])
39
+ # @param dim [Integer] Dimension to scatter inputs along (default: 0)
40
+ def initialize(mod, device_ids: nil, output_device: nil, dim: 0)
41
+ super()
42
+ @module = mod
43
+ @device_ids = device_ids || (0...Torch::CUDA.device_count).to_a
44
+ @output_device = output_device || @device_ids.first
45
+ @dim = dim
46
+ @replica_losses = nil
47
+
48
+ if @device_ids.empty?
49
+ raise ArgumentError, "device_ids cannot be empty"
50
+ end
51
+
52
+ # Convert to device strings for internal use
53
+ @device_strings = @device_ids.map { |id| "cuda:#{id}" }
54
+ @output_device_string = "cuda:#{@output_device}"
55
+ end
56
+
57
+ def forward(*inputs, **kwargs)
58
+ # Empty input check
59
+ if inputs.empty?
60
+ return @module.call(**kwargs)
61
+ end
62
+
63
+ # Single GPU fast path
64
+ if @device_ids.size == 1
65
+ return @module.call(*inputs.map { |i| i.to(@device_strings.first) }, **kwargs)
66
+ end
67
+
68
+ # Scatter inputs across devices
69
+ scattered_inputs = scatter(inputs, @device_strings, @dim)
70
+ scattered_kwargs = scatter_kwargs(kwargs, @device_strings, @dim)
71
+
72
+ # Get or create replicas, sync weights
73
+ num_replicas = scattered_inputs.size
74
+ devices = @device_strings[0...num_replicas]
75
+ @replicas = get_replicas(devices)
76
+ sync_replica_weights
77
+
78
+ # Apply in parallel
79
+ outputs = parallel_apply(@replicas, scattered_inputs, scattered_kwargs)
80
+
81
+ # Ensure all CUDA operations complete before gathering
82
+ TorchDL::CUDA.synchronize if Torch::CUDA.available?
83
+
84
+ # Gather outputs back to output device
85
+ gather(outputs, @output_device_string, @dim)
86
+ end
87
+
88
+ # Performs backward pass on all replica losses and reduces gradients.
89
+ # This is needed because gather creates a new tensor that breaks the
90
+ # autograd connection across devices. By calling backward on each
91
+ # replica's loss separately, gradients flow properly.
92
+ #
93
+ # @param scale [Float] Scale factor for gradients (e.g., 1.0/gradient_accumulation_steps)
94
+ def backward(scale: 1.0)
95
+ if @replica_losses && @replica_losses.size > 1
96
+ # Each replica's loss contributes equally to total loss
97
+ # Scale by 1/N to average gradients across replicas
98
+ replica_scale = scale / @replica_losses.size
99
+
100
+ @replica_losses.each do |replica_loss|
101
+ # Scale the loss before backward to get properly scaled gradients
102
+ scaled_loss = replica_loss * replica_scale
103
+ scaled_loss.backward
104
+ end
105
+
106
+ @replica_losses = nil
107
+ end
108
+
109
+ # Reduce gradients from all replicas to the original module
110
+ reduce_gradients
111
+ end
112
+
113
+ # Reduce gradients from replicas back to the original module.
114
+ # Called automatically by backward(), but can be called manually if needed.
115
+ def reduce_gradients
116
+ return if @replicas.nil? || @replicas.size <= 1
117
+
118
+ # Get original model's parameters (first replica is the original)
119
+ original_params = @module.parameters.to_a
120
+
121
+ # Accumulate gradients from other replicas
122
+ @replicas[1..].each do |replica|
123
+ replica_params = replica.parameters.to_a
124
+ original_params.zip(replica_params).each do |orig, repl|
125
+ next unless repl.grad
126
+
127
+ # Move replica gradient to original device and add
128
+ if orig.grad
129
+ orig.grad.add!(repl.grad.to(orig.device))
130
+ else
131
+ orig.grad = repl.grad.to(orig.device).clone
132
+ end
133
+ end
134
+ end
135
+ end
136
+
137
+ # Delegate training mode to wrapped module
138
+ def train(mode = true)
139
+ super
140
+ @module.train(mode)
141
+ self
142
+ end
143
+
144
+ def eval
145
+ train(false)
146
+ end
147
+
148
+ # Delegate parameter access to wrapped module
149
+ def parameters
150
+ @module.parameters
151
+ end
152
+
153
+ def named_parameters(prefix: "", recurse: true)
154
+ @module.named_parameters(prefix: prefix, recurse: recurse)
155
+ end
156
+
157
+ def state_dict(destination: nil, prefix: "")
158
+ @module.state_dict(destination: destination, prefix: prefix)
159
+ end
160
+
161
+ def load_state_dict(state_dict, strict: true)
162
+ @module.load_state_dict(state_dict, strict: strict)
163
+ end
164
+
165
+ def extra_inspect
166
+ format("device_ids: %s, output_device: %s, dim: %d", @device_ids, @output_device, @dim)
167
+ end
168
+
169
+ private
170
+
171
+ # Scatter a single value across devices
172
+ def scatter_value(value, devices, dim)
173
+ if value.is_a?(Torch::Tensor)
174
+ TorchDL::Ext.scatter(value, devices, dim).map(&:contiguous)
175
+ else
176
+ devices.map { value }
177
+ end
178
+ end
179
+
180
+ def scatter(inputs, devices, dim)
181
+ # Scatter each input, then transpose to group by device
182
+ scattered = inputs.map { |input| scatter_value(input, devices, dim) }
183
+ scattered.first.size.times.map { |i| scattered.map { |s| s[i] } }
184
+ end
185
+
186
+ def scatter_kwargs(kwargs, devices, dim)
187
+ return devices.map { {} } if kwargs.empty?
188
+ scattered = kwargs.transform_values { |v| scatter_value(v, devices, dim) }
189
+ devices.size.times.map { |i| scattered.transform_values { |v| v[i] } }
190
+ end
191
+
192
+ def get_replicas(devices)
193
+ # Return cached replicas if they match the requested devices
194
+ if @replicas && @replica_devices == devices
195
+ return @replicas
196
+ end
197
+
198
+ # Create new replicas and cache them
199
+ @replica_devices = devices
200
+ Replicate.replicate(@module, devices)
201
+ end
202
+
203
+ def sync_replica_weights
204
+ return if @replicas.nil? || @replicas.size <= 1
205
+
206
+ # Sync parameters in-place to avoid memory allocation
207
+ original_params = @module.parameters.to_a
208
+ @replicas[1..].each do |replica|
209
+ replica.parameters.to_a.each_with_index do |param, i|
210
+ param.data.copy!(original_params[i].data)
211
+ end
212
+ end
213
+ end
214
+
215
+ def parallel_apply(replicas, inputs, kwargs_list)
216
+ ParallelApply.parallel_apply(replicas, inputs, kwargs_list)
217
+ end
218
+
219
+ def gather(outputs, target_device, dim)
220
+ # Handle different output types
221
+ first = outputs.first
222
+
223
+ case first
224
+ when Torch::Tensor
225
+ gather_tensor(outputs, target_device, dim)
226
+ when Array
227
+ # Tuple/array of tensors - gather each element
228
+ first.size.times.map do |i|
229
+ tensors = outputs.map { |o| o[i] }
230
+ if tensors.first.is_a?(Torch::Tensor)
231
+ gather_tensor(tensors, target_device, dim)
232
+ elsif tensors.first.nil?
233
+ nil
234
+ else
235
+ tensors.first # Non-tensor, just return first
236
+ end
237
+ end
238
+ when Hash
239
+ # Dict of tensors - gather each value
240
+ first.keys.map do |key|
241
+ tensors = outputs.map { |o| o[key] }
242
+ if tensors.first.is_a?(Torch::Tensor)
243
+ [key, gather_tensor(tensors, target_device, dim)]
244
+ else
245
+ [key, tensors.first]
246
+ end
247
+ end.to_h
248
+ else
249
+ # Scalar or other - return first output
250
+ first
251
+ end
252
+ end
253
+
254
+ def gather_tensor(tensors, target_device, dim)
255
+ first = tensors.first
256
+ # Handle scalar tensors (0-dim) - store for backward, return average for display
257
+ if first.dim == 0
258
+ # Store individual losses for backward() method
259
+ @replica_losses = tensors
260
+
261
+ # Return mean of losses (moved to same device) for logging/display
262
+ # Note: This returned tensor should NOT be used for backward() - use backward() instead
263
+ sum = tensors.reduce(Torch.tensor(0.0, device: target_device)) do |acc, t|
264
+ acc + t.to(target_device).detach
265
+ end
266
+ sum / tensors.size
267
+ else
268
+ TorchDL::Ext.gather(tensors, target_device, dim)
269
+ end
270
+ end
271
+ end
272
+ end
273
+ end
@@ -0,0 +1,60 @@
1
+ module TorchDL
2
+ module NN
3
+ module ParallelApply
4
+ class << self
5
+ # Applies modules to inputs in parallel across devices.
6
+ #
7
+ # @param modules [Array<Torch::NN::Module>] List of module replicas
8
+ # @param inputs [Array] List of inputs, one per module
9
+ # @param kwargs_list [Array<Hash>, nil] Optional list of kwargs, one per module
10
+ # @return [Array] List of outputs, one per module
11
+ def parallel_apply(modules, inputs, kwargs_list = nil)
12
+ kwargs_list ||= modules.map { {} }
13
+
14
+ unless modules.size == inputs.size && modules.size == kwargs_list.size
15
+ raise ArgumentError, "modules, inputs, and kwargs_list must have the same length"
16
+ end
17
+
18
+ # Single module - no parallelism needed
19
+ if modules.size == 1
20
+ return [apply_module(modules[0], inputs[0], kwargs_list[0])]
21
+ end
22
+
23
+ parallel_apply_threads(modules, inputs, kwargs_list)
24
+ end
25
+
26
+ private
27
+
28
+ def parallel_apply_threads(modules, inputs, kwargs_list)
29
+ results = Array.new(modules.size)
30
+ errors = Array.new(modules.size)
31
+
32
+ threads = modules.each_with_index.map do |mod, i|
33
+ Thread.new(mod, inputs[i], kwargs_list[i], i) do |m, inp, kw, idx|
34
+ results[idx] = apply_module(m, inp, kw)
35
+ rescue => e
36
+ errors[idx] = e
37
+ end
38
+ end
39
+
40
+ threads.each(&:join)
41
+
42
+ # Re-raise first error if any
43
+ errors.each_with_index do |err, i|
44
+ raise err if err
45
+ end
46
+
47
+ results
48
+ end
49
+
50
+ def apply_module(mod, input, kwargs)
51
+ if input.is_a?(Array)
52
+ mod.call(*input, **kwargs)
53
+ else
54
+ mod.call(input, **kwargs)
55
+ end
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,178 @@
1
+ module TorchDL
2
+ module NN
3
+ module Replicate
4
+ class << self
5
+ # Replicates a module on multiple devices.
6
+ #
7
+ # @param network [Torch::NN::Module] The module to replicate
8
+ # @param devices [Array<String>] List of device strings (e.g., ["cuda:0", "cuda:1"])
9
+ # @return [Array<Torch::NN::Module>] List of module replicas, one per device
10
+ #
11
+ # Note: The first device uses the original network (not a copy) to ensure
12
+ # gradients flow back to the original parameters during backward pass.
13
+ def replicate(network, devices)
14
+ devices = devices.map { |d| d.is_a?(String) ? d : "cuda:#{d}" }
15
+
16
+ # Single device - just return the network (already on correct device)
17
+ return [network] if devices.size == 1
18
+
19
+ # Get the state dict once for creating replicas
20
+ state_dict = network.state_dict
21
+
22
+ # Create replicas - first device uses original network for gradient flow
23
+ devices.each_with_index.map do |device, idx|
24
+ if idx == 0
25
+ # First device: use the original network to maintain gradient connection
26
+ network
27
+ else
28
+ # Other devices: create independent replicas
29
+ replica = deep_copy_module(network)
30
+
31
+ # Copy state dict tensors to the target device
32
+ # Filter to only include keys that exist in the replica
33
+ replica_keys = replica.state_dict.keys
34
+ device_state = state_dict.select { |k, _| replica_keys.include?(k) }
35
+ .transform_values { |t| t.to(device) }
36
+ replica.load_state_dict(device_state)
37
+ replica.to(device)
38
+ replica.train(network.instance_variable_get(:@training))
39
+ replica
40
+ end
41
+ end
42
+ end
43
+
44
+ private
45
+
46
+ # Creates a deep copy of a module structure
47
+ def deep_copy_module(mod)
48
+ # Check for custom replication hook first
49
+ return mod.class._replicate(mod) if mod.class.respond_to?(:_replicate)
50
+
51
+ # Handle container modules
52
+ return copy_sequential(mod) if mod.is_a?(Torch::NN::Sequential)
53
+ return copy_module_list(mod) if mod.is_a?(Torch::NN::ModuleList)
54
+
55
+ # Try built-in module copiers
56
+ copier = module_copiers[mod.class]
57
+ return copier.call(mod) if copier
58
+
59
+ # Fallback to generic copy for custom modules
60
+ copy_custom_module(mod)
61
+ end
62
+
63
+ # Lazily initialized registry of module copiers
64
+ def module_copiers
65
+ @module_copiers ||= {
66
+ Torch::NN::Linear => ->(m) { Torch::NN::Linear.new(m.in_features, m.out_features, bias: has_bias?(m)) },
67
+ Torch::NN::Conv1d => ->(m) { copy_conv(Torch::NN::Conv1d, m) },
68
+ Torch::NN::Conv2d => ->(m) { copy_conv(Torch::NN::Conv2d, m) },
69
+ Torch::NN::Conv3d => ->(m) { copy_conv(Torch::NN::Conv3d, m) },
70
+ Torch::NN::BatchNorm1d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm1d, m) },
71
+ Torch::NN::BatchNorm2d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm2d, m) },
72
+ Torch::NN::BatchNorm3d => ->(m) { copy_batch_norm(Torch::NN::BatchNorm3d, m) },
73
+ Torch::NN::LayerNorm => ->(m) {
74
+ Torch::NN::LayerNorm.new(m.instance_variable_get(:@normalized_shape),
75
+ eps: m.instance_variable_get(:@eps),
76
+ elementwise_affine: m.instance_variable_get(:@elementwise_affine))
77
+ },
78
+ Torch::NN::Embedding => ->(m) {
79
+ Torch::NN::Embedding.new(m.instance_variable_get(:@num_embeddings),
80
+ m.instance_variable_get(:@embedding_dim),
81
+ padding_idx: m.instance_variable_get(:@padding_idx))
82
+ },
83
+ Torch::NN::Dropout => ->(m) { Torch::NN::Dropout.new(p: m.instance_variable_get(:@p)) },
84
+ Torch::NN::Dropout2d => ->(m) { Torch::NN::Dropout2d.new(p: m.instance_variable_get(:@p)) },
85
+ Torch::NN::Dropout3d => ->(m) { Torch::NN::Dropout3d.new(p: m.instance_variable_get(:@p)) },
86
+ Torch::NN::LSTM => ->(m) { copy_rnn(Torch::NN::LSTM, m) },
87
+ Torch::NN::GRU => ->(m) { copy_rnn(Torch::NN::GRU, m) },
88
+ Torch::NN::RNN => ->(m) { copy_rnn(Torch::NN::RNN, m) },
89
+ Torch::NN::ReLU => ->(_) { Torch::NN::ReLU.new },
90
+ Torch::NN::GELU => ->(_) { Torch::NN::GELU.new },
91
+ Torch::NN::Tanh => ->(_) { Torch::NN::Tanh.new },
92
+ Torch::NN::Sigmoid => ->(_) { Torch::NN::Sigmoid.new },
93
+ Torch::NN::Identity => ->(_) { Torch::NN::Identity.new },
94
+ Torch::NN::Softmax => ->(m) { Torch::NN::Softmax.new(dim: m.instance_variable_get(:@dim)) },
95
+ Torch::NN::LogSoftmax => ->(m) { Torch::NN::LogSoftmax.new(dim: m.instance_variable_get(:@dim)) },
96
+ }
97
+ end
98
+
99
+ def has_bias?(mod)
100
+ !mod.instance_variable_get(:@bias).nil?
101
+ end
102
+
103
+ def copy_conv(klass, mod)
104
+ klass.new(mod.in_channels, mod.out_channels, mod.kernel_size,
105
+ stride: mod.stride, padding: mod.padding, dilation: mod.dilation,
106
+ groups: mod.groups, bias: has_bias?(mod), padding_mode: mod.padding_mode)
107
+ end
108
+
109
+ def copy_batch_norm(klass, mod)
110
+ klass.new(mod.num_features, eps: mod.eps, momentum: mod.momentum,
111
+ affine: mod.affine, track_running_stats: mod.track_running_stats)
112
+ end
113
+
114
+ def copy_rnn(klass, mod)
115
+ klass.new(mod.input_size, mod.hidden_size, num_layers: mod.num_layers,
116
+ bias: mod.bias, batch_first: mod.batch_first,
117
+ dropout: mod.dropout, bidirectional: mod.bidirectional)
118
+ end
119
+
120
+ def copy_sequential(mod)
121
+ Torch::NN::Sequential.new(*mod.children.map { |child| deep_copy_module(child) })
122
+ end
123
+
124
+ def copy_module_list(mod)
125
+ Torch::NN::ModuleList.new(mod.map { |child| deep_copy_module(child) })
126
+ end
127
+
128
+ def copy_custom_module(mod)
129
+ klass = mod.class
130
+ children = mod.named_children.to_h
131
+
132
+ if children.any?
133
+ # Module has submodules - create structural copy
134
+ replica = klass.allocate
135
+ replica.send(:initialize_module_state)
136
+ copy_instance_state(mod, replica)
137
+ children.each do |name, child|
138
+ replica.instance_variable_set("@#{name}", deep_copy_module(child))
139
+ end
140
+ replica
141
+ else
142
+ # Leaf module - try clone
143
+ mod.clone
144
+ end
145
+ rescue => e
146
+ raise ArgumentError, "Cannot replicate #{klass}. " \
147
+ "Implement #{klass}._replicate(mod) class method. Error: #{e.message}"
148
+ end
149
+
150
+ def copy_instance_state(src, dst)
151
+ src.instance_variables.each do |ivar|
152
+ next if %i[@parameters @buffers @modules @training @non_persistent_buffers_set].include?(ivar)
153
+ val = src.instance_variable_get(ivar)
154
+ next if val.is_a?(Torch::Tensor) || val.is_a?(Torch::NN::Module)
155
+ dst.instance_variable_set(ivar, val)
156
+ end
157
+ end
158
+ end
159
+ end
160
+ end
161
+ end
162
+
163
+ # Add helper method to Module for initializing state (if not already defined)
164
+ module Torch
165
+ module NN
166
+ class Module
167
+ private
168
+
169
+ def initialize_module_state
170
+ @training = true
171
+ @parameters = {}
172
+ @buffers = {}
173
+ @modules = {}
174
+ @non_persistent_buffers_set = Set.new
175
+ end unless private_method_defined?(:initialize_module_state)
176
+ end
177
+ end
178
+ end
@@ -0,0 +1,3 @@
1
+ module TorchDL
2
+ VERSION = "0.1.0"
3
+ end
data/lib/torch_dl.rb ADDED
@@ -0,0 +1,24 @@
1
+ require "torch"
2
+ require_relative "torch_dl/version"
3
+ require_relative "torch_dl/cuda"
4
+ require_relative "torch_dl/ext"
5
+ require_relative "torch_dl/nn/parallel/replicate"
6
+ require_relative "torch_dl/nn/parallel/parallel_apply"
7
+ require_relative "torch_dl/nn/parallel/data_parallel"
8
+
9
+ module TorchDL
10
+ class Error < StandardError; end
11
+
12
+ class << self
13
+ def scatter(input, devices, dim = 0)
14
+ TorchDL::Ext.scatter(input, devices, dim)
15
+ end
16
+
17
+ def gather(inputs, target_device, dim = 0)
18
+ TorchDL::Ext.gather(inputs, target_device, dim)
19
+ end
20
+ end
21
+ end
22
+
23
+ # Convenience alias
24
+ DataParallel = TorchDL::NN::DataParallel
metadata ADDED
@@ -0,0 +1,84 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torch-dl
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Chris Hasinski
8
+ bindir: bin
9
+ cert_chain: []
10
+ date: 1980-01-02 00:00:00.000000000 Z
11
+ dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: torch-rb
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - ">="
17
+ - !ruby/object:Gem::Version
18
+ version: '0.17'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0.17'
26
+ - !ruby/object:Gem::Dependency
27
+ name: fiddle
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - ">="
31
+ - !ruby/object:Gem::Version
32
+ version: '0'
33
+ type: :runtime
34
+ prerelease: false
35
+ version_requirements: !ruby/object:Gem::Requirement
36
+ requirements:
37
+ - - ">="
38
+ - !ruby/object:Gem::Version
39
+ version: '0'
40
+ description: DataParallel and distributed training utilities for torch.rb. Split batches
41
+ across multiple GPUs automatically.
42
+ email:
43
+ - krzysztof.hasinski@gmail.com
44
+ executables: []
45
+ extensions: []
46
+ extra_rdoc_files: []
47
+ files:
48
+ - CHANGELOG.md
49
+ - LICENSE.txt
50
+ - README.md
51
+ - lib/torch_dl.rb
52
+ - lib/torch_dl/cuda.rb
53
+ - lib/torch_dl/ext.rb
54
+ - lib/torch_dl/nn/parallel/data_parallel.rb
55
+ - lib/torch_dl/nn/parallel/parallel_apply.rb
56
+ - lib/torch_dl/nn/parallel/replicate.rb
57
+ - lib/torch_dl/version.rb
58
+ homepage: https://github.com/khasinski/torch-rb-dl
59
+ licenses:
60
+ - MIT
61
+ metadata:
62
+ allowed_push_host: https://rubygems.org
63
+ homepage_uri: https://github.com/khasinski/torch-rb-dl
64
+ source_code_uri: https://github.com/khasinski/torch-rb-dl
65
+ changelog_uri: https://github.com/khasinski/torch-rb-dl/blob/master/CHANGELOG.md
66
+ rubygems_mfa_required: 'true'
67
+ rdoc_options: []
68
+ require_paths:
69
+ - lib
70
+ required_ruby_version: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - ">="
73
+ - !ruby/object:Gem::Version
74
+ version: '3.1'
75
+ required_rubygems_version: !ruby/object:Gem::Requirement
76
+ requirements:
77
+ - - ">="
78
+ - !ruby/object:Gem::Version
79
+ version: '0'
80
+ requirements: []
81
+ rubygems_version: 4.0.2
82
+ specification_version: 4
83
+ summary: Multi-GPU training for torch.rb
84
+ test_files: []