whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,14 @@
1
+ #pragma once
2
+ #include <stdint.h>
3
+ #include <stdbool.h>
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9
+ const void *, int64_t, void *, int64_t, int, int,
10
+ int, int, int);
11
+
12
+ #ifdef __cplusplus
13
+ }
14
+ #endif
@@ -0,0 +1,14 @@
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda.h>
5
+ #include <cublas_v2.h>
6
+ #include <cuda_fp16.h>
7
+
8
+ #if CUDART_VERSION < 11020
9
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
10
+ #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
11
+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
12
+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
13
+ #define cublasComputeType_t cudaDataType_t
14
+ #endif // CUDART_VERSION < 11020
@@ -0,0 +1,186 @@
1
+ #pragma once
2
+
3
+ #include <hip/hip_runtime.h>
4
+ #include <hipblas/hipblas.h>
5
+ #include <hip/hip_fp16.h>
6
+ #ifdef __HIP_PLATFORM_AMD__
7
+ // for rocblas_initialize()
8
+ #include "rocblas/rocblas.h"
9
+ #endif // __HIP_PLATFORM_AMD__
10
+ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
11
+ #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
12
+ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
13
+ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
15
+ #define CUBLAS_OP_N HIPBLAS_OP_N
16
+ #define CUBLAS_OP_T HIPBLAS_OP_T
17
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
18
+ #define CUBLAS_TF32_TENSOR_OP_MATH 0
19
+ #define CUDA_R_16F HIPBLAS_R_16F
20
+ #define CUDA_R_32F HIPBLAS_R_32F
21
+ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
22
+ #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
23
+ #define cublasCreate hipblasCreate
24
+ #define cublasDestroy hipblasDestroy
25
+ #define cublasGemmEx hipblasGemmEx
26
+ #define cublasGemmBatchedEx hipblasGemmBatchedEx
27
+ #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
28
+ #define cublasHandle_t hipblasHandle_t
29
+ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
30
+ #define cublasSetStream hipblasSetStream
31
+ #define cublasSgemm hipblasSgemm
32
+ #define cublasStatus_t hipblasStatus_t
33
+ #define cublasOperation_t hipblasOperation_t
34
+ #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
35
+ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
36
+ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
37
+ #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
38
+ #define cudaDeviceProp hipDeviceProp_t
39
+ #define cudaDeviceSynchronize hipDeviceSynchronize
40
+ #define cudaError_t hipError_t
41
+ #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
42
+ #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
43
+ #define cudaEventCreateWithFlags hipEventCreateWithFlags
44
+ #define cudaEventDisableTiming hipEventDisableTiming
45
+ #define cudaEventRecord hipEventRecord
46
+ #define cudaEventSynchronize hipEventSynchronize
47
+ #define cudaEvent_t hipEvent_t
48
+ #define cudaEventDestroy hipEventDestroy
49
+ #define cudaFree hipFree
50
+ #define cudaFreeHost hipHostFree
51
+ #define cudaGetDevice hipGetDevice
52
+ #define cudaGetDeviceCount hipGetDeviceCount
53
+ #define cudaGetDeviceProperties hipGetDeviceProperties
54
+ #define cudaGetErrorString hipGetErrorString
55
+ #define cudaGetLastError hipGetLastError
56
+ #define cudaHostRegister hipHostRegister
57
+ #define cudaHostRegisterPortable hipHostRegisterPortable
58
+ #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
59
+ #define cudaHostUnregister hipHostUnregister
60
+ #define cudaLaunchHostFunc hipLaunchHostFunc
61
+ #define cudaMalloc hipMalloc
62
+ #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
63
+ #define cudaMemcpy hipMemcpy
64
+ #define cudaMemcpyAsync hipMemcpyAsync
65
+ #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
66
+ #define cudaMemcpy2DAsync hipMemcpy2DAsync
67
+ #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
68
+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
69
+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
70
+ #define cudaMemcpyKind hipMemcpyKind
71
+ #define cudaMemset hipMemset
72
+ #define cudaMemsetAsync hipMemsetAsync
73
+ #define cudaMemGetInfo hipMemGetInfo
74
+ #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
75
+ #define cudaSetDevice hipSetDevice
76
+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
77
+ #define cudaStreamDestroy hipStreamDestroy
78
+ #define cudaStreamFireAndForget hipStreamFireAndForget
79
+ #define cudaStreamNonBlocking hipStreamNonBlocking
80
+ #define cudaStreamPerThread hipStreamPerThread
81
+ #define cudaStreamSynchronize hipStreamSynchronize
82
+ #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
83
+ #define cudaStream_t hipStream_t
84
+ #define cudaSuccess hipSuccess
85
+ #define __trap() do { abort(); __builtin_unreachable(); } while(0)
86
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
87
+ #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
88
+ #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
89
+ #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
90
+ #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
91
+ #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
92
+ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
93
+ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
94
+ #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
95
+
96
+ #define __CUDA_ARCH__ 1300
97
+
98
+ #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
99
+ #define GCN
100
+ #endif
101
+
102
+ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
103
+ #define CDNA
104
+ #endif
105
+
106
+ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
107
+ defined(__gfx1150__) || defined(__gfx1151__)
108
+ #define RDNA3
109
+ #endif
110
+
111
+ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
112
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
113
+ #define RDNA2
114
+ #endif
115
+
116
+ #if defined(__gfx1010__) || defined(__gfx1012__)
117
+ #define RDNA1
118
+ #endif
119
+
120
+ #ifndef __has_builtin
121
+ #define __has_builtin(x) 0
122
+ #endif
123
+
124
+ typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
125
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
126
+ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
127
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
128
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
129
+ #if __has_builtin(__builtin_elementwise_sub_sat)
130
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
131
+ return reinterpret_cast<const int &>(c);
132
+ #else
133
+ int8x4_t c;
134
+ int16_t tmp;
135
+ #pragma unroll
136
+ for (int i = 0; i < 4; i++) {
137
+ tmp = va[i] - vb[i];
138
+ if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
139
+ if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
140
+ c[i] = tmp;
141
+ }
142
+ return reinterpret_cast<int &>(c);
143
+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
144
+ }
145
+
146
+ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
147
+ return __vsubss4(a, b);
148
+ }
149
+
150
+ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
151
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
152
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
153
+ unsigned int c;
154
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
155
+ #pragma unroll
156
+ for (int i = 0; i < 4; ++i) {
157
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
158
+ }
159
+ return c;
160
+ }
161
+
162
+ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
163
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
164
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
165
+ unsigned int c;
166
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
167
+ #pragma unroll
168
+ for (int i = 0; i < 4; ++i) {
169
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
170
+ }
171
+ return c;
172
+ }
173
+
174
+ #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
175
+ // __shfl_xor() for half2 was added in ROCm 5.6
176
+ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
177
+ typedef union half2_b32 {
178
+ half2 val;
179
+ int b32;
180
+ } half2_b32_t;
181
+ half2_b32_t tmp;
182
+ tmp.val = var;
183
+ tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
184
+ return tmp.val;
185
+ }
186
+ #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
@@ -0,0 +1,134 @@
1
+ #pragma once
2
+
3
+ #include <musa_runtime.h>
4
+ #include <musa.h>
5
+ #include <mublas.h>
6
+ #include <musa_fp16.h>
7
+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
8
+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
9
+ #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
10
+ #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
11
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
12
+ #define CUBLAS_OP_N MUBLAS_OP_N
13
+ #define CUBLAS_OP_T MUBLAS_OP_T
14
+ #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
15
+ #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
16
+ #define CUDA_R_16F MUSA_R_16F
17
+ #define CUDA_R_32F MUSA_R_32F
18
+ #define cublasComputeType_t cudaDataType_t
19
+ #define cublasCreate mublasCreate
20
+ #define cublasDestroy mublasDestroy
21
+ #define cublasGemmEx mublasGemmEx
22
+ #define cublasGemmBatchedEx mublasGemmBatchedEx
23
+ #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
24
+ #define cublasHandle_t mublasHandle_t
25
+ #define cublasSetMathMode mublasSetMathMode
26
+ #define cublasSetStream mublasSetStream
27
+ #define cublasSgemm mublasSgemm
28
+ #define cublasStatus_t mublasStatus_t
29
+ #define cublasOperation_t mublasOperation_t
30
+ #define cublasGetStatusString mublasStatus_to_string
31
+ #define cudaDataType_t musaDataType_t
32
+ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
33
+ #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
34
+ #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
35
+ #define cudaDeviceProp musaDeviceProp
36
+ #define cudaDeviceSynchronize musaDeviceSynchronize
37
+ #define cudaError_t musaError_t
38
+ #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
39
+ #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
40
+ #define cudaEventCreateWithFlags musaEventCreateWithFlags
41
+ #define cudaEventDisableTiming musaEventDisableTiming
42
+ #define cudaEventRecord musaEventRecord
43
+ #define cudaEventSynchronize musaEventSynchronize
44
+ #define cudaEvent_t musaEvent_t
45
+ #define cudaEventDestroy musaEventDestroy
46
+ #define cudaFree musaFree
47
+ #define cudaFreeHost musaFreeHost
48
+ #define cudaGetDevice musaGetDevice
49
+ #define cudaGetDeviceCount musaGetDeviceCount
50
+ #define cudaGetDeviceProperties musaGetDeviceProperties
51
+ #define cudaGetErrorString musaGetErrorString
52
+ #define cudaGetLastError musaGetLastError
53
+ #define cudaHostRegister musaHostRegister
54
+ #define cudaHostRegisterPortable musaHostRegisterPortable
55
+ #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
56
+ #define cudaHostUnregister musaHostUnregister
57
+ #define cudaLaunchHostFunc musaLaunchHostFunc
58
+ #define cudaMalloc musaMalloc
59
+ #define cudaMallocHost musaMallocHost
60
+ #define cudaMallocManaged musaMallocManaged
61
+ #define cudaMemcpy musaMemcpy
62
+ #define cudaMemcpyAsync musaMemcpyAsync
63
+ #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
64
+ #define cudaMemcpy2DAsync musaMemcpy2DAsync
65
+ #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
66
+ #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
67
+ #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
68
+ #define cudaMemcpyKind musaMemcpyKind
69
+ #define cudaMemset musaMemset
70
+ #define cudaMemsetAsync musaMemsetAsync
71
+ #define cudaMemGetInfo musaMemGetInfo
72
+ #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
73
+ #define cudaSetDevice musaSetDevice
74
+ #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
75
+ #define cudaStreamDestroy musaStreamDestroy
76
+ #define cudaStreamFireAndForget musaStreamFireAndForget
77
+ #define cudaStreamNonBlocking musaStreamNonBlocking
78
+ #define cudaStreamPerThread musaStreamPerThread
79
+ #define cudaStreamSynchronize musaStreamSynchronize
80
+ #define cudaStreamWaitEvent musaStreamWaitEvent
81
+ #define cudaStream_t musaStream_t
82
+ #define cudaSuccess musaSuccess
83
+
84
+ // Additional mappings for MUSA virtual memory pool
85
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
86
+ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
87
+ #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
88
+ #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
89
+ #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
90
+ #define CUdevice MUdevice
91
+ #define CUdeviceptr MUdeviceptr
92
+ #define CUmemAccessDesc MUmemAccessDesc
93
+ #define CUmemAllocationProp MUmemAllocationProp
94
+ #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
95
+ #define cuDeviceGet muDeviceGet
96
+ #define cuDeviceGetAttribute muDeviceGetAttribute
97
+ #define cuMemAddressFree muMemAddressFree
98
+ #define cuMemAddressReserve muMemAddressReserve
99
+ #define cuMemCreate muMemCreate
100
+ #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
101
+ #define cuMemMap muMemMap
102
+ #define cuMemRelease muMemRelease
103
+ #define cuMemSetAccess muMemSetAccess
104
+ #define cuMemUnmap muMemUnmap
105
+ #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
106
+ #define cudaFuncSetAttribute musaFuncSetAttribute
107
+ #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
108
+ #define make_cudaExtent make_musaExtent
109
+ #define make_cudaPitchedPtr make_musaPitchedPtr
110
+
111
+ // Additional mappings for MUSA graphs
112
+ #define CUDA_SUCCESS MUSA_SUCCESS
113
+ #define CUresult MUresult
114
+ #define cuGetErrorString muGetErrorString
115
+ #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
116
+ #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
117
+ #define cudaGraphDestroy musaGraphDestroy
118
+ #define cudaGraphExecDestroy musaGraphExecDestroy
119
+ #define cudaGraphExec_t musaGraphExec_t
120
+ #define cudaGraphExecUpdate musaGraphExecUpdate
121
+ #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122
+ #define cudaGraphGetNodes musaGraphGetNodes
123
+ #define cudaGraphInstantiate musaGraphInstantiate
124
+ #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
125
+ #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
126
+ #define cudaGraphLaunch musaGraphLaunch
127
+ #define cudaGraphNodeGetType musaGraphNodeGetType
128
+ #define cudaGraphNode_t musaGraphNode_t
129
+ #define cudaGraphNodeType musaGraphNodeType
130
+ #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
131
+ #define cudaGraph_t musaGraph_t
132
+ #define cudaKernelNodeParams musaKernelNodeParams
133
+ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
134
+ #define cudaStreamEndCapture musaStreamEndCapture