whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,593 @@
1
+
2
+
3
+ #include <iostream>
4
+ #include <fstream>
5
+ #include <sstream>
6
+ #include <string>
7
+ #include <stdexcept>
8
+ #include <array>
9
+ #include <vector>
10
+ #include <map>
11
+ #include <thread>
12
+ #include <mutex>
13
+ #include <future>
14
+ #include <queue>
15
+ #include <condition_variable>
16
+ #include <cstdio>
17
+ #include <cstring>
18
+ #include <cstdlib>
19
+ #include <cassert>
20
+ #include <sys/stat.h>
21
+ #include <sys/types.h>
22
+
23
+ #ifdef _WIN32
24
+ #include <windows.h>
25
+ #include <direct.h> // For _mkdir on Windows
26
+ #include <algorithm> // For std::replace on w64devkit
27
+ #else
28
+ #include <unistd.h>
29
+ #include <sys/wait.h>
30
+ #include <fcntl.h>
31
+ #endif
32
+
33
+ #include <vulkan/vulkan_core.h>
34
+
35
+ #define ASYNCIO_CONCURRENCY 64
36
+
37
+ std::mutex lock;
38
+ std::vector<std::pair<std::string, std::string>> shader_fnames;
39
+
40
+ std::string GLSLC = "glslc";
41
+ std::string input_dir = "vulkan-shaders";
42
+ std::string output_dir = "/tmp";
43
+ std::string target_hpp = "ggml-vulkan-shaders.hpp";
44
+ std::string target_cpp = "ggml-vulkan-shaders.cpp";
45
+ bool no_clean = false;
46
+
47
+ const std::vector<std::string> type_names = {
48
+ "f32",
49
+ "f16",
50
+ "q4_0",
51
+ "q4_1",
52
+ "q5_0",
53
+ "q5_1",
54
+ "q8_0",
55
+ "q2_k",
56
+ "q3_k",
57
+ "q4_k",
58
+ "q5_k",
59
+ "q6_k",
60
+ "iq4_nl"
61
+ };
62
+
63
+ namespace {
64
+ void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
65
+ #ifdef _WIN32
66
+ HANDLE stdout_read, stdout_write;
67
+ HANDLE stderr_read, stderr_write;
68
+ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
69
+
70
+ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
71
+ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
72
+ throw std::runtime_error("Failed to create stdout pipe");
73
+ }
74
+
75
+ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
76
+ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
77
+ throw std::runtime_error("Failed to create stderr pipe");
78
+ }
79
+
80
+ PROCESS_INFORMATION pi;
81
+ STARTUPINFOA si = { sizeof(STARTUPINFOA) };
82
+ si.dwFlags = STARTF_USESTDHANDLES;
83
+ si.hStdOutput = stdout_write;
84
+ si.hStdError = stderr_write;
85
+
86
+ std::vector<char> cmd(command.begin(), command.end());
87
+ cmd.push_back('\0');
88
+
89
+ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
90
+ throw std::runtime_error("Failed to create process");
91
+ }
92
+
93
+ CloseHandle(stdout_write);
94
+ CloseHandle(stderr_write);
95
+
96
+ std::array<char, 128> buffer;
97
+ DWORD bytes_read;
98
+
99
+ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
100
+ stdout_str.append(buffer.data(), bytes_read);
101
+ }
102
+
103
+ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
104
+ stderr_str.append(buffer.data(), bytes_read);
105
+ }
106
+
107
+ CloseHandle(stdout_read);
108
+ CloseHandle(stderr_read);
109
+ WaitForSingleObject(pi.hProcess, INFINITE);
110
+ CloseHandle(pi.hProcess);
111
+ CloseHandle(pi.hThread);
112
+ #else
113
+ int stdout_pipe[2];
114
+ int stderr_pipe[2];
115
+
116
+ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
117
+ throw std::runtime_error("Failed to create pipes");
118
+ }
119
+
120
+ pid_t pid = fork();
121
+ if (pid < 0) {
122
+ throw std::runtime_error("Failed to fork process");
123
+ }
124
+
125
+ if (pid == 0) {
126
+ close(stdout_pipe[0]);
127
+ close(stderr_pipe[0]);
128
+ dup2(stdout_pipe[1], STDOUT_FILENO);
129
+ dup2(stderr_pipe[1], STDERR_FILENO);
130
+ close(stdout_pipe[1]);
131
+ close(stderr_pipe[1]);
132
+ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
133
+ _exit(EXIT_FAILURE);
134
+ } else {
135
+ close(stdout_pipe[1]);
136
+ close(stderr_pipe[1]);
137
+
138
+ std::array<char, 128> buffer;
139
+ ssize_t bytes_read;
140
+
141
+ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
142
+ stdout_str.append(buffer.data(), bytes_read);
143
+ }
144
+
145
+ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
146
+ stderr_str.append(buffer.data(), bytes_read);
147
+ }
148
+
149
+ close(stdout_pipe[0]);
150
+ close(stderr_pipe[0]);
151
+ waitpid(pid, nullptr, 0);
152
+ }
153
+ #endif
154
+ }
155
+
156
+ bool directory_exists(const std::string& path) {
157
+ struct stat info;
158
+ if (stat(path.c_str(), &info) != 0) {
159
+ return false; // Path doesn't exist or can't be accessed
160
+ }
161
+ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
162
+ }
163
+
164
+ bool create_directory(const std::string& path) {
165
+ #ifdef _WIN32
166
+ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
167
+ #else
168
+ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
169
+ #endif
170
+ }
171
+
172
+ std::string to_uppercase(const std::string& input) {
173
+ std::string result = input;
174
+ for (char& c : result) {
175
+ c = std::toupper(c);
176
+ }
177
+ return result;
178
+ }
179
+
180
+ bool string_ends_with(const std::string& str, const std::string& suffix) {
181
+ if (suffix.size() > str.size()) {
182
+ return false;
183
+ }
184
+ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
185
+ }
186
+
187
+ static const char path_separator = '/';
188
+
189
+ std::string join_paths(const std::string& path1, const std::string& path2) {
190
+ return path1 + path_separator + path2;
191
+ }
192
+
193
+ std::string basename(const std::string &path) {
194
+ return path.substr(path.find_last_of("/\\") + 1);
195
+ }
196
+
197
+ // variables to track number of compiles in progress
198
+ static uint32_t compile_count = 0;
199
+ static std::mutex compile_count_mutex;
200
+ static std::condition_variable compile_count_cond;
201
+
202
+ void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
203
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
204
+ std::string out_fname = join_paths(output_dir, name + ".spv");
205
+ std::string in_path = join_paths(input_dir, in_fname);
206
+
207
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
208
+
209
+ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
210
+ std::string opt_level = coopmat ? "" : "-O";
211
+
212
+ #ifdef _WIN32
213
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
214
+ #else
215
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
216
+ #endif
217
+
218
+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
219
+ cmd.push_back("-g");
220
+ #endif
221
+
222
+ for (const auto& define : defines) {
223
+ cmd.push_back("-D" + define.first + "=" + define.second);
224
+ }
225
+
226
+ std::string command;
227
+ for (const auto& part : cmd) {
228
+ command += part + " ";
229
+ }
230
+
231
+ std::string stdout_str, stderr_str;
232
+ try {
233
+ // std::cout << "Executing command: ";
234
+ // for (const auto& part : cmd) {
235
+ // std::cout << part << " ";
236
+ // }
237
+ // std::cout << std::endl;
238
+
239
+ execute_command(command, stdout_str, stderr_str);
240
+ if (!stderr_str.empty()) {
241
+ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
242
+ return;
243
+ }
244
+
245
+ std::lock_guard<std::mutex> guard(lock);
246
+ shader_fnames.push_back(std::make_pair(name, out_fname));
247
+ } catch (const std::exception& e) {
248
+ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
249
+ }
250
+ {
251
+ std::lock_guard<std::mutex> guard(compile_count_mutex);
252
+ assert(compile_count > 0);
253
+ compile_count--;
254
+ }
255
+ compile_count_cond.notify_all();
256
+ }
257
+
258
+ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
259
+ std::map<std::string, std::string> result = a;
260
+ result.insert(b.begin(), b.end());
261
+ return result;
262
+ }
263
+
264
+ static std::vector<std::future<void>> compiles;
265
+ void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
266
+ {
267
+ // wait until fewer than N compiles are in progress.
268
+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
269
+ uint32_t N = 16;
270
+ std::unique_lock<std::mutex> guard(compile_count_mutex);
271
+ while (compile_count >= N) {
272
+ compile_count_cond.wait(guard);
273
+ }
274
+ compile_count++;
275
+ }
276
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
277
+ }
278
+
279
+ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
280
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
281
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
282
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
283
+
284
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
285
+ std::string shader_name = "matmul";
286
+
287
+ if (matmul_id) {
288
+ base_dict["MUL_MAT_ID"] = "1";
289
+ shader_name = "matmul_id";
290
+ }
291
+
292
+ if (fp16) {
293
+ base_dict["FLOAT16"] = "1";
294
+ }
295
+
296
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
297
+
298
+ if (coopmat) {
299
+ base_dict["COOPMAT"] = "1";
300
+ }
301
+
302
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
303
+
304
+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
305
+
306
+ // Shaders with f16 B_TYPE
307
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
308
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
309
+
310
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
311
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
312
+
313
+ for (const auto& tname : type_names) {
314
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
315
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
316
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
317
+ // For aligned matmul loads
318
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
319
+
320
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
321
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
322
+
323
+ if (tname != "f16" && tname != "f32") {
324
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
325
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
326
+ }
327
+ }
328
+ }
329
+
330
+ void process_shaders() {
331
+ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
332
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
333
+
334
+ // matmul
335
+ for (const auto& matmul_id : {false, true}) {
336
+ // No coopmats
337
+ // fp32
338
+ matmul_shaders(false, matmul_id, false, false, false);
339
+
340
+ // fp16, fp32acc and fp16acc
341
+ matmul_shaders(true, matmul_id, false, false, false);
342
+ matmul_shaders(true, matmul_id, false, false, true);
343
+
344
+ // Coopmat, fp32acc and fp16acc
345
+ matmul_shaders(true, matmul_id, true, false, false);
346
+ matmul_shaders(true, matmul_id, true, false, true);
347
+
348
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
349
+ // Coopmat2, fp32acc and fp16acc
350
+ matmul_shaders(true, matmul_id, false, true, false);
351
+ matmul_shaders(true, matmul_id, false, true, true);
352
+ #endif
353
+ }
354
+
355
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
356
+ // flash attention
357
+ for (const auto& f16acc : {false, true}) {
358
+ std::string acctype = f16acc ? "float16_t" : "float";
359
+
360
+ for (const auto& tname : type_names) {
361
+ if (tname == "f32") {
362
+ continue;
363
+ }
364
+
365
+ if (tname == "f16") {
366
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
367
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
368
+ } else {
369
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
370
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
371
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
372
+ }
373
+ }
374
+ }
375
+ #endif
376
+
377
+ for (const auto& tname : type_names) {
378
+ // mul mat vec
379
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
380
+ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
381
+
382
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
383
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
384
+
385
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
386
+
387
+ // Dequant shaders
388
+ if (tname != "f16") {
389
+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
390
+ }
391
+
392
+ if (!string_ends_with(tname, "_k")) {
393
+ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
394
+
395
+ if (tname == "f16") {
396
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
397
+ } else {
398
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
399
+ }
400
+ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
401
+ }
402
+ }
403
+
404
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
405
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
406
+
407
+ // Norms
408
+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
409
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
410
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
411
+
412
+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
413
+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
414
+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
415
+ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
416
+ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
417
+ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
418
+
419
+ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
420
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
421
+
422
+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
423
+
424
+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
425
+
426
+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
427
+
428
+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
429
+
430
+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
431
+
432
+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
433
+
434
+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
435
+
436
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
437
+
438
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
439
+
440
+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
441
+
442
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
443
+
444
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
445
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
446
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
447
+
448
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
449
+
450
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
452
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
453
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
454
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
455
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
456
+
457
+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
458
+
459
+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
460
+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
461
+
462
+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463
+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
464
+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
465
+
466
+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
467
+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
468
+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
469
+
470
+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
471
+
472
+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
473
+
474
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
475
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
476
+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
477
+
478
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
479
+
480
+ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481
+
482
+ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
483
+
484
+ for (auto &c : compiles) {
485
+ c.wait();
486
+ }
487
+ }
488
+
489
+ void write_output_files() {
490
+ FILE* hdr = fopen(target_hpp.c_str(), "w");
491
+ FILE* src = fopen(target_cpp.c_str(), "w");
492
+
493
+ fprintf(hdr, "#include <cstdint>\n\n");
494
+ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
495
+
496
+ for (const auto& pair : shader_fnames) {
497
+ const std::string& name = pair.first;
498
+ #ifdef _WIN32
499
+ std::string path = pair.second;
500
+ std::replace(path.begin(), path.end(), '/', '\\' );
501
+ #else
502
+ const std::string& path = pair.second;
503
+ #endif
504
+
505
+ FILE* spv = fopen(path.c_str(), "rb");
506
+ if (!spv) {
507
+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
508
+ continue;
509
+ }
510
+
511
+ fseek(spv, 0, SEEK_END);
512
+ size_t size = ftell(spv);
513
+ fseek(spv, 0, SEEK_SET);
514
+
515
+ std::vector<unsigned char> data(size);
516
+ size_t read_size = fread(data.data(), 1, size, spv);
517
+ fclose(spv);
518
+ if (read_size != size) {
519
+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
520
+ continue;
521
+ }
522
+
523
+ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
524
+ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
525
+
526
+ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
527
+ for (size_t i = 0; i < size; ++i) {
528
+ fprintf(src, "0x%02x,", data[i]);
529
+ if ((i + 1) % 12 == 0) fprintf(src, "\n");
530
+ }
531
+ fprintf(src, "\n};\n\n");
532
+
533
+ if (!no_clean) {
534
+ std::remove(path.c_str());
535
+ }
536
+ }
537
+
538
+ fclose(hdr);
539
+ fclose(src);
540
+ }
541
+ }
542
+
543
+ int main(int argc, char** argv) {
544
+ std::map<std::string, std::string> args;
545
+ for (int i = 1; i < argc; ++i) {
546
+ std::string arg = argv[i];
547
+ if (arg.rfind("--", 0) == 0) {
548
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
549
+ args[arg] = argv[i + 1];
550
+ ++i;
551
+ } else {
552
+ args[arg] = "";
553
+ }
554
+ }
555
+ }
556
+
557
+ if (args.find("--glslc") != args.end()) {
558
+ GLSLC = args["--glslc"]; // Path to glslc
559
+ }
560
+ if (args.find("--input-dir") != args.end()) {
561
+ input_dir = args["--input-dir"]; // Directory containing shader sources
562
+ }
563
+ if (args.find("--output-dir") != args.end()) {
564
+ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
565
+ }
566
+ if (args.find("--target-hpp") != args.end()) {
567
+ target_hpp = args["--target-hpp"]; // Path to generated header file
568
+ }
569
+ if (args.find("--target-cpp") != args.end()) {
570
+ target_cpp = args["--target-cpp"]; // Path to generated cpp file
571
+ }
572
+ if (args.find("--no-clean") != args.end()) {
573
+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build
574
+ }
575
+
576
+ if (!directory_exists(input_dir)) {
577
+ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
578
+ return EXIT_FAILURE;
579
+ }
580
+
581
+ if (!directory_exists(output_dir)) {
582
+ if (!create_directory(output_dir)) {
583
+ std::cerr << "Error creating output directory: " << output_dir << "\n";
584
+ return EXIT_FAILURE;
585
+ }
586
+ }
587
+
588
+ process_shaders();
589
+
590
+ write_output_files();
591
+
592
+ return EXIT_SUCCESS;
593
+ }