ignis 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +15 -0
  3. data/lib/ignis.rb +94 -0
  4. data/lib/nnw/platform.rb +304 -0
  5. data/lib/nnw/shared/event_bus.rb +240 -0
  6. data/lib/nnw/shared/ffi_loader.rb +63 -0
  7. data/lib/nnw/shared/memory_contract.rb +204 -0
  8. data/lib/nnw/shared/nv_array.rb +710 -0
  9. data/lib/nnw/shared/recovery_protocol.rb +307 -0
  10. data/lib/nvruby/configuration.rb +217 -0
  11. data/lib/nvruby/cuda/device.rb +275 -0
  12. data/lib/nvruby/cuda/device_props.rb +202 -0
  13. data/lib/nvruby/cuda/graph.rb +265 -0
  14. data/lib/nvruby/cuda/graph_bindings.rb +119 -0
  15. data/lib/nvruby/cuda/library_loader.rb +285 -0
  16. data/lib/nvruby/cuda/memory.rb +410 -0
  17. data/lib/nvruby/cuda/runtime_api.rb +804 -0
  18. data/lib/nvruby/cuda/stream.rb +234 -0
  19. data/lib/nvruby/dtype.rb +139 -0
  20. data/lib/nvruby/epilogues.rb +438 -0
  21. data/lib/nvruby/errors.rb +303 -0
  22. data/lib/nvruby/half.rb +97 -0
  23. data/lib/nvruby/jit/compiled_kernel.rb +80 -0
  24. data/lib/nvruby/jit/compiler.rb +231 -0
  25. data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
  26. data/lib/nvruby/jit/kernel.rb +240 -0
  27. data/lib/nvruby/jit/kernel_module.rb +133 -0
  28. data/lib/nvruby/jit/kernels/activations.rb +179 -0
  29. data/lib/nvruby/jit/kernels/attention.rb +504 -0
  30. data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
  31. data/lib/nvruby/jit/kernels/loss.rb +213 -0
  32. data/lib/nvruby/jit/kernels/normalization.rb +200 -0
  33. data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
  34. data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
  35. data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
  36. data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
  37. data/lib/nvruby/linalg/epilog.rb +67 -0
  38. data/lib/nvruby/linalg/matmul.rb +247 -0
  39. data/lib/nvruby/linalg/matmul_plan.rb +229 -0
  40. data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
  41. data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
  42. data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
  43. data/lib/nvruby/memory/device_memory_resource.rb +106 -0
  44. data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
  45. data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
  46. data/lib/nvruby/memory/stats.rb +107 -0
  47. data/lib/nvruby/memory.rb +124 -0
  48. data/lib/nvruby/version.rb +5 -0
  49. metadata +108 -0
