ignis-collective 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +7 -0
  3. data/lib/ignis-collective.rb +9 -0
  4. data/lib/nvruby/collective/algorithms/double_binary_tree.rb +364 -0
  5. data/lib/nvruby/collective/algorithms/pipeliner.rb +222 -0
  6. data/lib/nvruby/collective/algorithms/reduction_ops.rb +168 -0
  7. data/lib/nvruby/collective/algorithms/ring.rb +421 -0
  8. data/lib/nvruby/collective/algorithms/topology_router.rb +284 -0
  9. data/lib/nvruby/collective/algorithms/tree.rb +291 -0
  10. data/lib/nvruby/collective/array_ops.rb +240 -0
  11. data/lib/nvruby/collective/communicator.rb +633 -0
  12. data/lib/nvruby/collective/communicator_healer.rb +276 -0
  13. data/lib/nvruby/collective/device_manager.rb +216 -0
  14. data/lib/nvruby/collective/dynamic_optimizer.rb +308 -0
  15. data/lib/nvruby/collective/health_monitor.rb +333 -0
  16. data/lib/nvruby/collective/net/nd_adapter.rb +450 -0
  17. data/lib/nvruby/collective/net/nd_bindings.rb +166 -0
  18. data/lib/nvruby/collective/net/rdma_transport.rb +366 -0
  19. data/lib/nvruby/collective/nvarray_adapter.rb +230 -0
  20. data/lib/nvruby/collective/p2p_bindings.rb +121 -0
  21. data/lib/nvruby/collective/resilient_transport.rb +296 -0
  22. data/lib/nvruby/collective/topology.rb +347 -0
  23. data/lib/nvruby/collective/transport/base.rb +138 -0
  24. data/lib/nvruby/collective/transport/host_staged_transport.rb +217 -0
  25. data/lib/nvruby/collective/transport/ipc_transport.rb +187 -0
  26. data/lib/nvruby/collective/transport/p2p_transport.rb +157 -0
  27. data/lib/nvruby/collective/transport/rdma_transports.rb +213 -0
  28. data/lib/nvruby/collective/transport/rio_transport.rb +405 -0
  29. data/lib/nvruby/collective/transport/tcp_transport.rb +290 -0
  30. data/lib/nvruby/collective/transport/vmm_ipc_structs.rb +189 -0
  31. data/lib/nvruby/collective/transport/vmm_ipc_transport.rb +266 -0
  32. data/lib/nvruby/collective/transport_selector.rb +200 -0
  33. data/lib/nvruby/collective/vmm_bindings.rb +212 -0
  34. data/lib/nvruby/collective.rb +156 -0
  35. metadata +92 -0
