whispercpp 1.2.0.2 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
@@ -0,0 +1,593 @@
1
+
2
+
3
+ #include <iostream>
4
+ #include <fstream>
5
+ #include <sstream>
6
+ #include <string>
7
+ #include <stdexcept>
8
+ #include <array>
9
+ #include <vector>
10
+ #include <map>
11
+ #include <thread>
12
+ #include <mutex>
13
+ #include <future>
14
+ #include <queue>
15
+ #include <condition_variable>
16
+ #include <cstdio>
17
+ #include <cstring>
18
+ #include <cstdlib>
19
+ #include <cassert>
20
+ #include <sys/stat.h>
21
+ #include <sys/types.h>
22
+
23
+ #ifdef _WIN32
24
+ #include <windows.h>
25
+ #include <direct.h> // For _mkdir on Windows
26
+ #include <algorithm> // For std::replace on w64devkit
27
+ #else
28
+ #include <unistd.h>
29
+ #include <sys/wait.h>
30
+ #include <fcntl.h>
31
+ #endif
32
+
33
+ #include <vulkan/vulkan_core.h>
34
+
35
+ #define ASYNCIO_CONCURRENCY 64
36
+
37
+ std::mutex lock;
38
+ std::vector<std::pair<std::string, std::string>> shader_fnames;
39
+
40
+ std::string GLSLC = "glslc";
41
+ std::string input_dir = "vulkan-shaders";
42
+ std::string output_dir = "/tmp";
43
+ std::string target_hpp = "ggml-vulkan-shaders.hpp";
44
+ std::string target_cpp = "ggml-vulkan-shaders.cpp";
45
+ bool no_clean = false;
46
+
47
+ const std::vector<std::string> type_names = {
48
+ "f32",
49
+ "f16",
50
+ "q4_0",
51
+ "q4_1",
52
+ "q5_0",
53
+ "q5_1",
54
+ "q8_0",
55
+ "q2_k",
56
+ "q3_k",
57
+ "q4_k",
58
+ "q5_k",
59
+ "q6_k",
60
+ "iq4_nl"
61
+ };
62
+
63
+ namespace {
64
+ void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
65
+ #ifdef _WIN32
66
+ HANDLE stdout_read, stdout_write;
67
+ HANDLE stderr_read, stderr_write;
68
+ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
69
+
70
+ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
71
+ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
72
+ throw std::runtime_error("Failed to create stdout pipe");
73
+ }
74
+
75
+ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
76
+ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
77
+ throw std::runtime_error("Failed to create stderr pipe");
78
+ }
79
+
80
+ PROCESS_INFORMATION pi;
81
+ STARTUPINFOA si = { sizeof(STARTUPINFOA) };
82
+ si.dwFlags = STARTF_USESTDHANDLES;
83
+ si.hStdOutput = stdout_write;
84
+ si.hStdError = stderr_write;
85
+
86
+ std::vector<char> cmd(command.begin(), command.end());
87
+ cmd.push_back('\0');
88
+
89
+ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
90
+ throw std::runtime_error("Failed to create process");
91
+ }
92
+
93
+ CloseHandle(stdout_write);
94
+ CloseHandle(stderr_write);
95
+
96
+ std::array<char, 128> buffer;
97
+ DWORD bytes_read;
98
+
99
+ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
100
+ stdout_str.append(buffer.data(), bytes_read);
101
+ }
102
+
103
+ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
104
+ stderr_str.append(buffer.data(), bytes_read);
105
+ }
106
+
107
+ CloseHandle(stdout_read);
108
+ CloseHandle(stderr_read);
109
+ WaitForSingleObject(pi.hProcess, INFINITE);
110
+ CloseHandle(pi.hProcess);
111
+ CloseHandle(pi.hThread);
112
+ #else
113
+ int stdout_pipe[2];
114
+ int stderr_pipe[2];
115
+
116
+ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
117
+ throw std::runtime_error("Failed to create pipes");
118
+ }
119
+
120
+ pid_t pid = fork();
121
+ if (pid < 0) {
122
+ throw std::runtime_error("Failed to fork process");
123
+ }
124
+
125
+ if (pid == 0) {
126
+ close(stdout_pipe[0]);
127
+ close(stderr_pipe[0]);
128
+ dup2(stdout_pipe[1], STDOUT_FILENO);
129
+ dup2(stderr_pipe[1], STDERR_FILENO);
130
+ close(stdout_pipe[1]);
131
+ close(stderr_pipe[1]);
132
+ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
133
+ _exit(EXIT_FAILURE);
134
+ } else {
135
+ close(stdout_pipe[1]);
136
+ close(stderr_pipe[1]);
137
+
138
+ std::array<char, 128> buffer;
139
+ ssize_t bytes_read;
140
+
141
+ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
142
+ stdout_str.append(buffer.data(), bytes_read);
143
+ }
144
+
145
+ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
146
+ stderr_str.append(buffer.data(), bytes_read);
147
+ }
148
+
149
+ close(stdout_pipe[0]);
150
+ close(stderr_pipe[0]);
151
+ waitpid(pid, nullptr, 0);
152
+ }
153
+ #endif
154
+ }
155
+
156
+ bool directory_exists(const std::string& path) {
157
+ struct stat info;
158
+ if (stat(path.c_str(), &info) != 0) {
159
+ return false; // Path doesn't exist or can't be accessed
160
+ }
161
+ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
162
+ }
163
+
164
+ bool create_directory(const std::string& path) {
165
+ #ifdef _WIN32
166
+ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
167
+ #else
168
+ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
169
+ #endif
170
+ }
171
+
172
+ std::string to_uppercase(const std::string& input) {
173
+ std::string result = input;
174
+ for (char& c : result) {
175
+ c = std::toupper(c);
176
+ }
177
+ return result;
178
+ }
179
+
180
+ bool string_ends_with(const std::string& str, const std::string& suffix) {
181
+ if (suffix.size() > str.size()) {
182
+ return false;
183
+ }
184
+ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
185
+ }
186
+
187
+ static const char path_separator = '/';
188
+
189
+ std::string join_paths(const std::string& path1, const std::string& path2) {
190
+ return path1 + path_separator + path2;
191
+ }
192
+
193
+ std::string basename(const std::string &path) {
194
+ return path.substr(path.find_last_of("/\\") + 1);
195
+ }
196
+
197
+ // variables to track number of compiles in progress
198
+ static uint32_t compile_count = 0;
199
+ static std::mutex compile_count_mutex;
200
+ static std::condition_variable compile_count_cond;
201
+
202
+ void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
203
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
204
+ std::string out_fname = join_paths(output_dir, name + ".spv");
205
+ std::string in_path = join_paths(input_dir, in_fname);
206
+
207
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
208
+
209
+ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
210
+ std::string opt_level = coopmat ? "" : "-O";
211
+
212
+ #ifdef _WIN32
213
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
214
+ #else
215
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
216
+ #endif
217
+
218
+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
219
+ cmd.push_back("-g");
220
+ #endif
221
+
222
+ for (const auto& define : defines) {
223
+ cmd.push_back("-D" + define.first + "=" + define.second);
224
+ }
225
+
226
+ std::string command;
227
+ for (const auto& part : cmd) {
228
+ command += part + " ";
229
+ }
230
+
231
+ std::string stdout_str, stderr_str;
232
+ try {
233
+ // std::cout << "Executing command: ";
234
+ // for (const auto& part : cmd) {
235
+ // std::cout << part << " ";
236
+ // }
237
+ // std::cout << std::endl;
238
+
239
+ execute_command(command, stdout_str, stderr_str);
240
+ if (!stderr_str.empty()) {
241
+ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
242
+ return;
243
+ }
244
+
245
+ std::lock_guard<std::mutex> guard(lock);
246
+ shader_fnames.push_back(std::make_pair(name, out_fname));
247
+ } catch (const std::exception& e) {
248
+ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
249
+ }
250
+ {
251
+ std::lock_guard<std::mutex> guard(compile_count_mutex);
252
+ assert(compile_count > 0);
253
+ compile_count--;
254
+ }
255
+ compile_count_cond.notify_all();
256
+ }
257
+
258
+ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
259
+ std::map<std::string, std::string> result = a;
260
+ result.insert(b.begin(), b.end());
261
+ return result;
262
+ }
263
+
264
+ static std::vector<std::future<void>> compiles;
265
+ void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
266
+ {
267
+ // wait until fewer than N compiles are in progress.
268
+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
269
+ uint32_t N = 16;
270
+ std::unique_lock<std::mutex> guard(compile_count_mutex);
271
+ while (compile_count >= N) {
272
+ compile_count_cond.wait(guard);
273
+ }
274
+ compile_count++;
275
+ }
276
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
277
+ }
278
+
279
+ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
280
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
281
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
282
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
283
+
284
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
285
+ std::string shader_name = "matmul";
286
+
287
+ if (matmul_id) {
288
+ base_dict["MUL_MAT_ID"] = "1";
289
+ shader_name = "matmul_id";
290
+ }
291
+
292
+ if (fp16) {
293
+ base_dict["FLOAT16"] = "1";
294
+ }
295
+
296
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
297
+
298
+ if (coopmat) {
299
+ base_dict["COOPMAT"] = "1";
300
+ }
301
+
302
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
303
+
304
+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
305
+
306
+ // Shaders with f16 B_TYPE
307
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
308
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
309
+
310
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
311
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
312
+
313
+ for (const auto& tname : type_names) {
314
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
315
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
316
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
317
+ // For aligned matmul loads
318
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
319
+
320
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
321
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
322
+
323
+ if (tname != "f16" && tname != "f32") {
324
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
325
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
326
+ }
327
+ }
328
+ }
329
+
330
+ void process_shaders() {
331
+ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
332
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
333
+
334
+ // matmul
335
+ for (const auto& matmul_id : {false, true}) {
336
+ // No coopmats
337
+ // fp32
338
+ matmul_shaders(false, matmul_id, false, false, false);
339
+
340
+ // fp16, fp32acc and fp16acc
341
+ matmul_shaders(true, matmul_id, false, false, false);
342
+ matmul_shaders(true, matmul_id, false, false, true);
343
+
344
+ // Coopmat, fp32acc and fp16acc
345
+ matmul_shaders(true, matmul_id, true, false, false);
346
+ matmul_shaders(true, matmul_id, true, false, true);
347
+
348
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
349
+ // Coopmat2, fp32acc and fp16acc
350
+ matmul_shaders(true, matmul_id, false, true, false);
351
+ matmul_shaders(true, matmul_id, false, true, true);
352
+ #endif
353
+ }
354
+
355
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
356
+ // flash attention
357
+ for (const auto& f16acc : {false, true}) {
358
+ std::string acctype = f16acc ? "float16_t" : "float";
359
+
360
+ for (const auto& tname : type_names) {
361
+ if (tname == "f32") {
362
+ continue;
363
+ }
364
+
365
+ if (tname == "f16") {
366
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
367
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
368
+ } else {
369
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
370
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
371
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
372
+ }
373
+ }
374
+ }
375
+ #endif
376
+
377
+ for (const auto& tname : type_names) {
378
+ // mul mat vec
379
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
380
+ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
381
+
382
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
383
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
384
+
385
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
386
+
387
+ // Dequant shaders
388
+ if (tname != "f16") {
389
+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
390
+ }
391
+
392
+ if (!string_ends_with(tname, "_k")) {
393
+ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
394
+
395
+ if (tname == "f16") {
396
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
397
+ } else {
398
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
399
+ }
400
+ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
401
+ }
402
+ }
403
+
404
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
405
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
406
+
407
+ // Norms
408
+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
409
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
410
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
411
+
412
+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
413
+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
414
+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
415
+ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
416
+ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
417
+ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
418
+
419
+ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
420
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
421
+
422
+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
423
+
424
+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
425
+
426
+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
427
+
428
+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
429
+
430
+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
431
+
432
+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
433
+
434
+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
435
+
436
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
437
+
438
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
439
+
440
+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
441
+
442
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
443
+
444
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
445
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
446
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
447
+
448
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
449
+
450
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
452
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
453
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
454
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
455
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
456
+
457
+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
458
+
459
+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
460
+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
461
+
462
+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463
+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
464
+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
465
+
466
+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
467
+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
468
+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
469
+
470
+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
471
+
472
+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
473
+
474
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
475
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
476
+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
477
+
478
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
479
+
480
+ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481
+
482
+ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
483
+
484
+ for (auto &c : compiles) {
485
+ c.wait();
486
+ }
487
+ }
488
+
489
+ void write_output_files() {
490
+ FILE* hdr = fopen(target_hpp.c_str(), "w");
491
+ FILE* src = fopen(target_cpp.c_str(), "w");
492
+
493
+ fprintf(hdr, "#include <cstdint>\n\n");
494
+ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
495
+
496
+ for (const auto& pair : shader_fnames) {
497
+ const std::string& name = pair.first;
498
+ #ifdef _WIN32
499
+ std::string path = pair.second;
500
+ std::replace(path.begin(), path.end(), '/', '\\' );
501
+ #else
502
+ const std::string& path = pair.second;
503
+ #endif
504
+
505
+ FILE* spv = fopen(path.c_str(), "rb");
506
+ if (!spv) {
507
+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
508
+ continue;
509
+ }
510
+
511
+ fseek(spv, 0, SEEK_END);
512
+ size_t size = ftell(spv);
513
+ fseek(spv, 0, SEEK_SET);
514
+
515
+ std::vector<unsigned char> data(size);
516
+ size_t read_size = fread(data.data(), 1, size, spv);
517
+ fclose(spv);
518
+ if (read_size != size) {
519
+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
520
+ continue;
521
+ }
522
+
523
+ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
524
+ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
525
+
526
+ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
527
+ for (size_t i = 0; i < size; ++i) {
528
+ fprintf(src, "0x%02x,", data[i]);
529
+ if ((i + 1) % 12 == 0) fprintf(src, "\n");
530
+ }
531
+ fprintf(src, "\n};\n\n");
532
+
533
+ if (!no_clean) {
534
+ std::remove(path.c_str());
535
+ }
536
+ }
537
+
538
+ fclose(hdr);
539
+ fclose(src);
540
+ }
541
+ }
542
+
543
+ int main(int argc, char** argv) {
544
+ std::map<std::string, std::string> args;
545
+ for (int i = 1; i < argc; ++i) {
546
+ std::string arg = argv[i];
547
+ if (arg.rfind("--", 0) == 0) {
548
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
549
+ args[arg] = argv[i + 1];
550
+ ++i;
551
+ } else {
552
+ args[arg] = "";
553
+ }
554
+ }
555
+ }
556
+
557
+ if (args.find("--glslc") != args.end()) {
558
+ GLSLC = args["--glslc"]; // Path to glslc
559
+ }
560
+ if (args.find("--input-dir") != args.end()) {
561
+ input_dir = args["--input-dir"]; // Directory containing shader sources
562
+ }
563
+ if (args.find("--output-dir") != args.end()) {
564
+ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
565
+ }
566
+ if (args.find("--target-hpp") != args.end()) {
567
+ target_hpp = args["--target-hpp"]; // Path to generated header file
568
+ }
569
+ if (args.find("--target-cpp") != args.end()) {
570
+ target_cpp = args["--target-cpp"]; // Path to generated cpp file
571
+ }
572
+ if (args.find("--no-clean") != args.end()) {
573
+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build
574
+ }
575
+
576
+ if (!directory_exists(input_dir)) {
577
+ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
578
+ return EXIT_FAILURE;
579
+ }
580
+
581
+ if (!directory_exists(output_dir)) {
582
+ if (!create_directory(output_dir)) {
583
+ std::cerr << "Error creating output directory: " << output_dir << "\n";
584
+ return EXIT_FAILURE;
585
+ }
586
+ }
587
+
588
+ process_shaders();
589
+
590
+ write_output_files();
591
+
592
+ return EXIT_SUCCESS;
593
+ }