@@ -0,0 +1,247 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module LinAlg
5
+ # Matrix multiplication operations using cuBLAS
6
+ module Matmul
7
+ class << self
8
+ # Perform matrix multiplication: C = alpha * A @ B + beta * C
9
+ # @param a [NvArray] Left matrix
10
+ # @param b [NvArray] Right matrix
11
+ # @param c [NvArray, nil] Output matrix (created if nil)
12
+ # @param alpha [Float] Scaling factor for A @ B
13
+ # @param beta [Float] Scaling factor for C
14
+ # @param transpose_a [Boolean] Transpose A
15
+ # @param transpose_b [Boolean] Transpose B
16
+ # @param stream [CUDA::Stream, nil] CUDA stream
17
+ # @return [NvArray] Result matrix
18
+ def call(a, b, c: nil, alpha: 1.0, beta: 0.0, transpose_a: false, transpose_b: false, stream: nil)
19
+ validate_inputs!(a, b)
20
+
21
+ # Get dimensions
22
+ m, k1, k2, n = compute_dimensions(a, b, transpose_a, transpose_b)
23
+ raise DimensionError, "Matrix dimensions incompatible: A(#{a.shape}) @ B(#{b.shape})" unless k1 == k2
24
+
25
+ # Ensure arrays are on device
26
+ a = a.to_device unless a.on_device?
27
+ b = b.to_device unless b.on_device?
28
+
29
+ # Create or validate output
30
+ c = prepare_output(c, m, n, a)
31
+ c = c.to_device unless c.on_device?
32
+
33
+ # Perform GEMM
34
+ execute_gemm(a, b, c, m, n, k1, alpha, beta, transpose_a, transpose_b, stream)
35
+
36
+ c
37
+ end
38
+
39
+ # Shorthand for matrix multiplication: C = A @ B
40
+ # @param a [NvArray] Left matrix
41
+ # @param b [NvArray] Right matrix
42
+ # @return [NvArray] Result matrix
43
+ def matmul(a, b)
44
+ call(a, b)
45
+ end
46
+
47
+ # Perform GEMM with a specific algorithm (useful for benchmarking/perf tuning)
48
+ def call_with_algo(a, b, algo, c: nil, alpha: 1.0, beta: 0.0, transpose_a: false, transpose_b: false, stream: nil)
49
+ validate_inputs!(a, b)
50
+ m, k1, k2, n = compute_dimensions(a, b, transpose_a, transpose_b)
51
+ raise DimensionError, "Matrix dimensions incompatible" unless k1 == k2
52
+
53
+ a = a.to_device unless a.on_device?
54
+ b = b.to_device unless b.on_device?
55
+ c = prepare_output(c, m, n, a)
56
+ c = c.to_device unless c.on_device?
57
+
58
+ CuBLASBindings.ensure_loaded!
59
+ handle = CuBLASBindings.get_handle
60
+
61
+ if stream
62
+ CuBLASBindings.check_status!(CuBLASBindings.cublasSetStream_v2(handle, stream.handle))
63
+ end
64
+
65
+ op_a = transpose_b ? CuBLASBindings::CUBLAS_OP_T : CuBLASBindings::CUBLAS_OP_N
66
+ op_b = transpose_a ? CuBLASBindings::CUBLAS_OP_T : CuBLASBindings::CUBLAS_OP_N
67
+
68
+ lda = b.shape[1]
69
+ ldb = a.shape[1]
70
+ ldc = n
71
+
72
+ status = execute_gemmex(handle, op_a, op_b, n, m, k1, alpha, a, b, c, lda, ldb, ldc, algo, beta)
73
+ CuBLASBindings.check_status!(status, "GEMM execution with algorithm #{algo}")
74
+ c
75
+ end
76
+
77
+ private
78
+
79
+ # Accept Ignis::NvArray or any array-like exposing the device-pointer API
80
+ # (e.g. Ignis::Shared::NvArray from the AI stack).
81
+ def array_like?(x)
82
+ # Duck-typed: any GPU array exposing the device-pointer API works
83
+ # (Ignis::Shared::NvArray, or Ignis::NvArray from ignis-numerics). The
84
+ # core no longer hard-references the numerics NvArray class.
85
+ x.respond_to?(:device_ffi_ptr)
86
+ end
87
+
88
+ # Validate input arrays
89
+ def validate_inputs!(a, b)
90
+ raise ArgumentError, "Expected NvArray, got #{a.class}" unless array_like?(a)
91
+ raise ArgumentError, "Expected NvArray, got #{b.class}" unless array_like?(b)
92
+ raise DimensionError, "Matrix A must be 2D, got #{a.ndim}D" unless a.ndim == 2
93
+ raise DimensionError, "Matrix B must be 2D, got #{b.ndim}D" unless b.ndim == 2
94
+ raise ArgumentError, "Data types must match: #{a.dtype} vs #{b.dtype}" unless a.dtype == b.dtype
95
+
96
+ dtype = a.dtype
97
+ unless %i[float16 bfloat16 float32 float64 complex64 complex128].include?(dtype)
98
+ raise UnsupportedDTypeError.new(dtype, operation: "matrix multiplication")
99
+ end
100
+ end
101
+
102
+ # GEMM algorithm constants
103
+ CUBLAS_GEMM_DEFAULT = -1
104
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99
105
+ CUBLAS_GEMM_FP16_OPTIMIZED = 1 # Peak for RTX 3060 FP16
106
+ CUBLAS_GEMM_FP32_OPTIMIZED = 101 # Peak for RTX 3060 FP32 TF32
107
+
108
+ # Compute dimensions considering transposition
109
+ # @return [Array<Integer>] [m, k1, k2, n]
110
+ def compute_dimensions(a, b, transpose_a, transpose_b)
111
+ m = transpose_a ? a.shape[1] : a.shape[0]
112
+ k1 = transpose_a ? a.shape[0] : a.shape[1]
113
+ k2 = transpose_b ? b.shape[1] : b.shape[0]
114
+ n = transpose_b ? b.shape[0] : b.shape[1]
115
+ [m, k1, k2, n]
116
+ end
117
+
118
+ # Prepare output array. Allocates an output of the SAME class as the
119
+ # input template, so the AI stack (Shared::NvArray) gets a Shared::NvArray
120
+ # back rather than a foreign Ignis::NvArray.
121
+ def prepare_output(c, m, n, template)
122
+ dtype = template.dtype
123
+ if c
124
+ raise DimensionError, "Output shape must be [#{m}, #{n}], got #{c.shape}" unless c.shape == [m, n]
125
+ raise ArgumentError, "Output dtype must be #{dtype}, got #{c.dtype}" unless c.dtype == dtype
126
+
127
+ c
128
+ elsif defined?(NvArray) && template.is_a?(NvArray)
129
+ NvArray.zeros([m, n], dtype: dtype, device: template.device_index)
130
+ else
131
+ # Array-like (e.g. Ignis::Shared::NvArray). beta defaults to 0 so cuBLAS
132
+ # overwrites C; we only need it allocated, not zeroed.
133
+ out = template.class.new(shape: [m, n], dtype: dtype, device_id: template.device_index)
134
+ out.to_device
135
+ end
136
+ end
137
+
138
+ # Execute GEMM operation via cuBLAS
139
+ # rubocop:disable Metrics/ParameterLists, Metrics/AbcSize, Metrics/MethodLength
140
+ def execute_gemm(a, b, c, m, n, k, alpha, beta, transpose_a, transpose_b, stream)
141
+ CuBLASBindings.ensure_loaded!
142
+ handle = CuBLASBindings.get_handle
143
+
144
+ # Ensure the handle is always associated with the correct stream.
145
+ # If stream is nil, we MUST reset it to 0 (default stream) because the handle
146
+ # might be holding a stale reference to a destroyed stream from a previous capture.
147
+ stream_handle = stream&.handle || FFI::Pointer::NULL
148
+ status = CuBLASBindings.cublasSetStream_v2(handle, stream_handle)
149
+ CuBLASBindings.check_status!(status, "Set cuBLAS stream")
150
+
151
+ # cuBLAS uses column-major, so we compute B^T @ A^T = (A @ B)^T
152
+ # by swapping A and B and transposition
153
+ op_a = transpose_b ? CuBLASBindings::CUBLAS_OP_T : CuBLASBindings::CUBLAS_OP_N
154
+ op_b = transpose_a ? CuBLASBindings::CUBLAS_OP_T : CuBLASBindings::CUBLAS_OP_N
155
+
156
+ # Leading dimensions
157
+ lda = b.shape[1] # Leading dimension of B (swapped)
158
+ ldb = a.shape[1] # Leading dimension of A (swapped)
159
+ ldc = n # Leading dimension of C
160
+
161
+ # Prepare scalar pointers
162
+ case a.dtype
163
+ when :float32
164
+ # Use GemmEx for FP32 to enable TF32 and optimized algorithms
165
+ status = execute_gemmex(handle, op_a, op_b, n, m, k,
166
+ alpha, a, b, c, lda, ldb, ldc, CUBLAS_GEMM_FP32_OPTIMIZED, beta)
167
+ when :float64
168
+ alpha_ptr = FFI::MemoryPointer.new(:double).tap { |p| p.put_double(0, alpha) }
169
+ beta_ptr = FFI::MemoryPointer.new(:double).tap { |p| p.put_double(0, beta) }
170
+
171
+ status = CuBLASBindings.cublasDgemm_v2(
172
+ handle, op_a, op_b,
173
+ n, m, k,
174
+ alpha_ptr, b.device_ffi_ptr, lda,
175
+ a.device_ffi_ptr, ldb,
176
+ beta_ptr, c.device_ffi_ptr, ldc
177
+ )
178
+ when :complex64
179
+ # Complex alpha/beta need 2 floats each
180
+ alpha_ptr = FFI::MemoryPointer.new(:float, 2)
181
+ alpha_ptr.put_float(0, alpha.is_a?(Complex) ? alpha.real : alpha)
182
+ alpha_ptr.put_float(4, alpha.is_a?(Complex) ? alpha.imag : 0.0)
183
+ beta_ptr = FFI::MemoryPointer.new(:float, 2)
184
+ beta_ptr.put_float(0, beta.is_a?(Complex) ? beta.real : beta)
185
+ beta_ptr.put_float(4, beta.is_a?(Complex) ? beta.imag : 0.0)
186
+
187
+ status = CuBLASBindings.cublasCgemm_v2(
188
+ handle, op_a, op_b,
189
+ n, m, k,
190
+ alpha_ptr, b.device_ffi_ptr, lda,
191
+ a.device_ffi_ptr, ldb,
192
+ beta_ptr, c.device_ffi_ptr, ldc
193
+ )
194
+ when :complex128
195
+ alpha_ptr = FFI::MemoryPointer.new(:double, 2)
196
+ alpha_ptr.put_double(0, alpha.is_a?(Complex) ? alpha.real : alpha)
197
+ alpha_ptr.put_double(8, alpha.is_a?(Complex) ? alpha.imag : 0.0)
198
+ beta_ptr = FFI::MemoryPointer.new(:double, 2)
199
+ beta_ptr.put_double(0, beta.is_a?(Complex) ? beta.real : beta)
200
+ beta_ptr.put_double(8, beta.is_a?(Complex) ? beta.imag : 0.0)
201
+
202
+ status = CuBLASBindings.cublasZgemm_v2(
203
+ handle, op_a, op_b,
204
+ n, m, k,
205
+ alpha_ptr, b.device_ffi_ptr, lda,
206
+ a.device_ffi_ptr, ldb,
207
+ beta_ptr, c.device_ffi_ptr, ldc
208
+ )
209
+ when :float16, :bfloat16
210
+ # Use cublasGemmEx for half-precision with optimized algorithm
211
+ status = execute_gemmex(handle, op_a, op_b, n, m, k,
212
+ alpha, a, b, c, lda, ldb, ldc, CUBLAS_GEMM_FP16_OPTIMIZED, beta)
213
+ end
214
+
215
+ CuBLASBindings.check_status!(status, "GEMM execution")
216
+ end
217
+ # rubocop:enable Metrics/ParameterLists, Metrics/AbcSize, Metrics/MethodLength
218
+
219
+ # Execute GemmEx for optimized operations
220
+ def execute_gemmex(handle, op_a, op_b, m, n, k, alpha, a, b, c, lda, ldb, ldc, algo, beta = 0.0)
221
+ # Alpha and beta are always FP32 when using COMPUTE_32F modes
222
+ alpha_ptr = FFI::MemoryPointer.new(:float).tap { |p| p.put_float(0, alpha.to_f) }
223
+ beta_ptr = FFI::MemoryPointer.new(:float).tap { |p| p.put_float(0, beta.to_f) }
224
+
225
+ a_cuda_type = DType.cublas_type(a.dtype)
226
+ b_cuda_type = DType.cublas_type(b.dtype)
227
+ c_cuda_type = DType.cublas_type(c.dtype)
228
+
229
+ compute_type = case a.dtype
230
+ when :float16 then CuBLASBindings::CUBLAS_COMPUTE_32F_FAST_16F
231
+ when :bfloat16 then CuBLASBindings::CUBLAS_COMPUTE_32F_FAST_16BF
232
+ when :float32 then CuBLASBindings::CUBLAS_COMPUTE_32F_FAST_TF32
233
+ else CuBLASBindings::CUBLAS_COMPUTE_32F
234
+ end
235
+
236
+ CuBLASBindings.cublasGemmEx(
237
+ handle, op_a, op_b, m, n, k,
238
+ alpha_ptr, b.device_ffi_ptr, b_cuda_type, lda,
239
+ a.device_ffi_ptr, a_cuda_type, ldb,
240
+ beta_ptr, c.device_ffi_ptr, c_cuda_type, ldc,
241
+ compute_type, algo
242
+ )
243
+ end
244
+ end
245
+ end
246
+ end
247
+ end
@@ -0,0 +1,229 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module LinAlg
5
+ # Stateful matrix multiplication plan with autotuning
6
+ # Reusable for repeated operations with same dimensions
7
+ class MatmulPlan
8
+ # @return [Array<Integer>] Shape of matrix A
9
+ attr_reader :shape_a
10
+
11
+ # @return [Array<Integer>] Shape of matrix B
12
+ attr_reader :shape_b
13
+
14
+ # @return [Symbol] Data type
15
+ attr_reader :dtype
16
+
17
+ # @return [Hash] Plan options
18
+ attr_reader :options
19
+
20
+ # @return [Boolean] Whether plan has been autotuned
21
+ attr_reader :autotuned
22
+
23
+ # @param shape_a [Array<Integer>] Shape of matrix A [m, k]
24
+ # @param shape_b [Array<Integer>] Shape of matrix B [k, n]
25
+ # @param dtype [Symbol] Data type
26
+ # @param transpose_a [Boolean] Transpose A
27
+ # @param transpose_b [Boolean] Transpose B
28
+ # @param epilog [Symbol, nil] Epilog operation
29
+ # @param device [Integer, nil] Target device
30
+ def initialize(shape_a:, shape_b:, dtype: :float32, transpose_a: false, transpose_b: false,
31
+ epilog: nil, device: nil)
32
+ @shape_a = Array(shape_a)
33
+ @shape_b = Array(shape_b)
34
+ @dtype = DType.validate!(dtype)
35
+ @transpose_a = transpose_a
36
+ @transpose_b = transpose_b
37
+ @epilog = epilog
38
+ @device_index = device || Ignis.configuration.default_device
39
+
40
+ validate_shapes!
41
+
42
+ @options = {}
43
+ @autotuned = false
44
+ @best_algorithm = nil
45
+ @workspace = nil
46
+
47
+ @m, @k, @n = compute_dimensions
48
+ @execution_count = 0
49
+ end
50
+
51
+ # Output shape of the matmul operation
52
+ # @return [Array<Integer>]
53
+ def output_shape
54
+ [@m, @n]
55
+ end
56
+
57
+ # Plan the operation (find algorithms)
58
+ # @param workspace_size [Integer] Maximum workspace size in bytes
59
+ # @return [self]
60
+ def plan!(workspace_size: nil)
61
+ workspace_size ||= Ignis.configuration.default_workspace_size
62
+
63
+ CuBLASBindings.ensure_loaded!
64
+
65
+ Ignis.logger.debug { "Planning MatmulPlan for #{@shape_a} @ #{@shape_b} -> #{output_shape}" }
66
+
67
+ # For basic cuBLAS GEMM, planning is straightforward
68
+ # Advanced planning with cuBLASLt would involve algorithm selection
69
+ @options[:workspace_size] = workspace_size
70
+ @options[:planned] = true
71
+
72
+ Ignis.logger.info { "MatmulPlan planned: workspace=#{workspace_size} bytes" }
73
+
74
+ self
75
+ end
76
+
77
+ # Autotune the operation to find the best algorithm
78
+ # @param iterations [Integer] Number of benchmark iterations
79
+ # @param warmup [Integer] Number of warmup iterations
80
+ # @return [self]
81
+ def autotune!(iterations: nil, warmup: 3)
82
+ iterations ||= Ignis.configuration.autotuning_iterations
83
+
84
+ plan! unless @options[:planned]
85
+
86
+ Ignis.logger.info { "Autotuning MatmulPlan with #{iterations} iterations" }
87
+
88
+ # Create test arrays
89
+ a = NvArray.zeros(@shape_a, dtype: @dtype, device: @device_index)
90
+ b = NvArray.zeros(@shape_b, dtype: @dtype, device: @device_index)
91
+ c = NvArray.zeros(output_shape, dtype: @dtype, device: @device_index)
92
+
93
+ # Warmup
94
+ warmup.times { execute_internal(a, b, c) }
95
+ CUDA::Device.current.synchronize
96
+
97
+ # Benchmark
98
+ start_event = CUDA::Event.new
99
+ end_event = CUDA::Event.new
100
+
101
+ start_event.record
102
+ iterations.times { execute_internal(a, b, c) }
103
+ end_event.record
104
+ end_event.synchronize
105
+
106
+ elapsed_ms = CUDA::Event.elapsed_time(start_event, end_event)
107
+ avg_time = elapsed_ms / iterations
108
+
109
+ @options[:avg_time_ms] = avg_time
110
+ @autotuned = true
111
+
112
+ # Cleanup
113
+ start_event.destroy!
114
+ end_event.destroy!
115
+ a.free!
116
+ b.free!
117
+ c.free!
118
+
119
+ Ignis.logger.info { "MatmulPlan autotuned: avg_time=#{avg_time.round(3)}ms" }
120
+
121
+ self
122
+ end
123
+
124
+ # Execute the planned matrix multiplication
125
+ # @param a [NvArray] Left matrix
126
+ # @param b [NvArray] Right matrix
127
+ # @param c [NvArray, nil] Output matrix (created if nil)
128
+ # @param alpha [Float] Scaling factor for A @ B
129
+ # @param beta [Float] Scaling factor for C
130
+ # @param stream [CUDA::Stream, nil] CUDA stream
131
+ # @return [NvArray] Result matrix
132
+ def execute(a, b, c: nil, alpha: 1.0, beta: 0.0, stream: nil)
133
+ validate_execution_inputs!(a, b)
134
+
135
+ # Ensure on device
136
+ a = a.to_device(device: @device_index) unless a.on_device?
137
+ b = b.to_device(device: @device_index) unless b.on_device?
138
+
139
+ # Prepare output
140
+ if c
141
+ validate_output!(c)
142
+ c = c.to_device(device: @device_index) unless c.on_device?
143
+ else
144
+ c = NvArray.zeros(output_shape, dtype: @dtype, device: @device_index)
145
+ end
146
+
147
+ execute_internal(a, b, c, alpha, beta, stream)
148
+
149
+ c
150
+ end
151
+
152
+ # Get statistics about the plan
153
+ # @return [Hash]
154
+ def stats
155
+ {
156
+ shape_a: @shape_a,
157
+ shape_b: @shape_b,
158
+ output_shape: output_shape,
159
+ dtype: @dtype,
160
+ transpose_a: @transpose_a,
161
+ transpose_b: @transpose_b,
162
+ autotuned: @autotuned,
163
+ avg_time_ms: @options[:avg_time_ms],
164
+ execution_count: @execution_count
165
+ }
166
+ end
167
+
168
+ # @return [String]
169
+ def to_s
170
+ tuned = @autotuned ? "autotuned" : "not tuned"
171
+ "MatmulPlan(#{@shape_a} @ #{@shape_b} -> #{output_shape}, #{@dtype}, #{tuned})"
172
+ end
173
+
174
+ private
175
+
176
+ # Validate input shapes
177
+ def validate_shapes!
178
+ raise DimensionError, "Matrix A must be 2D" unless @shape_a.size == 2
179
+ raise DimensionError, "Matrix B must be 2D" unless @shape_b.size == 2
180
+
181
+ # Check K dimension compatibility
182
+ k_a = @transpose_a ? @shape_a[0] : @shape_a[1]
183
+ k_b = @transpose_b ? @shape_b[1] : @shape_b[0]
184
+
185
+ raise DimensionError, "Matrix dimensions incompatible: #{@shape_a} @ #{@shape_b}" unless k_a == k_b
186
+ end
187
+
188
+ # Compute output dimensions
189
+ # @return [Array<Integer>] [m, k, n]
190
+ def compute_dimensions
191
+ m = @transpose_a ? @shape_a[1] : @shape_a[0]
192
+ k = @transpose_a ? @shape_a[0] : @shape_a[1]
193
+ n = @transpose_b ? @shape_b[0] : @shape_b[1]
194
+ [m, k, n]
195
+ end
196
+
197
+ # Validate execution inputs
198
+ def validate_execution_inputs!(a, b)
199
+ raise ArgumentError, "Expected NvArray, got #{a.class}" unless a.is_a?(NvArray)
200
+ raise ArgumentError, "Expected NvArray, got #{b.class}" unless b.is_a?(NvArray)
201
+ raise DimensionError, "A shape mismatch: expected #{@shape_a}, got #{a.shape}" unless a.shape == @shape_a
202
+ raise DimensionError, "B shape mismatch: expected #{@shape_b}, got #{b.shape}" unless b.shape == @shape_b
203
+ raise ArgumentError, "A dtype mismatch: expected #{@dtype}, got #{a.dtype}" unless a.dtype == @dtype
204
+ raise ArgumentError, "B dtype mismatch: expected #{@dtype}, got #{b.dtype}" unless b.dtype == @dtype
205
+ end
206
+
207
+ # Validate output array
208
+ def validate_output!(c)
209
+ expected = output_shape
210
+ raise DimensionError, "Output shape mismatch: expected #{expected}, got #{c.shape}" unless c.shape == expected
211
+ raise ArgumentError, "Output dtype mismatch: expected #{@dtype}, got #{c.dtype}" unless c.dtype == @dtype
212
+ end
213
+
214
+ # Execute GEMM internally
215
+ def execute_internal(a, b, c, alpha = 1.0, beta = 0.0, stream = nil)
216
+ @execution_count += 1
217
+ Matmul.call(
218
+ a, b,
219
+ c: c,
220
+ alpha: alpha,
221
+ beta: beta,
222
+ transpose_a: @transpose_a,
223
+ transpose_b: @transpose_b,
224
+ stream: stream
225
+ )
226
+ end
227
+ end
228
+ end
229
+ end