ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,456 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module Sparse
5
+ # Sparse matrix representation with multiple format support
6
+ class SparseMatrix
7
+ # @return [Symbol] Sparse format (:csr, :csc, :coo)
8
+ attr_reader :format
9
+
10
+ # @return [Array<Integer>] Matrix shape [rows, cols]
11
+ attr_reader :shape
12
+
13
+ # @return [Integer] Number of non-zero elements
14
+ attr_reader :nnz
15
+
16
+ # @return [Symbol] Data type
17
+ attr_reader :dtype
18
+
19
+ # @return [Integer] Device index
20
+ attr_reader :device_index
21
+
22
+ # CSR format: row_ptr, col_indices, values
23
+ # @return [CUDA::Memory, nil] Row pointer array
24
+ attr_reader :row_ptr
25
+
26
+ # @return [CUDA::Memory, nil] Column indices
27
+ attr_reader :col_indices
28
+
29
+ # @return [CUDA::Memory, nil] Row indices (for COO)
30
+ attr_reader :row_indices
31
+
32
+ # @return [CUDA::Memory, nil] Values array
33
+ attr_reader :values
34
+
35
+ # Create a sparse matrix in CSR format
36
+ # @param values [Array, NvArray] Non-zero values
37
+ # @param row_ptr [Array, NvArray] Row pointer array
38
+ # @param col_indices [Array, NvArray] Column indices
39
+ # @param shape [Array<Integer>] Matrix shape [rows, cols]
40
+ # @param dtype [Symbol] Data type
41
+ # @param device [Integer, nil] Device index
42
+ # @return [SparseMatrix]
43
+ def self.csr(values:, row_ptr:, col_indices:, shape:, dtype: :float32, device: nil)
44
+ matrix = new(format: :csr, shape: shape, nnz: values.size, dtype: dtype, device: device)
45
+ matrix.send(:initialize_csr, values, row_ptr, col_indices)
46
+ matrix
47
+ end
48
+
49
+ # Create a sparse matrix in COO format
50
+ # @param values [Array, NvArray] Non-zero values
51
+ # @param row_indices [Array, NvArray] Row indices
52
+ # @param col_indices [Array, NvArray] Column indices
53
+ # @param shape [Array<Integer>] Matrix shape [rows, cols]
54
+ # @param dtype [Symbol] Data type
55
+ # @param device [Integer, nil] Device index
56
+ # @return [SparseMatrix]
57
+ def self.coo(values:, row_indices:, col_indices:, shape:, dtype: :float32, device: nil)
58
+ matrix = new(format: :coo, shape: shape, nnz: values.size, dtype: dtype, device: device)
59
+ matrix.send(:initialize_coo, values, row_indices, col_indices)
60
+ matrix
61
+ end
62
+
63
+ # Create identity sparse matrix
64
+ # @param size [Integer] Matrix size
65
+ # @param dtype [Symbol] Data type
66
+ # @param device [Integer, nil] Device index
67
+ # @return [SparseMatrix]
68
+ def self.identity(size, dtype: :float32, device: nil)
69
+ values = Array.new(size, 1.0)
70
+ row_ptr = (0..size).to_a
71
+ col_indices = (0...size).to_a
72
+
73
+ csr(values: values, row_ptr: row_ptr, col_indices: col_indices,
74
+ shape: [size, size], dtype: dtype, device: device)
75
+ end
76
+
77
+ # Create sparse matrix from dense NvArray
78
+ # @param dense [NvArray] Dense matrix
79
+ # @param format [Symbol] Output format (:csr or :coo)
80
+ # @param threshold [Float] Values below this are treated as zero
81
+ # @return [SparseMatrix]
82
+ def self.from_dense(dense, format: :csr, threshold: 0.0)
83
+ raise ArgumentError, "Expected NvArray, got #{dense.class}" unless dense.is_a?(NvArray)
84
+ raise DimensionError, "Expected 2D array, got #{dense.ndim}D" unless dense.ndim == 2
85
+
86
+ dense = dense.to_host unless dense.on_host?
87
+ data = dense.flatten
88
+
89
+ rows, cols = dense.shape
90
+ row_indices = []
91
+ col_indices_arr = []
92
+ values_arr = []
93
+
94
+ data.each_with_index do |val, idx|
95
+ next if val.abs <= threshold
96
+
97
+ row_indices << idx / cols
98
+ col_indices_arr << idx % cols
99
+ values_arr << val
100
+ end
101
+
102
+ if format == :csr
103
+ # Convert row indices to row pointer
104
+ row_ptr = Array.new(rows + 1, 0)
105
+ row_indices.each { |r| row_ptr[r + 1] += 1 }
106
+ (1..rows).each { |i| row_ptr[i] += row_ptr[i - 1] }
107
+
108
+ csr(values: values_arr, row_ptr: row_ptr, col_indices: col_indices_arr,
109
+ shape: [rows, cols], dtype: dense.dtype, device: dense.device_index)
110
+ else
111
+ coo(values: values_arr, row_indices: row_indices, col_indices: col_indices_arr,
112
+ shape: [rows, cols], dtype: dense.dtype, device: dense.device_index)
113
+ end
114
+ end
115
+
116
+ # @return [Integer] Number of rows
117
+ def rows
118
+ @shape[0]
119
+ end
120
+
121
+ # @return [Integer] Number of columns
122
+ def cols
123
+ @shape[1]
124
+ end
125
+
126
+ # Sparsity ratio
127
+ # @return [Float] Fraction of elements that are zero
128
+ def sparsity
129
+ 1.0 - (@nnz.to_f / (rows * cols))
130
+ end
131
+
132
+ # Density ratio
133
+ # @return [Float] Fraction of elements that are non-zero
134
+ def density
135
+ @nnz.to_f / (rows * cols)
136
+ end
137
+
138
+ # Transfer to GPU
139
+ # @param device [Integer, nil] Target device
140
+ # @return [self]
141
+ def to_device(device: nil)
142
+ target_device = device || @device_index
143
+
144
+ case @format
145
+ when :csr
146
+ transfer_csr_to_device(target_device)
147
+ when :coo
148
+ transfer_coo_to_device(target_device)
149
+ end
150
+
151
+ @device_index = target_device
152
+ @on_device = true
153
+ self
154
+ end
155
+
156
+ # Transfer data to host
157
+ # @return [self]
158
+ def to_host
159
+ return self unless on_device?
160
+
161
+ # Currently we don't support device-side mutation of sparse structures
162
+ # so the host copy is already up to date. We just free device memory.
163
+ free!
164
+ self
165
+ end
166
+
167
+ # Check if on device
168
+ # @return [Boolean]
169
+ def on_device?
170
+ @on_device
171
+ end
172
+
173
+ # Convert to dense NvArray
174
+ # @return [NvArray] Dense matrix
175
+ def to_dense
176
+ result = NvArray.zeros(@shape, dtype: @dtype, device: nil)
177
+
178
+ case @format
179
+ when :csr
180
+ expand_csr_to_dense(result)
181
+ when :coo
182
+ expand_coo_to_dense(result)
183
+ end
184
+
185
+ result
186
+ end
187
+
188
+ # Sparse matrix-vector multiplication: y = alpha * A * x + beta * y
189
+ # @param x [NvArray] Input vector
190
+ # @param y [NvArray, nil] Output vector (created if nil)
191
+ # @param alpha [Float] Scaling factor for A*x
192
+ # @param beta [Float] Scaling factor for y
193
+ # @param transpose [Boolean] Transpose A
194
+ # @return [NvArray] Result vector
195
+ def spmv(x, y: nil, alpha: 1.0, beta: 0.0, transpose: false)
196
+ raise ArgumentError, "Expected NvArray, got #{x.class}" unless x.is_a?(NvArray)
197
+
198
+ out_rows = transpose ? cols : rows
199
+ in_cols = transpose ? rows : cols
200
+
201
+ raise DimensionError, "Vector size #{x.size} != matrix cols #{in_cols}" unless x.size == in_cols
202
+
203
+ to_device unless on_device?
204
+ x = x.to_device unless x.on_device?
205
+
206
+ y ||= NvArray.zeros([out_rows], dtype: @dtype, device: @device_index)
207
+ y = y.to_device unless y.on_device?
208
+
209
+ execute_spmv(x, y, alpha, beta, transpose)
210
+
211
+ y
212
+ end
213
+
214
+ # Free all memory
215
+ # @return [void]
216
+ def free!
217
+ @row_ptr&.free!
218
+ @col_indices&.free!
219
+ @row_indices&.free!
220
+ @values&.free!
221
+
222
+ @row_ptr = nil
223
+ @col_indices = nil
224
+ @row_indices = nil
225
+ @values = nil
226
+ @on_device = false
227
+ end
228
+
229
+ # @return [String]
230
+ def to_s
231
+ "SparseMatrix(shape=#{@shape}, nnz=#{@nnz}, format=#{@format}, density=#{(density * 100).round(2)}%)"
232
+ end
233
+
234
+ private
235
+
236
+ def initialize(format:, shape:, nnz:, dtype:, device:)
237
+ @format = format
238
+ @shape = Array(shape)
239
+ @nnz = nnz
240
+ @dtype = DType.validate!(dtype)
241
+ @device_index = device || Ignis.configuration.default_device
242
+ @on_device = false
243
+ end
244
+
245
+ # Initialize CSR format arrays
246
+ def initialize_csr(values, row_ptr, col_indices)
247
+ @values_host = to_flat_array(values)
248
+ @row_ptr_host = to_int_array(row_ptr)
249
+ @col_indices_host = to_int_array(col_indices)
250
+ end
251
+
252
+ # Initialize COO format arrays
253
+ def initialize_coo(values, row_indices, col_indices)
254
+ @values_host = to_flat_array(values)
255
+ @row_indices_host = to_int_array(row_indices)
256
+ @col_indices_host = to_int_array(col_indices)
257
+ end
258
+
259
+ # Convert to flat array
260
+ def to_flat_array(data)
261
+ case data
262
+ when NvArray
263
+ data.flatten
264
+ when Array
265
+ data.flatten
266
+ else
267
+ Array(data)
268
+ end
269
+ end
270
+
271
+ # Convert to integer array
272
+ def to_int_array(data)
273
+ to_flat_array(data).map(&:to_i)
274
+ end
275
+
276
+ # Transfer CSR data to device
277
+ def transfer_csr_to_device(device)
278
+ # Values
279
+ @values = CUDA::Memory.new(@nnz * DType.byte_size(@dtype), device: device)
280
+ values_ptr = create_host_pointer(@values_host, DType.ffi_type(@dtype))
281
+ @values.copy_from_host(values_ptr)
282
+
283
+ # Row pointer (int32)
284
+ @row_ptr = CUDA::Memory.new((rows + 1) * 4, device: device)
285
+ row_ptr_ffi = FFI::MemoryPointer.new(:int32, rows + 1)
286
+ @row_ptr_host.each_with_index { |v, i| row_ptr_ffi.put_int32(i * 4, v) }
287
+ @row_ptr.copy_from_host(row_ptr_ffi)
288
+
289
+ # Column indices (int32)
290
+ @col_indices = CUDA::Memory.new(@nnz * 4, device: device)
291
+ col_ind_ffi = FFI::MemoryPointer.new(:int32, @nnz)
292
+ @col_indices_host.each_with_index { |v, i| col_ind_ffi.put_int32(i * 4, v) }
293
+ @col_indices.copy_from_host(col_ind_ffi)
294
+ end
295
+
296
+ # Transfer COO data to device
297
+ def transfer_coo_to_device(device)
298
+ # Values
299
+ @values = CUDA::Memory.new(@nnz * DType.byte_size(@dtype), device: device)
300
+ values_ptr = create_host_pointer(@values_host, DType.ffi_type(@dtype))
301
+ @values.copy_from_host(values_ptr)
302
+
303
+ # Row indices (int32)
304
+ @row_indices = CUDA::Memory.new(@nnz * 4, device: device)
305
+ row_ind_ffi = FFI::MemoryPointer.new(:int32, @nnz)
306
+ @row_indices_host.each_with_index { |v, i| row_ind_ffi.put_int32(i * 4, v) }
307
+ @row_indices.copy_from_host(row_ind_ffi)
308
+
309
+ # Column indices (int32)
310
+ @col_indices = CUDA::Memory.new(@nnz * 4, device: device)
311
+ col_ind_ffi = FFI::MemoryPointer.new(:int32, @nnz)
312
+ @col_indices_host.each_with_index { |v, i| col_ind_ffi.put_int32(i * 4, v) }
313
+ @col_indices.copy_from_host(col_ind_ffi)
314
+ end
315
+
316
+ # Create host pointer for values
317
+ def create_host_pointer(data, ffi_type)
318
+ ptr = FFI::MemoryPointer.new(ffi_type, data.size)
319
+ data.each_with_index do |v, i|
320
+ case ffi_type
321
+ when :float then ptr.put_float(i * 4, v)
322
+ when :double then ptr.put_double(i * 8, v)
323
+ end
324
+ end
325
+ ptr
326
+ end
327
+
328
+ # Expand CSR to dense
329
+ def expand_csr_to_dense(result)
330
+ result_data = Array.new(rows * cols, 0.0)
331
+
332
+ rows.times do |row|
333
+ start_idx = @row_ptr_host[row]
334
+ end_idx = @row_ptr_host[row + 1]
335
+
336
+ (start_idx...end_idx).each do |idx|
337
+ col = @col_indices_host[idx]
338
+ result_data[row * cols + col] = @values_host[idx]
339
+ end
340
+ end
341
+
342
+ result.instance_variable_get(:@host_memory).tap do |ptr|
343
+ result_data.each_with_index do |v, i|
344
+ ptr.put_float(i * 4, v)
345
+ end
346
+ end
347
+ end
348
+
349
+ # Expand COO to dense
350
+ def expand_coo_to_dense(result)
351
+ result_data = Array.new(rows * cols, 0.0)
352
+
353
+ @nnz.times do |idx|
354
+ row = @row_indices_host[idx]
355
+ col = @col_indices_host[idx]
356
+ result_data[row * cols + col] = @values_host[idx]
357
+ end
358
+
359
+ result.instance_variable_get(:@host_memory).tap do |ptr|
360
+ result_data.each_with_index do |v, i|
361
+ ptr.put_float(i * 4, v)
362
+ end
363
+ end
364
+ end
365
+
366
+ # Execute sparse matrix-vector multiplication
367
+ # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
368
+ def execute_spmv(x, y, alpha, beta, transpose)
369
+ CuSPARSEBindings.ensure_loaded!
370
+ handle = CuSPARSEBindings.get_handle
371
+
372
+ # Create sparse matrix descriptor
373
+ sp_mat_ptr = FFI::MemoryPointer.new(:pointer)
374
+ value_type = @dtype == :float32 ? 0 : 1 # CUDA_R_32F or CUDA_R_64F
375
+
376
+ if @format == :csr
377
+ status = CuSPARSEBindings.cusparseCreateCsr(
378
+ sp_mat_ptr,
379
+ rows, cols, @nnz,
380
+ @row_ptr.device_ptr,
381
+ @col_indices.device_ptr,
382
+ @values.device_ptr,
383
+ CuSPARSEBindings::CUSPARSE_INDEX_32I,
384
+ CuSPARSEBindings::CUSPARSE_INDEX_32I,
385
+ CuSPARSEBindings::CUSPARSE_INDEX_BASE_ZERO,
386
+ value_type
387
+ )
388
+ else
389
+ status = CuSPARSEBindings.cusparseCreateCoo(
390
+ sp_mat_ptr,
391
+ rows, cols, @nnz,
392
+ @row_indices.device_ptr,
393
+ @col_indices.device_ptr,
394
+ @values.device_ptr,
395
+ CuSPARSEBindings::CUSPARSE_INDEX_32I,
396
+ CuSPARSEBindings::CUSPARSE_INDEX_BASE_ZERO,
397
+ value_type
398
+ )
399
+ end
400
+ CuSPARSEBindings.check_status!(status, "Create sparse matrix descriptor")
401
+ sp_mat = sp_mat_ptr.read_pointer
402
+
403
+ # Create dense vector descriptors
404
+ vec_x_ptr = FFI::MemoryPointer.new(:pointer)
405
+ status = CuSPARSEBindings.cusparseCreateDnVec(vec_x_ptr, x.size, x.device_ptr, value_type)
406
+ CuSPARSEBindings.check_status!(status, "Create dense vector X")
407
+ vec_x = vec_x_ptr.read_pointer
408
+
409
+ vec_y_ptr = FFI::MemoryPointer.new(:pointer)
410
+ status = CuSPARSEBindings.cusparseCreateDnVec(vec_y_ptr, y.size, y.device_ptr, value_type)
411
+ CuSPARSEBindings.check_status!(status, "Create dense vector Y")
412
+ vec_y = vec_y_ptr.read_pointer
413
+
414
+ # Prepare scalars
415
+ alpha_ptr = FFI::MemoryPointer.new(@dtype == :float32 ? :float : :double)
416
+ beta_ptr = FFI::MemoryPointer.new(@dtype == :float32 ? :float : :double)
417
+
418
+ if @dtype == :float32
419
+ alpha_ptr.put_float(0, alpha)
420
+ beta_ptr.put_float(0, beta)
421
+ else
422
+ alpha_ptr.put_double(0, alpha)
423
+ beta_ptr.put_double(0, beta)
424
+ end
425
+
426
+ # Get buffer size
427
+ buffer_size_ptr = FFI::MemoryPointer.new(:size_t)
428
+ op = transpose ? CuSPARSEBindings::CUSPARSE_OPERATION_TRANSPOSE : CuSPARSEBindings::CUSPARSE_OPERATION_NON_TRANSPOSE
429
+
430
+ status = CuSPARSEBindings.cusparseSpMV_bufferSize(
431
+ handle, op, alpha_ptr, sp_mat, vec_x, beta_ptr, vec_y,
432
+ value_type, CuSPARSEBindings::CUSPARSE_SPMV_ALG_DEFAULT, buffer_size_ptr
433
+ )
434
+ CuSPARSEBindings.check_status!(status, "Get SpMV buffer size")
435
+
436
+ buffer_size = buffer_size_ptr.read(:size_t)
437
+ buffer = buffer_size.positive? ? CUDA::Memory.new(buffer_size, device: @device_index) : nil
438
+
439
+ # Execute SpMV
440
+ status = CuSPARSEBindings.cusparseSpMV(
441
+ handle, op, alpha_ptr, sp_mat, vec_x, beta_ptr, vec_y,
442
+ value_type, CuSPARSEBindings::CUSPARSE_SPMV_ALG_DEFAULT,
443
+ buffer&.device_ptr || FFI::Pointer::NULL
444
+ )
445
+ CuSPARSEBindings.check_status!(status, "Execute SpMV")
446
+
447
+ # Cleanup
448
+ CuSPARSEBindings.cusparseDestroySpMat(sp_mat)
449
+ CuSPARSEBindings.cusparseDestroyDnVec(vec_x)
450
+ CuSPARSEBindings.cusparseDestroyDnVec(vec_y)
451
+ buffer&.free!
452
+ end
453
+ # rubocop:enable Metrics/AbcSize, Metrics/MethodLength
454
+ end
455
+ end
456
+ end
@@ -0,0 +1,218 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../linalg/cutensor_bindings"
4
+
5
+ module Ignis
6
+ module Tensor
7
+ # Tensor contraction operation using cuTENSOR
8
+ # Supports Einstein notation for expressing tensor contractions
9
+ #
10
+ # @example Matrix multiplication (ij,jk->ik)
11
+ # contraction = Contraction.new("ij,jk->ik", a, b)
12
+ # result = contraction.execute
13
+ # contraction.destroy!
14
+ class Contraction
15
+ # @return [String] Einstein notation expression
16
+ attr_reader :expression
17
+
18
+ # @return [NvArray] First input tensor
19
+ attr_reader :tensor_a
20
+
21
+ # @return [NvArray] Second input tensor
22
+ attr_reader :tensor_b
23
+
24
+ # @return [Array<Integer>] Output shape
25
+ attr_reader :output_shape
26
+
27
+ # Parse Einstein notation expression
28
+ # @param expression [String] Expression like "ij,jk->ik"
29
+ # @return [Hash] Parsed components
30
+ def self.parse_expression(expression)
31
+ unless expression.include?(",") && expression.include?("->")
32
+ raise ArgumentError, "Invalid Einstein notation: #{expression}. Expected format: 'ij,jk->ik'"
33
+ end
34
+
35
+ input_part, output_modes = expression.split("->")
36
+ input_a_modes, input_b_modes = input_part.split(",")
37
+
38
+ {
39
+ input_a_modes: input_a_modes.chars,
40
+ input_b_modes: input_b_modes.chars,
41
+ output_modes: output_modes.chars,
42
+ contracted_modes: (input_a_modes.chars & input_b_modes.chars) - output_modes.chars
43
+ }
44
+ end
45
+
46
+ # Initialize tensor contraction
47
+ # @param expression [String] Einstein notation expression
48
+ # @param tensor_a [NvArray] First input tensor
49
+ # @param tensor_b [NvArray] Second input tensor
50
+ # @param alpha [Float] Scaling factor
51
+ # @param beta [Float] Scaling for output (for accumulation)
52
+ def initialize(expression, tensor_a, tensor_b, alpha: 1.0, beta: 0.0)
53
+ @expression = expression
54
+ @tensor_a = tensor_a
55
+ @tensor_b = tensor_b
56
+ @alpha = alpha
57
+ @beta = beta
58
+ @parsed = self.class.parse_expression(expression)
59
+ @planned = false
60
+ @handle = nil
61
+ @plan = nil
62
+
63
+ validate_inputs!
64
+ compute_output_shape!
65
+ end
66
+
67
+ # Execute the tensor contraction
68
+ # @return [NvArray] Result tensor
69
+ def execute
70
+ # Ensure tensors are on device
71
+ a = @tensor_a.on_device? ? @tensor_a : @tensor_a.to_device
72
+ b = @tensor_b.on_device? ? @tensor_b : @tensor_b.to_device
73
+
74
+ # Create output tensor
75
+ output = NvArray.zeros(@output_shape, dtype: a.dtype, device: a.device_index)
76
+ output = output.to_device unless output.on_device?
77
+
78
+ # For now, use optimized path for common patterns
79
+ # Full cuTENSOR integration will use the plan/execute pattern
80
+ if matrix_multiply?
81
+ execute_as_matmul(a, b, output)
82
+ else
83
+ execute_general_contraction(a, b, output)
84
+ end
85
+
86
+ output
87
+ end
88
+
89
+ # Free cuTENSOR resources
90
+ # @return [void]
91
+ def destroy!
92
+ @handle = nil
93
+ @plan = nil
94
+ @planned = false
95
+ end
96
+
97
+ private
98
+
99
+ # Validate input tensors
100
+ def validate_inputs!
101
+ unless @tensor_a.is_a?(NvArray) && @tensor_b.is_a?(NvArray)
102
+ raise ArgumentError, "Expected NvArray tensors"
103
+ end
104
+
105
+ unless @tensor_a.dtype == @tensor_b.dtype
106
+ raise ArgumentError, "Tensor dtypes must match: #{@tensor_a.dtype} vs #{@tensor_b.dtype}"
107
+ end
108
+
109
+ a_modes = @parsed[:input_a_modes]
110
+ b_modes = @parsed[:input_b_modes]
111
+
112
+ unless @tensor_a.ndim == a_modes.size
113
+ raise DimensionError, "Tensor A has #{@tensor_a.ndim} dims but expression specifies #{a_modes.size} modes"
114
+ end
115
+
116
+ unless @tensor_b.ndim == b_modes.size
117
+ raise DimensionError, "Tensor B has #{@tensor_b.ndim} dims but expression specifies #{b_modes.size} modes"
118
+ end
119
+ end
120
+
121
+ # Compute output tensor shape based on expression
122
+ def compute_output_shape!
123
+ mode_extents = {}
124
+
125
+ # Map modes to extents from input tensors
126
+ @parsed[:input_a_modes].each_with_index do |mode, idx|
127
+ mode_extents[mode] = @tensor_a.shape[idx]
128
+ end
129
+
130
+ @parsed[:input_b_modes].each_with_index do |mode, idx|
131
+ if mode_extents[mode] && mode_extents[mode] != @tensor_b.shape[idx]
132
+ raise DimensionError, "Mode '#{mode}' has inconsistent extents: #{mode_extents[mode]} vs #{@tensor_b.shape[idx]}"
133
+ end
134
+ mode_extents[mode] = @tensor_b.shape[idx]
135
+ end
136
+
137
+ # Build output shape
138
+ @output_shape = @parsed[:output_modes].map do |mode|
139
+ extent = mode_extents[mode]
140
+ raise ArgumentError, "Output mode '#{mode}' not found in inputs" unless extent
141
+
142
+ extent
143
+ end
144
+ end
145
+
146
+ # Check if this is a simple matrix multiplication
147
+ # @return [Boolean]
148
+ def matrix_multiply?
149
+ @expression == "ij,jk->ik" || @expression == "ik,kj->ij"
150
+ end
151
+
152
+ # Execute as matrix multiplication using cuBLAS (faster for 2D)
153
+ def execute_as_matmul(a, b, output)
154
+ # Use cuBLAS for matrix multiply - much faster
155
+ c = LinAlg::Matmul.call(a, b, c: output, alpha: @alpha, beta: @beta)
156
+ # Copy result back if needed
157
+ output.instance_variable_set(:@device_ptr, c.device_ptr) if c != output
158
+ end
159
+
160
+ # Execute general tensor contraction using cuTENSOR
161
+ def execute_general_contraction(a, b, output)
162
+ # For general contractions, we use a transpose + reshape + matmul approach
163
+ # This is the "TTGT" (Transpose-Transpose-GEMM-Transpose) algorithm
164
+ # Full cuTENSOR integration would use cutensorContract directly
165
+
166
+ # For MVP, fall back to matmul by reshaping
167
+ # This handles many common cases efficiently
168
+ perform_ttgt_contraction(a, b, output)
169
+ end
170
+
171
+ # TTGT: Transpose-Transpose-GEMM-Transpose algorithm for general contractions
172
+ def perform_ttgt_contraction(a, b, output)
173
+ contracted = @parsed[:contracted_modes]
174
+ a_modes = @parsed[:input_a_modes]
175
+ b_modes = @parsed[:input_b_modes]
176
+ out_modes = @parsed[:output_modes]
177
+
178
+ # Find free indices for A and B
179
+ a_free = a_modes - contracted
180
+ b_free = b_modes - contracted
181
+
182
+ # For simple cases, reshape and use matmul
183
+ if contracted.size == 1 && a_free.size == 1 && b_free.size == 1
184
+ # Standard matrix multiply pattern
185
+ m = a.shape[a_modes.index(a_free[0])]
186
+ k = a.shape[a_modes.index(contracted[0])]
187
+ n = b.shape[b_modes.index(b_free[0])]
188
+
189
+ # Reshape if needed and do matmul
190
+ a_2d = a.reshape([m, k])
191
+ b_2d = b.reshape([k, n])
192
+
193
+ result = LinAlg::Matmul.call(a_2d, b_2d, alpha: @alpha, beta: @beta)
194
+
195
+ # Reshape output to correct shape
196
+ if result.shape != @output_shape
197
+ result = result.reshape(@output_shape)
198
+ end
199
+
200
+ # Copy to output
201
+ copy_to_output(result, output)
202
+ else
203
+ # More complex contractions - requires full cuTENSOR
204
+ # For now, raise an informative error
205
+ raise NotImplementedError, "Complex contraction '#{@expression}' requires full cuTENSOR integration. " \
206
+ "Supported patterns: 'ij,jk->ik' (matmul), single-index contractions."
207
+ end
208
+ end
209
+
210
+ # Copy result data to output tensor
211
+ def copy_to_output(source, dest)
212
+ # Direct device memory copy
213
+ byte_size = source.size * DType.byte_size(source.dtype)
214
+ CUDA::RuntimeAPI.memcpy_device_to_device(dest.device_ptr, source.device_ptr, byte_size)
215
+ end
216
+ end
217
+ end
218
+ end