whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,94 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ // hack until AMX is moved into the CPU backend
5
+ #include "../ggml-cpu/ggml-cpu-impl.h" // <immintrin.h>
6
+
7
+ #include <algorithm>
8
+ #include <memory>
9
+ #include <type_traits>
10
+
11
+ #if defined(_OPENMP)
12
+ #include <omp.h>
13
+ #endif
14
+
15
+ #define TILE_M 16
16
+ #define TILE_N 16
17
+ #define TILE_K 32
18
+ #define VNNI_BLK 4
19
+
20
+ #define AMX_BLK_SIZE 32
21
+
22
+ #define TMM0 0
23
+ #define TMM1 1
24
+ #define TMM2 2
25
+ #define TMM3 3
26
+ #define TMM4 4
27
+ #define TMM5 5
28
+ #define TMM6 6
29
+ #define TMM7 7
30
+
31
+ // parallel routines
32
+ template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
33
+ inline T div_up(T x, T y) { return (x + y - 1) / y; }
34
+
35
+ template <typename T>
36
+ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
37
+ #if 0
38
+ // onednn partition pattern
39
+ T& n_my = n_end;
40
+ if (nth <= 1 || n == 0) {
41
+ n_start = 0;
42
+ n_my = n;
43
+ } else {
44
+ T n1 = div_up(n, nth);
45
+ T n2 = n1 - 1;
46
+ T T1 = n - n2 * nth;
47
+ n_my = ith < T1 ? n1 : n2;
48
+ n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
49
+ }
50
+ n_end += n_start;
51
+ #else
52
+ // pytorch aten partition pattern
53
+ T n_my = div_up(n, nth);
54
+ n_start = ith * n_my;
55
+ n_end = std::min(n_start + n_my, n);
56
+ #endif
57
+ }
58
+
59
+ template <typename func_t>
60
+ inline void parallel_for(int nth, int n, const func_t& f) {
61
+ #if defined(_OPENMP)
62
+ #pragma omp parallel num_threads(nth)
63
+ {
64
+ //int nth = omp_get_num_threads();
65
+ int ith = omp_get_thread_num();
66
+ int tbegin, tend;
67
+ balance211(n, nth, ith, tbegin, tend);
68
+ f(tbegin, tend);
69
+ }
70
+ #else
71
+ f(0, n);
72
+
73
+ GGML_UNUSED(nth);
74
+ #endif
75
+ }
76
+
77
+ // quantized types that have AMX support
78
+ inline bool qtype_has_amx_kernels(const enum ggml_type type) {
79
+ // TODO: fix padding for vnni format
80
+ return (type == GGML_TYPE_Q4_0) ||
81
+ (type == GGML_TYPE_Q4_1);
82
+ //(type == GGML_TYPE_Q8_0) ||
83
+ //(type == GGML_TYPE_Q4_K) ||
84
+ //(type == GGML_TYPE_Q5_K) ||
85
+ //(type == GGML_TYPE_Q6_K) ||
86
+ //(type == GGML_TYPE_IQ4_XS);
87
+ }
88
+
89
+ // ggml backend context
90
+ struct ggml_backend_amx_context {
91
+ int n_threads = GGML_DEFAULT_N_THREADS;
92
+ std::unique_ptr<char[]> work_data;
93
+ size_t work_size = 0;
94
+ };
@@ -0,0 +1,446 @@
1
+ #include "ggml-amx.h"
2
+ #include "ggml-amx/common.h"
3
+ #include "ggml-amx/mmq.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-impl.h"
6
+
7
+ #if defined(__gnu_linux__)
8
+ #include <sys/syscall.h>
9
+ #include <unistd.h>
10
+ #endif
11
+
12
+ #include <cstdlib>
13
+ #include <cstring>
14
+ #include <memory>
15
+
16
+ #if defined(__AMX_INT8__)
17
+
18
+ // AMX buffer interface
19
+ static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
20
+ free(buffer->context);
21
+ }
22
+
23
+ static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
24
+ return (void *)(buffer->context);
25
+ }
26
+
27
+ static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
28
+ memset((char *)tensor->data + offset, value, size);
29
+
30
+ GGML_UNUSED(buffer);
31
+ }
32
+
33
+ static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
34
+ if (qtype_has_amx_kernels(tensor->type)) {
35
+ ggml_backend_amx_convert_weight(tensor, data, offset, size);
36
+ } else {
37
+ memcpy((char *)tensor->data + offset, data, size);
38
+ }
39
+
40
+ GGML_UNUSED(buffer);
41
+ }
42
+
43
+ static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
44
+ GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
45
+ memcpy(data, (const char *)tensor->data + offset, size);
46
+
47
+ GGML_UNUSED(buffer);
48
+ }
49
+
50
+ static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
51
+ if (ggml_backend_buffer_is_host(src->buffer)) {
52
+ if (qtype_has_amx_kernels(src->type)) {
53
+ ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
54
+ } else {
55
+ memcpy(dst->data, src->data, ggml_nbytes(src));
56
+ }
57
+ return true;
58
+ }
59
+ return false;
60
+
61
+ GGML_UNUSED(buffer);
62
+ }
63
+
64
+ static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
65
+ memset(buffer->context, value, buffer->size);
66
+ }
67
+
68
+ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
69
+ /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
70
+ /* .get_base = */ ggml_backend_amx_buffer_get_base,
71
+ /* .init_tensor = */ NULL, // no initialization required
72
+ /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
73
+ /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
74
+ /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
75
+ /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
76
+ /* .clear = */ ggml_backend_amx_buffer_clear,
77
+ /* .reset = */ NULL,
78
+ };
79
+
80
+ static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
81
+ return "AMX";
82
+
83
+ GGML_UNUSED(buft);
84
+ }
85
+
86
+ static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
87
+ void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
88
+ if (data == NULL) {
89
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
90
+ return NULL;
91
+ }
92
+
93
+ return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
94
+ }
95
+
96
+ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
97
+ return TENSOR_ALIGNMENT;
98
+
99
+ GGML_UNUSED(buft);
100
+ }
101
+
102
+ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
103
+ return ggml_backend_amx_get_alloc_size(tensor);
104
+
105
+ GGML_UNUSED(buft);
106
+ }
107
+
108
+ static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
109
+ return false;
110
+
111
+ GGML_UNUSED(buft);
112
+ }
113
+
114
+ ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
115
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
116
+ /* .iface = */ {
117
+ /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
118
+ /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
119
+ /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
120
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
121
+ /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
122
+ /* .is_host = */ ggml_backend_amx_buffer_type_is_host,
123
+ },
124
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
125
+ /* .context = */ NULL,
126
+ };
127
+
128
+ return &ggml_backend_buffer_type_amx;
129
+ }
130
+
131
+ // backend interface
132
+
133
+ static const char * ggml_backend_amx_name(ggml_backend_t backend) {
134
+ return "AMX";
135
+
136
+ GGML_UNUSED(backend);
137
+ }
138
+
139
+ static void ggml_backend_amx_free(ggml_backend_t backend) {
140
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
141
+ delete ctx;
142
+ delete backend;
143
+ }
144
+
145
+ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
146
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
147
+
148
+ for (int i = 0; i < cgraph->n_nodes; i++) {
149
+ struct ggml_tensor * node = cgraph->nodes[i];
150
+
151
+ switch (node->op) {
152
+ case GGML_OP_MUL_MAT:
153
+ ggml_backend_amx_mul_mat(ctx, node);
154
+ break;
155
+
156
+ case GGML_OP_NONE:
157
+ case GGML_OP_RESHAPE:
158
+ case GGML_OP_VIEW:
159
+ case GGML_OP_PERMUTE:
160
+ case GGML_OP_TRANSPOSE:
161
+ break;
162
+
163
+ default:
164
+ fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
165
+ GGML_ASSERT(false);
166
+ }
167
+ }
168
+
169
+ return GGML_STATUS_SUCCESS;
170
+
171
+ GGML_UNUSED(backend);
172
+ }
173
+
174
+ static struct ggml_backend_i ggml_backend_amx_i = {
175
+ /* .get_name = */ ggml_backend_amx_name,
176
+ /* .free = */ ggml_backend_amx_free,
177
+ /* .set_tensor_async = */ NULL,
178
+ /* .get_tensor_async = */ NULL,
179
+ /* .cpy_tensor_async = */ NULL,
180
+ /* .synchronize = */ NULL,
181
+ /* .graph_plan_create = */ NULL,
182
+ /* .graph_plan_free = */ NULL,
183
+ /* .graph_plan_update = */ NULL,
184
+ /* .graph_plan_compute = */ NULL,
185
+ /* .graph_compute = */ ggml_backend_amx_graph_compute,
186
+ /* .event_record = */ NULL,
187
+ /* .event_wait = */ NULL,
188
+ };
189
+
190
+ static ggml_guid_t ggml_backend_amx_guid() {
191
+ static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
192
+ return &guid;
193
+ }
194
+
195
+ #define ARCH_GET_XCOMP_PERM 0x1022
196
+ #define ARCH_REQ_XCOMP_PERM 0x1023
197
+ #define XFEATURE_XTILECFG 17
198
+ #define XFEATURE_XTILEDATA 18
199
+
200
+ static bool ggml_amx_init() {
201
+ #if defined(__gnu_linux__)
202
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
203
+ fprintf(stderr, "AMX is not ready to be used!\n");
204
+ return false;
205
+ }
206
+ return true;
207
+ #elif defined(_WIN32)
208
+ return true;
209
+ #endif
210
+ }
211
+
212
+ ggml_backend_t ggml_backend_amx_init() {
213
+
214
+ // invoke a Linux system call to request access to AMX features
215
+ ggml_amx_init();
216
+
217
+ // backend context
218
+ ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
219
+
220
+ // ggml amx backend
221
+ ggml_backend_t backend = new ggml_backend {
222
+ /* .guid = */ ggml_backend_amx_guid(),
223
+ /* .interface = */ ggml_backend_amx_i,
224
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
225
+ /* .context = */ ctx,
226
+ };
227
+
228
+ return backend;
229
+ }
230
+
231
+ bool ggml_backend_is_amx(ggml_backend_t backend) {
232
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
233
+ }
234
+
235
+ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
236
+ GGML_ASSERT(ggml_backend_is_amx(backend_amx));
237
+
238
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
239
+ ctx->n_threads = n_threads;
240
+ }
241
+
242
+ // device interface
243
+
244
+ static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
245
+ return "AMX";
246
+
247
+ GGML_UNUSED(dev);
248
+ }
249
+
250
+ static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
251
+ return "Intel Advanced Matrix Extensions";
252
+
253
+ GGML_UNUSED(dev);
254
+ }
255
+
256
+ static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
257
+ // TODO
258
+ *free = 0;
259
+ *total = 0;
260
+
261
+ GGML_UNUSED(dev);
262
+ }
263
+
264
+ static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
265
+ return GGML_BACKEND_DEVICE_TYPE_ACCEL;
266
+
267
+ GGML_UNUSED(dev);
268
+ }
269
+
270
+ static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
271
+ props->name = ggml_backend_amx_device_get_name(dev);
272
+ props->description = ggml_backend_amx_device_get_description(dev);
273
+ props->type = ggml_backend_amx_device_get_type(dev);
274
+ ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
275
+
276
+ // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
277
+ props->caps = {
278
+ /* .async = */ false,
279
+ /* .host_buffer = */ false,
280
+ /* .buffer_from_host_ptr = */ false,
281
+ /* .events = */ false,
282
+ };
283
+ }
284
+
285
+ static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
286
+ return ggml_backend_amx_init();
287
+
288
+ GGML_UNUSED(dev);
289
+ GGML_UNUSED(params);
290
+ }
291
+
292
+ static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
293
+ return ggml_backend_amx_buffer_type();
294
+
295
+ GGML_UNUSED(dev);
296
+ }
297
+
298
+ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
299
+
300
+ // handle only 2d gemm for now
301
+ auto is_contiguous_2d = [](const struct ggml_tensor * t) {
302
+ return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
303
+ };
304
+
305
+ switch (op->op) {
306
+ case GGML_OP_NONE:
307
+ case GGML_OP_RESHAPE:
308
+ case GGML_OP_VIEW:
309
+ case GGML_OP_PERMUTE:
310
+ case GGML_OP_TRANSPOSE:
311
+ return true;
312
+
313
+ case GGML_OP_MUL_MAT: {
314
+ const struct ggml_tensor * src0 = op->src[0];
315
+ const struct ggml_tensor * src1 = op->src[1];
316
+
317
+ const enum ggml_type type = src0->type;
318
+ const int64_t ne0 = op->ne[0];
319
+
320
+ // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
321
+ // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
322
+ bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
323
+
324
+ bool can_use_amx =
325
+ is_contiguous_2d(src0) && // src0 must be contiguous
326
+ is_contiguous_2d(src1) && // src1 must be contiguous
327
+ src1->type == GGML_TYPE_F32 && // src1 must be float32
328
+ has_amx_kernels && // with amx kernel impls
329
+ ne0 % (TILE_N * 2) == 0; // out_features is 32x
330
+
331
+ return can_use_amx;
332
+ }
333
+ default:
334
+ return false;
335
+ }
336
+
337
+ GGML_UNUSED(dev);
338
+ }
339
+
340
+ static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
341
+ return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
342
+
343
+ GGML_UNUSED(dev);
344
+ }
345
+
346
+ static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
347
+ /* .get_name = */ ggml_backend_amx_device_get_name,
348
+ /* .get_description = */ ggml_backend_amx_device_get_description,
349
+ /* .get_memory = */ ggml_backend_amx_device_get_memory,
350
+ /* .get_type = */ ggml_backend_amx_device_get_type,
351
+ /* .get_props = */ ggml_backend_amx_device_get_props,
352
+ /* .init_backend = */ ggml_backend_amx_device_init,
353
+ /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
354
+ /* .get_host_buffer_type = */ NULL,
355
+ /* .buffer_from_host_ptr = */ NULL,
356
+ /* .supports_op = */ ggml_backend_amx_device_supports_op,
357
+ /* .supports_buft = */ ggml_backend_amx_device_supports_buft,
358
+ /* .offload_op = */ NULL,
359
+ /* .event_new = */ NULL,
360
+ /* .event_free = */ NULL,
361
+ /* .event_synchronize = */ NULL,
362
+ };
363
+
364
+ // backend reg interface
365
+
366
+ static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
367
+ return "AMX";
368
+
369
+ GGML_UNUSED(reg);
370
+ }
371
+
372
+ static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
373
+ return 1;
374
+
375
+ GGML_UNUSED(reg);
376
+ }
377
+
378
+ static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
379
+ GGML_ASSERT(index == 0);
380
+
381
+ static ggml_backend_device ggml_backend_amx_device = {
382
+ /* .iface = */ ggml_backend_amx_device_i,
383
+ /* .reg = */ reg,
384
+ /* .context = */ nullptr,
385
+ };
386
+
387
+ return &ggml_backend_amx_device;
388
+
389
+ GGML_UNUSED(reg);
390
+ GGML_UNUSED(index);
391
+ }
392
+
393
+ static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
394
+ if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
395
+ return (void *)ggml_backend_amx_set_n_threads;
396
+ }
397
+ return NULL;
398
+
399
+ GGML_UNUSED(reg);
400
+ GGML_UNUSED(name);
401
+ }
402
+
403
+ static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
404
+ /* .get_name = */ ggml_backend_amx_reg_get_name,
405
+ /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
406
+ /* .get_device = */ ggml_backend_amx_reg_get_device,
407
+ /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
408
+ };
409
+
410
+ ggml_backend_reg_t ggml_backend_amx_reg(void) {
411
+ static struct ggml_backend_reg ggml_backend_amx_reg = {
412
+ /* .iface = */ ggml_backend_amx_reg_i,
413
+ /* .context = */ NULL,
414
+ };
415
+
416
+ return &ggml_backend_amx_reg;
417
+ }
418
+
419
+ #else // if defined(__AMX_INT8__)
420
+
421
+ ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void) {
422
+ return nullptr;
423
+ }
424
+
425
+ bool ggml_backend_is_amx(ggml_backend_t backend) {
426
+ GGML_UNUSED(backend);
427
+ return false;
428
+ }
429
+
430
+ ggml_backend_t ggml_backend_amx_init(void) {
431
+ fprintf(stderr, "GGML is not compiled with AMX support!\n");
432
+ return nullptr;
433
+ }
434
+
435
+ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
436
+ fprintf(stderr, "GGML is not compiled with AMX support!\n");
437
+
438
+ GGML_UNUSED(backend_amx);
439
+ GGML_UNUSED(n_threads);
440
+ }
441
+
442
+ ggml_backend_reg_t ggml_backend_amx_reg(void) {
443
+ return nullptr;
444
+ }
445
+
446
+ #endif