@@ -0,0 +1,168 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ignis"
4
+
5
+ module Ignis
6
+ module Collective
7
+ module Algorithms
8
+ # Reduction operations for collective primitives
9
+ # These operations combine tensor elements during reduce/allreduce
10
+ module ReductionOps
11
+ # Valid reduction operations.
12
+ OPS = %i[sum prod min max avg].freeze
13
+
14
+ # Sum all elements (a + b)
15
+ def self.sum(a, b, result, count, dtype, stream = nil)
16
+ execute(:sum, a, b, result, count, dtype, stream)
17
+ end
18
+
19
+ # Multiply all elements (a * b)
20
+ def self.prod(a, b, result, count, dtype, stream = nil)
21
+ execute(:prod, a, b, result, count, dtype, stream)
22
+ end
23
+
24
+ # Element-wise minimum
25
+ def self.min(a, b, result, count, dtype, stream = nil)
26
+ execute(:min, a, b, result, count, dtype, stream)
27
+ end
28
+
29
+ # Element-wise maximum
30
+ def self.max(a, b, result, count, dtype, stream = nil)
31
+ execute(:max, a, b, result, count, dtype, stream)
32
+ end
33
+
34
+ # Average step. NOTE: averaging is "sum across all ranks, then divide by
35
+ # the participant count ONCE at the end". The per-pair reduction step is
36
+ # therefore a plain sum; the caller (Communicator) performs the final
37
+ # divide-by-N. (Previously this silently returned a sum with no divide.)
38
+ def self.avg(a, b, result, count, dtype, stream = nil, _n_participants = nil)
39
+ execute(:sum, a, b, result, count, dtype, stream)
40
+ end
41
+
42
+ # Execute reduction operation by name: result = op(a, b), elementwise.
43
+ # @param op [Symbol] :sum, :prod, :min, :max, or :avg (avg == sum per step)
44
+ # @param a [FFI::Pointer] First operand (device pointer)
45
+ # @param b [FFI::Pointer] Second operand (device pointer)
46
+ # @param result [FFI::Pointer] Result buffer (may alias a for in-place)
47
+ # @param count [Integer] Element count
48
+ # @param dtype [Symbol] Data type
49
+ # @param stream [FFI::Pointer, nil] CUDA stream
50
+ # @return [void]
51
+ def self.execute(op, a, b, result, count, dtype, stream = nil)
52
+ reduce = (op == :avg ? :sum : op)
53
+ raise ArgumentError, "Unknown reduction operation: #{op}" unless %i[sum prod min max].include?(reduce)
54
+ return if count.zero?
55
+
56
+ if dtype == :float32
57
+ gpu_elementwise(reduce, a, b, result, count)
58
+ else
59
+ # Non-fp32 dtypes use the (correct, slower) host path: the fused JIT
60
+ # kernels are typed `float`, so reinterpreting fp16/fp64/int buffers
61
+ # through them would be wrong.
62
+ host_elementwise_fallback(host_op(reduce), a, b, result, count, dtype)
63
+ end
64
+ end
65
+
66
+ class << self
67
+ private
68
+
69
+ # GPU elementwise reduction for float32 via the fused JIT kernels.
70
+ def gpu_elementwise(op, a, b, result, count)
71
+ kernel = case op
72
+ when :sum then Ignis::JIT::Kernels::Elementwise.add_forward
73
+ when :prod then Ignis::JIT::Kernels::Elementwise.mul_forward
74
+ when :min then Ignis::JIT::Kernels::Elementwise.min_forward
75
+ when :max then Ignis::JIT::Kernels::Elementwise.max_forward
76
+ end
77
+ kernel.launch(grid: [(count + 255) / 256], block: [256], args: [a, b, result, count])
78
+ Ignis.synchronize
79
+ end
80
+
81
+ # Map a reduction op to the host-fallback op name.
82
+ def host_op(op)
83
+ op == :sum ? :add : op
84
+ end
85
+
86
+ # Fallback host-side elementwise (for when NVRTC unavailable)
87
+ def host_elementwise_fallback(op, a, b, result, count, dtype)
88
+ elem_size = dtype_size(dtype)
89
+ total_size = count * elem_size
90
+
91
+ # Allocate host buffers
92
+ host_a = FFI::MemoryPointer.new(:uint8, total_size)
93
+ host_b = FFI::MemoryPointer.new(:uint8, total_size)
94
+ host_result = FFI::MemoryPointer.new(:uint8, total_size)
95
+
96
+ # Copy from device to host
97
+ CUDA::RuntimeAPI.cudaMemcpy(host_a, a, total_size, CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_HOST)
98
+ CUDA::RuntimeAPI.cudaMemcpy(host_b, b, total_size, CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_HOST)
99
+
100
+ # Perform operation
101
+ count.times do |i|
102
+ offset = i * elem_size
103
+ val_a = read_element(host_a, offset, dtype)
104
+ val_b = read_element(host_b, offset, dtype)
105
+
106
+ val_result = case op
107
+ when :add then val_a + val_b
108
+ when :mul then val_a * val_b
109
+ when :min then [val_a, val_b].min
110
+ when :max then [val_a, val_b].max
111
+ end
112
+
113
+ write_element(host_result, offset, val_result, dtype)
114
+ end
115
+
116
+ # Copy back to device
117
+ CUDA::RuntimeAPI.cudaMemcpy(result, host_result, total_size, CUDA::RuntimeAPI::MEMCPY_HOST_TO_DEVICE)
118
+ end
119
+
120
+ # Get size of dtype in bytes
121
+ def dtype_size(dtype)
122
+ case dtype
123
+ when :float32, :int32, :uint32 then 4
124
+ when :float64, :int64, :uint64 then 8
125
+ when :float16, :bfloat16, :int16, :uint16 then 2
126
+ when :int8, :uint8 then 1
127
+ else 4 # Default to float32
128
+ end
129
+ end
130
+
131
+ # Convert Ruby dtype to C type string
132
+ def dtype_to_ctype(dtype)
133
+ case dtype
134
+ when :float32 then "float"
135
+ when :float64 then "double"
136
+ when :float16 then "__half"
137
+ when :int32 then "int"
138
+ when :int64 then "long long"
139
+ else "float"
140
+ end
141
+ end
142
+
143
+ # Read element from host buffer
144
+ def read_element(buffer, offset, dtype)
145
+ case dtype
146
+ when :float32 then buffer.get_float32(offset)
147
+ when :float64 then buffer.get_float64(offset)
148
+ when :int32 then buffer.get_int32(offset)
149
+ when :int64 then buffer.get_int64(offset)
150
+ else buffer.get_float32(offset)
151
+ end
152
+ end
153
+
154
+ # Write element to host buffer
155
+ def write_element(buffer, offset, value, dtype)
156
+ case dtype
157
+ when :float32 then buffer.put_float32(offset, value)
158
+ when :float64 then buffer.put_float64(offset, value)
159
+ when :int32 then buffer.put_int32(offset, value.to_i)
160
+ when :int64 then buffer.put_int64(offset, value.to_i)
161
+ else buffer.put_float32(offset, value)
162
+ end
163
+ end
164
+ end
165
+ end
166
+ end
167
+ end
168
+ end
@@ -0,0 +1,421 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "reduction_ops"
4
+
5
+ module Ignis
6
+ module Collective
7
+ module Algorithms
8
+ # Ring AllReduce algorithm implementation
9
+ #
10
+ # The Ring algorithm performs AllReduce in 2*(N-1) steps where N is number of GPUs:
11
+ # 1. Scatter-Reduce phase: N-1 steps, each GPU sends a chunk and receives+reduces another
12
+ # 2. AllGather phase: N-1 steps, each GPU sends its reduced chunk and receives another
13
+ #
14
+ # Bandwidth complexity: 2 * (N-1)/N * data_size (asymptotically optimal)
15
+ # Latency complexity: 2 * (N-1) * alpha (linear in N)
16
+ #
17
+ # Best for: Large messages (>1MB) where bandwidth dominates latency
18
+ class Ring
19
+ # Chunk metadata for pipelining
20
+ ChunkInfo = Struct.new(:chunk_id, :offset, :size, :src_rank, :dst_rank, keyword_init: true)
21
+
22
+ # @return [Array<Integer>] Ring order (GPU IDs in ring sequence)
23
+ attr_reader :ring_order
24
+
25
+ # @return [Integer] Number of participants
26
+ attr_reader :n_gpus
27
+
28
+ # @return [TransportSelector] Transport selector for GPU pairs
29
+ attr_reader :transport_selector
30
+
31
+ # @param ring_order [Array<Integer>] GPU IDs in ring order
32
+ # @param transport_selector [TransportSelector] Transport selector
33
+ def initialize(ring_order:, transport_selector:)
34
+ @ring_order = ring_order.dup.freeze
35
+ @n_gpus = ring_order.size
36
+ @transport_selector = transport_selector
37
+ @chunk_counts = {}
38
+ end
39
+
40
+ # Perform Ring AllReduce
41
+ #
42
+ # @param buffers [Array<FFI::Pointer>] Device buffers (one per GPU in ring_order)
43
+ # @param sizes [Array<Integer>] Buffer sizes in bytes
44
+ # @param dtype [Symbol] Data type (:float32, :float64, etc.)
45
+ # @param op [Symbol] Reduction operation (:sum, :prod, :min, :max)
46
+ # @param streams [Array<CUDA::Stream, FFI::Pointer>] CUDA streams per GPU
47
+ # @return [void]
48
+ def all_reduce(buffers:, sizes:, dtype:, op:, streams:)
49
+ validate_inputs!(buffers, sizes, streams)
50
+
51
+ return if @n_gpus == 1 # Single GPU - no-op
52
+
53
+ # Even element-wise chunk layout (handles non-divisible sizes without
54
+ # overrunning the buffer on the last chunk).
55
+ total_size = sizes[0]
56
+ layout = chunk_layout(total_size, dtype_elem_size(dtype))
57
+
58
+ # Allocate temp buffers for receive
59
+ recv_buffers = allocate_recv_buffers(total_size)
60
+
61
+ begin
62
+ # Phase 1: Scatter-Reduce
63
+ scatter_reduce!(buffers, recv_buffers, layout, dtype, op, streams)
64
+
65
+ # Phase 2: AllGather
66
+ all_gather!(buffers, recv_buffers, layout, streams)
67
+ ensure
68
+ free_recv_buffers(recv_buffers)
69
+ end
70
+ end
71
+
72
+ # Perform only the scatter-reduce phase (for testing/benchmarking)
73
+ def scatter_reduce_only(buffers:, sizes:, dtype:, op:, streams:)
74
+ layout = chunk_layout(sizes[0], dtype_elem_size(dtype))
75
+ recv_buffers = allocate_recv_buffers(sizes[0])
76
+
77
+ begin
78
+ scatter_reduce!(buffers, recv_buffers, layout, dtype, op, streams)
79
+ ensure
80
+ free_recv_buffers(recv_buffers)
81
+ end
82
+ end
83
+
84
+ # Perform Ring AllGather - gather all chunks to all GPUs
85
+ #
86
+ # Each GPU starts with a chunk of data. After AllGather, each GPU
87
+ # has all chunks from all GPUs concatenated.
88
+ #
89
+ # @param send_buffers [Array<FFI::Pointer>] Source buffers (one per GPU, each with local chunk)
90
+ # @param recv_buffers [Array<FFI::Pointer>] Dest buffers (one per GPU, sized for all chunks)
91
+ # @param send_sizes [Array<Integer>] Size of each GPU's local chunk
92
+ # @param streams [Array<CUDA::Stream, FFI::Pointer>] CUDA streams
93
+ # @return [void]
94
+ def all_gather_standalone(send_buffers:, recv_buffers:, send_sizes:, streams:)
95
+ validate_inputs_gather!(send_buffers, recv_buffers, streams)
96
+
97
+ return if @n_gpus == 1
98
+
99
+ chunk_size = send_sizes[0]
100
+
101
+ # Copy each GPU's local chunk to its position in the result buffer
102
+ @n_gpus.times do |rank|
103
+ gpu_id = @ring_order[rank]
104
+ CUDA::RuntimeAPI.cudaSetDevice(gpu_id)
105
+
106
+ src_offset = 0
107
+ dst_offset = rank * chunk_size
108
+
109
+ stream_ptr = get_stream_ptr(streams[rank])
110
+
111
+ # Copy local data to correct position
112
+ CUDA::RuntimeAPI.cudaMemcpyAsync(
113
+ ptr_offset(recv_buffers[rank], dst_offset),
114
+ send_buffers[rank],
115
+ chunk_size,
116
+ CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_DEVICE,
117
+ stream_ptr
118
+ )
119
+ end
120
+
121
+ synchronize_all_streams!(streams)
122
+
123
+ # N-1 ring steps to propagate all chunks
124
+ (@n_gpus - 1).times do |step|
125
+ @n_gpus.times do |rank|
126
+ gpu_id = @ring_order[rank]
127
+
128
+ # Calculate which chunk to send (the one we just received)
129
+ send_chunk_id = (rank - step + @n_gpus) % @n_gpus
130
+ send_offset = send_chunk_id * chunk_size
131
+
132
+ next_rank = (rank + 1) % @n_gpus
133
+ next_gpu = @ring_order[next_rank]
134
+
135
+ transport = @transport_selector.select_transport(gpu_id, next_gpu)
136
+ stream_ptr = get_stream_ptr(streams[rank])
137
+
138
+ src_ptr = ptr_offset(recv_buffers[rank], send_offset)
139
+ dst_ptr = ptr_offset(recv_buffers[next_rank], send_offset)
140
+ move!(transport, dst_ptr, src_ptr, chunk_size, stream_ptr)
141
+ end
142
+
143
+ synchronize_all_streams!(streams)
144
+ end
145
+ end
146
+
147
+ # Perform Ring ReduceScatter - reduce and scatter result
148
+ #
149
+ # Each GPU starts with a full buffer. After ReduceScatter, each GPU
150
+ # has 1/N of the reduced result (different chunks on different GPUs).
151
+ #
152
+ # @param buffers [Array<FFI::Pointer>] Buffers (one per GPU, full size)
153
+ # @param result_buffers [Array<FFI::Pointer>] Result buffers (one per GPU, chunk size)
154
+ # @param sizes [Array<Integer>] Full buffer sizes
155
+ # @param dtype [Symbol] Data type
156
+ # @param op [Symbol] Reduction operation
157
+ # @param streams [Array<CUDA::Stream, FFI::Pointer>] CUDA streams
158
+ # @return [void]
159
+ def reduce_scatter(buffers:, result_buffers:, sizes:, dtype:, op:, streams:)
160
+ validate_inputs!(buffers, sizes, streams)
161
+
162
+ return if @n_gpus == 1
163
+
164
+ total_size = sizes[0]
165
+ layout = chunk_layout(total_size, dtype_elem_size(dtype))
166
+
167
+ # Allocate temp buffers
168
+ temp_buffers = allocate_recv_buffers(total_size)
169
+
170
+ begin
171
+ # Scatter-Reduce phase only (same as first half of AllReduce)
172
+ scatter_reduce!(buffers, temp_buffers, layout, dtype, op, streams)
173
+
174
+ # Copy each GPU's final chunk to result buffer
175
+ @n_gpus.times do |rank|
176
+ gpu_id = @ring_order[rank]
177
+ CUDA::RuntimeAPI.cudaSetDevice(gpu_id)
178
+
179
+ # After scatter-reduce, GPU[rank] has fully reduced chunk[(rank+1) % N]
180
+ final_chunk_id = (rank + 1) % @n_gpus
181
+ src_offset, n_bytes, = layout[final_chunk_id]
182
+ next if n_bytes.zero?
183
+
184
+ stream_ptr = get_stream_ptr(streams[rank])
185
+
186
+ CUDA::RuntimeAPI.cudaMemcpyAsync(
187
+ result_buffers[rank],
188
+ ptr_offset(buffers[rank], src_offset),
189
+ n_bytes,
190
+ CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_DEVICE,
191
+ stream_ptr
192
+ )
193
+ end
194
+
195
+ synchronize_all_streams!(streams)
196
+ ensure
197
+ free_recv_buffers(temp_buffers)
198
+ end
199
+ end
200
+
201
+ # Largest chunk size in bytes (ceil division) — used by callers only for
202
+ # allocating result buffers big enough to hold any single chunk.
203
+ # @param total_size [Integer] Total buffer size in bytes
204
+ # @return [Integer] Max chunk size in bytes
205
+ def calculate_chunk_size(total_size)
206
+ (total_size + @n_gpus - 1) / @n_gpus
207
+ end
208
+
209
+ # Even element-wise chunk layout (NCCL-style). Distributes elements as
210
+ # evenly as possible across the N chunks so they tile the whole buffer
211
+ # with NO overrun even when (n_elements % N) != 0 — the previous
212
+ # ceil-rounded byte chunking read/wrote past the buffer on the last chunk.
213
+ # @return [Array<Array(Integer,Integer,Integer,Integer)>]
214
+ # one [offset_bytes, n_bytes, offset_elems, n_elems] per chunk
215
+ def chunk_layout(total_bytes, elem_size)
216
+ total_elems = total_bytes / elem_size
217
+ base = total_elems / @n_gpus
218
+ rem = total_elems % @n_gpus
219
+ off_e = 0
220
+ Array.new(@n_gpus) do |k|
221
+ n_e = base + (k < rem ? 1 : 0)
222
+ entry = [off_e * elem_size, n_e * elem_size, off_e, n_e]
223
+ off_e += n_e
224
+ entry
225
+ end
226
+ end
227
+
228
+ private
229
+
230
+ # Move bytes via the selected transport, failing LOUDLY rather than
231
+ # silently skipping (a non-P2P transport without copy_async would
232
+ # otherwise drop the chunk and corrupt the reduction with no error).
233
+ def move!(transport, dst, src, n_bytes, stream_ptr)
234
+ if transport.respond_to?(:copy_async)
235
+ transport.copy_async(dst, src, n_bytes, stream_ptr)
236
+ else
237
+ raise NotImplementedError,
238
+ "Transport #{transport.class} has no copy_async; non-P2P ring " \
239
+ "movement is not wired yet (refusing to silently drop data)"
240
+ end
241
+ end
242
+
243
+ def validate_inputs!(buffers, sizes, streams)
244
+ unless buffers.size == @n_gpus
245
+ raise ArgumentError, "Expected #{@n_gpus} buffers, got #{buffers.size}"
246
+ end
247
+
248
+ unless sizes.size == @n_gpus
249
+ raise ArgumentError, "Expected #{@n_gpus} sizes, got #{sizes.size}"
250
+ end
251
+
252
+ unless streams.size == @n_gpus
253
+ raise ArgumentError, "Expected #{@n_gpus} streams, got #{streams.size}"
254
+ end
255
+
256
+ # All sizes should be equal for basic ring
257
+ unless sizes.uniq.size == 1
258
+ raise ArgumentError, "All buffer sizes must be equal for Ring AllReduce"
259
+ end
260
+ end
261
+
262
+ def validate_inputs_gather!(send_buffers, recv_buffers, streams)
263
+ unless send_buffers.size == @n_gpus
264
+ raise ArgumentError, "Expected #{@n_gpus} send buffers, got #{send_buffers.size}"
265
+ end
266
+
267
+ unless recv_buffers.size == @n_gpus
268
+ raise ArgumentError, "Expected #{@n_gpus} recv buffers, got #{recv_buffers.size}"
269
+ end
270
+
271
+ unless streams.size == @n_gpus
272
+ raise ArgumentError, "Expected #{@n_gpus} streams, got #{streams.size}"
273
+ end
274
+ end
275
+
276
+ # Allocate temporary receive buffers on each GPU
277
+ def allocate_recv_buffers(size)
278
+ CUDA::RuntimeAPI.ensure_loaded!
279
+
280
+ @ring_order.map do |gpu_id|
281
+ # Set device context
282
+ status = CUDA::RuntimeAPI.cudaSetDevice(gpu_id)
283
+ CUDA::RuntimeAPI.check_status!(status, "Set device #{gpu_id}")
284
+
285
+ # Allocate buffer
286
+ ptr_ptr = FFI::MemoryPointer.new(:pointer)
287
+ status = CUDA::RuntimeAPI.cudaMalloc(ptr_ptr, size)
288
+ CUDA::RuntimeAPI.check_status!(status, "Alloc recv buffer GPU #{gpu_id}")
289
+
290
+ ptr_ptr.read_pointer
291
+ end
292
+ end
293
+
294
+ # Free temporary receive buffers
295
+ def free_recv_buffers(recv_buffers)
296
+ recv_buffers.each_with_index do |buf, i|
297
+ next unless buf && !buf.null?
298
+
299
+ CUDA::RuntimeAPI.cudaSetDevice(@ring_order[i])
300
+ CUDA::RuntimeAPI.cudaFree(buf)
301
+ rescue StandardError
302
+ # Ignore cleanup errors
303
+ end
304
+ end
305
+
306
+ # Scatter-Reduce phase: N-1 steps
307
+ # In each step, GPU[i] sends chunk[(i-step) % N] to GPU[(i+1) % N]
308
+ # and receives chunk[(i-step-1) % N] from GPU[(i-1) % N], reducing it
309
+ def scatter_reduce!(buffers, recv_buffers, layout, dtype, op, streams)
310
+ (@n_gpus - 1).times do |step|
311
+ # Each GPU sends one chunk to its successor
312
+ @n_gpus.times do |rank|
313
+ gpu_id = @ring_order[rank]
314
+ send_chunk_id = (rank - step) % @n_gpus
315
+ send_offset, n_bytes, = layout[send_chunk_id]
316
+ next if n_bytes.zero?
317
+
318
+ next_rank = (rank + 1) % @n_gpus
319
+ next_gpu = @ring_order[next_rank]
320
+ send_transport = @transport_selector.select_transport(gpu_id, next_gpu)
321
+ stream_ptr = get_stream_ptr(streams[rank])
322
+
323
+ src_ptr = ptr_offset(buffers[rank], send_offset)
324
+ dst_ptr = ptr_offset(recv_buffers[next_rank], send_offset)
325
+ move!(send_transport, dst_ptr, src_ptr, n_bytes, stream_ptr)
326
+ end
327
+
328
+ # Synchronize all GPUs after send
329
+ synchronize_all_streams!(streams)
330
+
331
+ # Now each GPU reduces the received chunk with its local chunk
332
+ @n_gpus.times do |rank|
333
+ gpu_id = @ring_order[rank]
334
+ recv_chunk_id = (rank - step - 1) % @n_gpus
335
+ recv_offset, _n_bytes, _off_e, elem_count = layout[recv_chunk_id]
336
+ next if elem_count.zero?
337
+
338
+ CUDA::RuntimeAPI.cudaSetDevice(gpu_id)
339
+ local_ptr = ptr_offset(buffers[rank], recv_offset)
340
+ recv_ptr = ptr_offset(recv_buffers[rank], recv_offset)
341
+ stream_ptr = get_stream_ptr(streams[rank])
342
+
343
+ # Reduce: local = reduce(local, recv)
344
+ ReductionOps.execute(op, local_ptr, recv_ptr, local_ptr, elem_count, dtype, stream_ptr)
345
+ end
346
+
347
+ # Synchronize before next step
348
+ synchronize_all_streams!(streams)
349
+ end
350
+ end
351
+
352
+ # AllGather phase: N-1 steps
353
+ # In each step, GPU[i] sends its fully-reduced chunk to GPU[(i+1) % N]
354
+ def all_gather!(buffers, recv_buffers, layout, streams)
355
+ (@n_gpus - 1).times do |step|
356
+ @n_gpus.times do |rank|
357
+ gpu_id = @ring_order[rank]
358
+
359
+ # After scatter-reduce, GPU[i] has fully reduced chunk[(i+1) % N]
360
+ send_chunk_id = (rank - step + 1) % @n_gpus
361
+ send_offset, n_bytes, = layout[send_chunk_id]
362
+ next if n_bytes.zero?
363
+
364
+ next_rank = (rank + 1) % @n_gpus
365
+ next_gpu = @ring_order[next_rank]
366
+ send_transport = @transport_selector.select_transport(gpu_id, next_gpu)
367
+ stream_ptr = get_stream_ptr(streams[rank])
368
+
369
+ src_ptr = ptr_offset(buffers[rank], send_offset)
370
+ dst_ptr = ptr_offset(buffers[next_rank], send_offset)
371
+ move!(send_transport, dst_ptr, src_ptr, n_bytes, stream_ptr)
372
+ end
373
+
374
+ synchronize_all_streams!(streams)
375
+ end
376
+ end
377
+
378
+ # Get stream pointer for FFI
379
+ def get_stream_ptr(stream)
380
+ case stream
381
+ when FFI::Pointer
382
+ stream
383
+ when CUDA::Stream
384
+ stream.ptr
385
+ else
386
+ FFI::Pointer::NULL
387
+ end
388
+ end
389
+
390
+ # Offset a pointer by bytes
391
+ def ptr_offset(ptr, offset)
392
+ FFI::Pointer.new(:uint8, ptr.address + offset)
393
+ end
394
+
395
+ # Get element size for dtype
396
+ def dtype_elem_size(dtype)
397
+ case dtype
398
+ when :float32, :int32 then 4
399
+ when :float64, :int64 then 8
400
+ when :float16, :bfloat16 then 2
401
+ else 4
402
+ end
403
+ end
404
+
405
+ # Synchronize all streams
406
+ def synchronize_all_streams!(streams)
407
+ streams.each_with_index do |stream, i|
408
+ CUDA::RuntimeAPI.cudaSetDevice(@ring_order[i])
409
+
410
+ stream_ptr = get_stream_ptr(stream)
411
+ if stream_ptr.null?
412
+ CUDA::RuntimeAPI.cudaDeviceSynchronize
413
+ else
414
+ CUDA::RuntimeAPI.cudaStreamSynchronize(stream_ptr)
415
+ end
416
+ end
417
+ end
418
+ end
419
+ end
420
+ end
421
+ end