whispercpp 1.2.0.2 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +46 -86
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -7
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/ggml/include/ggml.h +2285 -0
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/include/whisper.h +672 -0
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1608 -159
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/src/whisper.cpp +7393 -0
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -8616
- data/ext/ggml.h +0 -748
- data/ext/whisper.cpp +0 -4829
- data/ext/whisper.h +0 -402
@@ -0,0 +1,593 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
#include <iostream>
|
4
|
+
#include <fstream>
|
5
|
+
#include <sstream>
|
6
|
+
#include <string>
|
7
|
+
#include <stdexcept>
|
8
|
+
#include <array>
|
9
|
+
#include <vector>
|
10
|
+
#include <map>
|
11
|
+
#include <thread>
|
12
|
+
#include <mutex>
|
13
|
+
#include <future>
|
14
|
+
#include <queue>
|
15
|
+
#include <condition_variable>
|
16
|
+
#include <cstdio>
|
17
|
+
#include <cstring>
|
18
|
+
#include <cstdlib>
|
19
|
+
#include <cassert>
|
20
|
+
#include <sys/stat.h>
|
21
|
+
#include <sys/types.h>
|
22
|
+
|
23
|
+
#ifdef _WIN32
|
24
|
+
#include <windows.h>
|
25
|
+
#include <direct.h> // For _mkdir on Windows
|
26
|
+
#include <algorithm> // For std::replace on w64devkit
|
27
|
+
#else
|
28
|
+
#include <unistd.h>
|
29
|
+
#include <sys/wait.h>
|
30
|
+
#include <fcntl.h>
|
31
|
+
#endif
|
32
|
+
|
33
|
+
#include <vulkan/vulkan_core.h>
|
34
|
+
|
35
|
+
#define ASYNCIO_CONCURRENCY 64
|
36
|
+
|
37
|
+
std::mutex lock;
|
38
|
+
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
39
|
+
|
40
|
+
std::string GLSLC = "glslc";
|
41
|
+
std::string input_dir = "vulkan-shaders";
|
42
|
+
std::string output_dir = "/tmp";
|
43
|
+
std::string target_hpp = "ggml-vulkan-shaders.hpp";
|
44
|
+
std::string target_cpp = "ggml-vulkan-shaders.cpp";
|
45
|
+
bool no_clean = false;
|
46
|
+
|
47
|
+
const std::vector<std::string> type_names = {
|
48
|
+
"f32",
|
49
|
+
"f16",
|
50
|
+
"q4_0",
|
51
|
+
"q4_1",
|
52
|
+
"q5_0",
|
53
|
+
"q5_1",
|
54
|
+
"q8_0",
|
55
|
+
"q2_k",
|
56
|
+
"q3_k",
|
57
|
+
"q4_k",
|
58
|
+
"q5_k",
|
59
|
+
"q6_k",
|
60
|
+
"iq4_nl"
|
61
|
+
};
|
62
|
+
|
63
|
+
namespace {
|
64
|
+
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
65
|
+
#ifdef _WIN32
|
66
|
+
HANDLE stdout_read, stdout_write;
|
67
|
+
HANDLE stderr_read, stderr_write;
|
68
|
+
SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
|
69
|
+
|
70
|
+
if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
|
71
|
+
!SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
|
72
|
+
throw std::runtime_error("Failed to create stdout pipe");
|
73
|
+
}
|
74
|
+
|
75
|
+
if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
|
76
|
+
!SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
|
77
|
+
throw std::runtime_error("Failed to create stderr pipe");
|
78
|
+
}
|
79
|
+
|
80
|
+
PROCESS_INFORMATION pi;
|
81
|
+
STARTUPINFOA si = { sizeof(STARTUPINFOA) };
|
82
|
+
si.dwFlags = STARTF_USESTDHANDLES;
|
83
|
+
si.hStdOutput = stdout_write;
|
84
|
+
si.hStdError = stderr_write;
|
85
|
+
|
86
|
+
std::vector<char> cmd(command.begin(), command.end());
|
87
|
+
cmd.push_back('\0');
|
88
|
+
|
89
|
+
if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
|
90
|
+
throw std::runtime_error("Failed to create process");
|
91
|
+
}
|
92
|
+
|
93
|
+
CloseHandle(stdout_write);
|
94
|
+
CloseHandle(stderr_write);
|
95
|
+
|
96
|
+
std::array<char, 128> buffer;
|
97
|
+
DWORD bytes_read;
|
98
|
+
|
99
|
+
while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
|
100
|
+
stdout_str.append(buffer.data(), bytes_read);
|
101
|
+
}
|
102
|
+
|
103
|
+
while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
|
104
|
+
stderr_str.append(buffer.data(), bytes_read);
|
105
|
+
}
|
106
|
+
|
107
|
+
CloseHandle(stdout_read);
|
108
|
+
CloseHandle(stderr_read);
|
109
|
+
WaitForSingleObject(pi.hProcess, INFINITE);
|
110
|
+
CloseHandle(pi.hProcess);
|
111
|
+
CloseHandle(pi.hThread);
|
112
|
+
#else
|
113
|
+
int stdout_pipe[2];
|
114
|
+
int stderr_pipe[2];
|
115
|
+
|
116
|
+
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
|
117
|
+
throw std::runtime_error("Failed to create pipes");
|
118
|
+
}
|
119
|
+
|
120
|
+
pid_t pid = fork();
|
121
|
+
if (pid < 0) {
|
122
|
+
throw std::runtime_error("Failed to fork process");
|
123
|
+
}
|
124
|
+
|
125
|
+
if (pid == 0) {
|
126
|
+
close(stdout_pipe[0]);
|
127
|
+
close(stderr_pipe[0]);
|
128
|
+
dup2(stdout_pipe[1], STDOUT_FILENO);
|
129
|
+
dup2(stderr_pipe[1], STDERR_FILENO);
|
130
|
+
close(stdout_pipe[1]);
|
131
|
+
close(stderr_pipe[1]);
|
132
|
+
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
|
133
|
+
_exit(EXIT_FAILURE);
|
134
|
+
} else {
|
135
|
+
close(stdout_pipe[1]);
|
136
|
+
close(stderr_pipe[1]);
|
137
|
+
|
138
|
+
std::array<char, 128> buffer;
|
139
|
+
ssize_t bytes_read;
|
140
|
+
|
141
|
+
while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
|
142
|
+
stdout_str.append(buffer.data(), bytes_read);
|
143
|
+
}
|
144
|
+
|
145
|
+
while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
|
146
|
+
stderr_str.append(buffer.data(), bytes_read);
|
147
|
+
}
|
148
|
+
|
149
|
+
close(stdout_pipe[0]);
|
150
|
+
close(stderr_pipe[0]);
|
151
|
+
waitpid(pid, nullptr, 0);
|
152
|
+
}
|
153
|
+
#endif
|
154
|
+
}
|
155
|
+
|
156
|
+
bool directory_exists(const std::string& path) {
|
157
|
+
struct stat info;
|
158
|
+
if (stat(path.c_str(), &info) != 0) {
|
159
|
+
return false; // Path doesn't exist or can't be accessed
|
160
|
+
}
|
161
|
+
return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
|
162
|
+
}
|
163
|
+
|
164
|
+
bool create_directory(const std::string& path) {
|
165
|
+
#ifdef _WIN32
|
166
|
+
return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
|
167
|
+
#else
|
168
|
+
return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
|
169
|
+
#endif
|
170
|
+
}
|
171
|
+
|
172
|
+
std::string to_uppercase(const std::string& input) {
|
173
|
+
std::string result = input;
|
174
|
+
for (char& c : result) {
|
175
|
+
c = std::toupper(c);
|
176
|
+
}
|
177
|
+
return result;
|
178
|
+
}
|
179
|
+
|
180
|
+
bool string_ends_with(const std::string& str, const std::string& suffix) {
|
181
|
+
if (suffix.size() > str.size()) {
|
182
|
+
return false;
|
183
|
+
}
|
184
|
+
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
185
|
+
}
|
186
|
+
|
187
|
+
static const char path_separator = '/';
|
188
|
+
|
189
|
+
std::string join_paths(const std::string& path1, const std::string& path2) {
|
190
|
+
return path1 + path_separator + path2;
|
191
|
+
}
|
192
|
+
|
193
|
+
std::string basename(const std::string &path) {
|
194
|
+
return path.substr(path.find_last_of("/\\") + 1);
|
195
|
+
}
|
196
|
+
|
197
|
+
// variables to track number of compiles in progress
|
198
|
+
static uint32_t compile_count = 0;
|
199
|
+
static std::mutex compile_count_mutex;
|
200
|
+
static std::condition_variable compile_count_cond;
|
201
|
+
|
202
|
+
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
203
|
+
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
204
|
+
std::string out_fname = join_paths(output_dir, name + ".spv");
|
205
|
+
std::string in_path = join_paths(input_dir, in_fname);
|
206
|
+
|
207
|
+
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
|
208
|
+
|
209
|
+
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
|
210
|
+
std::string opt_level = coopmat ? "" : "-O";
|
211
|
+
|
212
|
+
#ifdef _WIN32
|
213
|
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
|
214
|
+
#else
|
215
|
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
|
216
|
+
#endif
|
217
|
+
|
218
|
+
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
|
219
|
+
cmd.push_back("-g");
|
220
|
+
#endif
|
221
|
+
|
222
|
+
for (const auto& define : defines) {
|
223
|
+
cmd.push_back("-D" + define.first + "=" + define.second);
|
224
|
+
}
|
225
|
+
|
226
|
+
std::string command;
|
227
|
+
for (const auto& part : cmd) {
|
228
|
+
command += part + " ";
|
229
|
+
}
|
230
|
+
|
231
|
+
std::string stdout_str, stderr_str;
|
232
|
+
try {
|
233
|
+
// std::cout << "Executing command: ";
|
234
|
+
// for (const auto& part : cmd) {
|
235
|
+
// std::cout << part << " ";
|
236
|
+
// }
|
237
|
+
// std::cout << std::endl;
|
238
|
+
|
239
|
+
execute_command(command, stdout_str, stderr_str);
|
240
|
+
if (!stderr_str.empty()) {
|
241
|
+
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
|
242
|
+
return;
|
243
|
+
}
|
244
|
+
|
245
|
+
std::lock_guard<std::mutex> guard(lock);
|
246
|
+
shader_fnames.push_back(std::make_pair(name, out_fname));
|
247
|
+
} catch (const std::exception& e) {
|
248
|
+
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
249
|
+
}
|
250
|
+
{
|
251
|
+
std::lock_guard<std::mutex> guard(compile_count_mutex);
|
252
|
+
assert(compile_count > 0);
|
253
|
+
compile_count--;
|
254
|
+
}
|
255
|
+
compile_count_cond.notify_all();
|
256
|
+
}
|
257
|
+
|
258
|
+
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
|
259
|
+
std::map<std::string, std::string> result = a;
|
260
|
+
result.insert(b.begin(), b.end());
|
261
|
+
return result;
|
262
|
+
}
|
263
|
+
|
264
|
+
static std::vector<std::future<void>> compiles;
|
265
|
+
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
266
|
+
{
|
267
|
+
// wait until fewer than N compiles are in progress.
|
268
|
+
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
|
269
|
+
uint32_t N = 16;
|
270
|
+
std::unique_lock<std::mutex> guard(compile_count_mutex);
|
271
|
+
while (compile_count >= N) {
|
272
|
+
compile_count_cond.wait(guard);
|
273
|
+
}
|
274
|
+
compile_count++;
|
275
|
+
}
|
276
|
+
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
277
|
+
}
|
278
|
+
|
279
|
+
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
280
|
+
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
281
|
+
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
282
|
+
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
283
|
+
|
284
|
+
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
|
285
|
+
std::string shader_name = "matmul";
|
286
|
+
|
287
|
+
if (matmul_id) {
|
288
|
+
base_dict["MUL_MAT_ID"] = "1";
|
289
|
+
shader_name = "matmul_id";
|
290
|
+
}
|
291
|
+
|
292
|
+
if (fp16) {
|
293
|
+
base_dict["FLOAT16"] = "1";
|
294
|
+
}
|
295
|
+
|
296
|
+
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
297
|
+
|
298
|
+
if (coopmat) {
|
299
|
+
base_dict["COOPMAT"] = "1";
|
300
|
+
}
|
301
|
+
|
302
|
+
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
303
|
+
|
304
|
+
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
305
|
+
|
306
|
+
// Shaders with f16 B_TYPE
|
307
|
+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
308
|
+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
309
|
+
|
310
|
+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
311
|
+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
312
|
+
|
313
|
+
for (const auto& tname : type_names) {
|
314
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
315
|
+
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
316
|
+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
|
317
|
+
// For aligned matmul loads
|
318
|
+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
319
|
+
|
320
|
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
321
|
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
322
|
+
|
323
|
+
if (tname != "f16" && tname != "f32") {
|
324
|
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
325
|
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
326
|
+
}
|
327
|
+
}
|
328
|
+
}
|
329
|
+
|
330
|
+
void process_shaders() {
|
331
|
+
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
332
|
+
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
333
|
+
|
334
|
+
// matmul
|
335
|
+
for (const auto& matmul_id : {false, true}) {
|
336
|
+
// No coopmats
|
337
|
+
// fp32
|
338
|
+
matmul_shaders(false, matmul_id, false, false, false);
|
339
|
+
|
340
|
+
// fp16, fp32acc and fp16acc
|
341
|
+
matmul_shaders(true, matmul_id, false, false, false);
|
342
|
+
matmul_shaders(true, matmul_id, false, false, true);
|
343
|
+
|
344
|
+
// Coopmat, fp32acc and fp16acc
|
345
|
+
matmul_shaders(true, matmul_id, true, false, false);
|
346
|
+
matmul_shaders(true, matmul_id, true, false, true);
|
347
|
+
|
348
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
349
|
+
// Coopmat2, fp32acc and fp16acc
|
350
|
+
matmul_shaders(true, matmul_id, false, true, false);
|
351
|
+
matmul_shaders(true, matmul_id, false, true, true);
|
352
|
+
#endif
|
353
|
+
}
|
354
|
+
|
355
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
356
|
+
// flash attention
|
357
|
+
for (const auto& f16acc : {false, true}) {
|
358
|
+
std::string acctype = f16acc ? "float16_t" : "float";
|
359
|
+
|
360
|
+
for (const auto& tname : type_names) {
|
361
|
+
if (tname == "f32") {
|
362
|
+
continue;
|
363
|
+
}
|
364
|
+
|
365
|
+
if (tname == "f16") {
|
366
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
367
|
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
368
|
+
} else {
|
369
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
370
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
371
|
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
372
|
+
}
|
373
|
+
}
|
374
|
+
}
|
375
|
+
#endif
|
376
|
+
|
377
|
+
for (const auto& tname : type_names) {
|
378
|
+
// mul mat vec
|
379
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
380
|
+
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
381
|
+
|
382
|
+
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
383
|
+
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
384
|
+
|
385
|
+
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
386
|
+
|
387
|
+
// Dequant shaders
|
388
|
+
if (tname != "f16") {
|
389
|
+
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
390
|
+
}
|
391
|
+
|
392
|
+
if (!string_ends_with(tname, "_k")) {
|
393
|
+
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
394
|
+
|
395
|
+
if (tname == "f16") {
|
396
|
+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
397
|
+
} else {
|
398
|
+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
399
|
+
}
|
400
|
+
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
401
|
+
}
|
402
|
+
}
|
403
|
+
|
404
|
+
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
405
|
+
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
406
|
+
|
407
|
+
// Norms
|
408
|
+
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
409
|
+
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
410
|
+
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
411
|
+
|
412
|
+
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
413
|
+
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
414
|
+
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
415
|
+
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
416
|
+
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
417
|
+
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
418
|
+
|
419
|
+
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
420
|
+
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
421
|
+
|
422
|
+
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
423
|
+
|
424
|
+
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
425
|
+
|
426
|
+
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
427
|
+
|
428
|
+
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
429
|
+
|
430
|
+
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
431
|
+
|
432
|
+
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
433
|
+
|
434
|
+
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
435
|
+
|
436
|
+
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
437
|
+
|
438
|
+
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
439
|
+
|
440
|
+
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
441
|
+
|
442
|
+
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
443
|
+
|
444
|
+
string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
445
|
+
string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
446
|
+
string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
|
447
|
+
|
448
|
+
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
449
|
+
|
450
|
+
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
451
|
+
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
452
|
+
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
453
|
+
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
454
|
+
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
455
|
+
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
456
|
+
|
457
|
+
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
458
|
+
|
459
|
+
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
460
|
+
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
461
|
+
|
462
|
+
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
463
|
+
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
464
|
+
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
465
|
+
|
466
|
+
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
467
|
+
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
468
|
+
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
469
|
+
|
470
|
+
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
471
|
+
|
472
|
+
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
473
|
+
|
474
|
+
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
475
|
+
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
476
|
+
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
477
|
+
|
478
|
+
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
479
|
+
|
480
|
+
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
481
|
+
|
482
|
+
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
483
|
+
|
484
|
+
for (auto &c : compiles) {
|
485
|
+
c.wait();
|
486
|
+
}
|
487
|
+
}
|
488
|
+
|
489
|
+
void write_output_files() {
|
490
|
+
FILE* hdr = fopen(target_hpp.c_str(), "w");
|
491
|
+
FILE* src = fopen(target_cpp.c_str(), "w");
|
492
|
+
|
493
|
+
fprintf(hdr, "#include <cstdint>\n\n");
|
494
|
+
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
|
495
|
+
|
496
|
+
for (const auto& pair : shader_fnames) {
|
497
|
+
const std::string& name = pair.first;
|
498
|
+
#ifdef _WIN32
|
499
|
+
std::string path = pair.second;
|
500
|
+
std::replace(path.begin(), path.end(), '/', '\\' );
|
501
|
+
#else
|
502
|
+
const std::string& path = pair.second;
|
503
|
+
#endif
|
504
|
+
|
505
|
+
FILE* spv = fopen(path.c_str(), "rb");
|
506
|
+
if (!spv) {
|
507
|
+
std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
|
508
|
+
continue;
|
509
|
+
}
|
510
|
+
|
511
|
+
fseek(spv, 0, SEEK_END);
|
512
|
+
size_t size = ftell(spv);
|
513
|
+
fseek(spv, 0, SEEK_SET);
|
514
|
+
|
515
|
+
std::vector<unsigned char> data(size);
|
516
|
+
size_t read_size = fread(data.data(), 1, size, spv);
|
517
|
+
fclose(spv);
|
518
|
+
if (read_size != size) {
|
519
|
+
std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
|
520
|
+
continue;
|
521
|
+
}
|
522
|
+
|
523
|
+
fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
|
524
|
+
fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
|
525
|
+
|
526
|
+
fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
|
527
|
+
for (size_t i = 0; i < size; ++i) {
|
528
|
+
fprintf(src, "0x%02x,", data[i]);
|
529
|
+
if ((i + 1) % 12 == 0) fprintf(src, "\n");
|
530
|
+
}
|
531
|
+
fprintf(src, "\n};\n\n");
|
532
|
+
|
533
|
+
if (!no_clean) {
|
534
|
+
std::remove(path.c_str());
|
535
|
+
}
|
536
|
+
}
|
537
|
+
|
538
|
+
fclose(hdr);
|
539
|
+
fclose(src);
|
540
|
+
}
|
541
|
+
}
|
542
|
+
|
543
|
+
int main(int argc, char** argv) {
|
544
|
+
std::map<std::string, std::string> args;
|
545
|
+
for (int i = 1; i < argc; ++i) {
|
546
|
+
std::string arg = argv[i];
|
547
|
+
if (arg.rfind("--", 0) == 0) {
|
548
|
+
if (i + 1 < argc && argv[i + 1][0] != '-') {
|
549
|
+
args[arg] = argv[i + 1];
|
550
|
+
++i;
|
551
|
+
} else {
|
552
|
+
args[arg] = "";
|
553
|
+
}
|
554
|
+
}
|
555
|
+
}
|
556
|
+
|
557
|
+
if (args.find("--glslc") != args.end()) {
|
558
|
+
GLSLC = args["--glslc"]; // Path to glslc
|
559
|
+
}
|
560
|
+
if (args.find("--input-dir") != args.end()) {
|
561
|
+
input_dir = args["--input-dir"]; // Directory containing shader sources
|
562
|
+
}
|
563
|
+
if (args.find("--output-dir") != args.end()) {
|
564
|
+
output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
|
565
|
+
}
|
566
|
+
if (args.find("--target-hpp") != args.end()) {
|
567
|
+
target_hpp = args["--target-hpp"]; // Path to generated header file
|
568
|
+
}
|
569
|
+
if (args.find("--target-cpp") != args.end()) {
|
570
|
+
target_cpp = args["--target-cpp"]; // Path to generated cpp file
|
571
|
+
}
|
572
|
+
if (args.find("--no-clean") != args.end()) {
|
573
|
+
no_clean = true; // Keep temporary SPIR-V files in output-dir after build
|
574
|
+
}
|
575
|
+
|
576
|
+
if (!directory_exists(input_dir)) {
|
577
|
+
std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
|
578
|
+
return EXIT_FAILURE;
|
579
|
+
}
|
580
|
+
|
581
|
+
if (!directory_exists(output_dir)) {
|
582
|
+
if (!create_directory(output_dir)) {
|
583
|
+
std::cerr << "Error creating output directory: " << output_dir << "\n";
|
584
|
+
return EXIT_FAILURE;
|
585
|
+
}
|
586
|
+
}
|
587
|
+
|
588
|
+
process_shaders();
|
589
|
+
|
590
|
+
write_output_files();
|
591
|
+
|
592
|
+
return EXIT_SUCCESS;
|
593
|
+
}
|