llama_cpp 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/ext/llama_cpp/extconf.rb +7 -0
- data/ext/llama_cpp/llama_cpp.cpp +60 -6
- data/ext/llama_cpp/src/ggml-cuda.h +2 -0
- data/ext/llama_cpp/src/ggml-opencl.c +246 -133
- data/ext/llama_cpp/src/ggml.c +362 -137
- data/ext/llama_cpp/src/ggml.h +13 -3
- data/ext/llama_cpp/src/llama-util.h +23 -23
- data/ext/llama_cpp/src/llama.cpp +173 -102
- data/ext/llama_cpp/src/llama.h +30 -17
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -0
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -10,87 +10,77 @@
|
|
10
10
|
#include "ggml.h"
|
11
11
|
|
12
12
|
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
13
|
-
const char *
|
13
|
+
static const char * program_source = MULTILINE_QUOTE(
|
14
14
|
|
15
|
+
typedef char int8_t;
|
15
16
|
typedef uchar uint8_t;
|
16
17
|
typedef int int32_t;
|
17
18
|
typedef uint uint32_t;
|
18
19
|
|
19
|
-
|
20
|
-
struct block_q4_0
|
20
|
+
struct __attribute__ ((packed)) block_q4_0
|
21
21
|
{
|
22
|
-
|
23
|
-
uint8_t qs[QK4_0 / 2
|
22
|
+
half d;
|
23
|
+
uint8_t qs[16]; /* QK4_0 / 2 */
|
24
24
|
};
|
25
25
|
|
26
|
-
|
27
|
-
struct block_q4_1
|
26
|
+
struct __attribute__ ((packed)) block_q4_1
|
28
27
|
{
|
29
|
-
|
30
|
-
|
31
|
-
uint8_t qs[QK4_1 / 2
|
28
|
+
half d;
|
29
|
+
half m;
|
30
|
+
uint8_t qs[16]; /* QK4_1 / 2 */
|
32
31
|
};
|
33
32
|
|
34
|
-
constant uint QK5_0 = 32;
|
35
33
|
struct __attribute__ ((packed)) block_q5_0
|
36
34
|
{
|
37
35
|
half d;
|
38
36
|
uint32_t qh;
|
39
|
-
uint8_t qs[QK5_0 / 2
|
37
|
+
uint8_t qs[16]; /* QK5_0 / 2 */
|
40
38
|
};
|
41
39
|
|
42
|
-
|
43
|
-
struct block_q5_1
|
40
|
+
struct __attribute__ ((packed)) block_q5_1
|
44
41
|
{
|
45
42
|
half d;
|
46
43
|
half m;
|
47
44
|
uint32_t qh;
|
48
|
-
uint8_t qs[QK5_1 / 2
|
45
|
+
uint8_t qs[16]; /* QK5_1 / 2 */
|
49
46
|
};
|
50
47
|
|
51
|
-
|
52
|
-
struct block_q8_0
|
48
|
+
struct __attribute__ ((packed)) block_q8_0
|
53
49
|
{
|
54
|
-
|
55
|
-
|
50
|
+
half d;
|
51
|
+
int8_t qs[32]; /* QK8_0 */
|
56
52
|
};
|
57
53
|
|
58
54
|
|
59
55
|
__kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
|
60
|
-
|
61
|
-
|
62
|
-
const uint i = get_global_id(0) / qk;
|
56
|
+
const uint i = get_global_id(0) / 32; /* QK4_0 */
|
63
57
|
const uint j = get_local_id(0);
|
64
58
|
|
65
|
-
const float d = x[i].d;
|
59
|
+
const float d = vload_half(0, (__global half*) &x[i].d);
|
66
60
|
|
67
61
|
const int x0 = (x[i].qs[j] & 0xf) - 8;
|
68
62
|
const int x1 = (x[i].qs[j] >> 4) - 8;
|
69
63
|
|
70
|
-
y[i*
|
71
|
-
y[i*
|
64
|
+
y[i*32 + j + 0 ] = x0*d;
|
65
|
+
y[i*32 + j + 16] = x1*d;
|
72
66
|
}
|
73
67
|
|
74
68
|
__kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
|
75
|
-
|
76
|
-
|
77
|
-
const uint i = get_global_id(0) / qk;
|
69
|
+
const uint i = get_global_id(0) / 32; /* QK4_1 */
|
78
70
|
const uint j = get_local_id(0);
|
79
71
|
|
80
|
-
const float d = x[i].d;
|
81
|
-
const float m = x[i].m;
|
72
|
+
const float d = vload_half(0, (__global half*) &x[i].d);
|
73
|
+
const float m = vload_half(0, (__global half*) &x[i].m);
|
82
74
|
|
83
75
|
const int x0 = (x[i].qs[j] & 0xf);
|
84
76
|
const int x1 = (x[i].qs[j] >> 4);
|
85
77
|
|
86
|
-
y[i*
|
87
|
-
y[i*
|
78
|
+
y[i*32 + j + 0 ] = x0*d + m;
|
79
|
+
y[i*32 + j + 16] = x1*d + m;
|
88
80
|
}
|
89
81
|
|
90
82
|
__kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
|
91
|
-
|
92
|
-
|
93
|
-
const uint i = get_global_id(0) / qk;
|
83
|
+
const uint i = get_global_id(0) / 32; /* QK5_0 */
|
94
84
|
const uint j = get_local_id(0);
|
95
85
|
|
96
86
|
const float d = vload_half(0, (__global half*) &x[i].d);
|
@@ -103,14 +93,12 @@ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float*
|
|
103
93
|
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
|
104
94
|
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
|
105
95
|
|
106
|
-
y[i*
|
107
|
-
y[i*
|
96
|
+
y[i*32 + j + 0 ] = x0*d;
|
97
|
+
y[i*32 + j + 16] = x1*d;
|
108
98
|
}
|
109
99
|
|
110
100
|
__kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
|
111
|
-
|
112
|
-
|
113
|
-
const uint i = get_global_id(0) / qk;
|
101
|
+
const uint i = get_global_id(0) / 32; /* QK5_1 */
|
114
102
|
const uint j = get_local_id(0);
|
115
103
|
|
116
104
|
const float d = vload_half(0, (__global half*) &x[i].d);
|
@@ -124,28 +112,38 @@ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float*
|
|
124
112
|
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
|
125
113
|
const int x1 = (x[i].qs[j] >> 4) | xh_1;
|
126
114
|
|
127
|
-
y[i*
|
128
|
-
y[i*
|
115
|
+
y[i*32 + j + 0 ] = x0*d + m;
|
116
|
+
y[i*32 + j + 16] = x1*d + m;
|
129
117
|
}
|
130
118
|
|
131
119
|
__kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
|
132
|
-
|
133
|
-
const uint i = get_global_id(0) / qk;
|
120
|
+
const uint i = get_global_id(0) / 32; /* QK8_0 */
|
134
121
|
const uint j = get_local_id(0);
|
135
122
|
|
136
|
-
const float d = x[i].d;
|
137
|
-
y[i*
|
123
|
+
const float d = vload_half(0, (__global half*) &x[i].d);
|
124
|
+
y[i*32 + j] = x[i].qs[j]*d;
|
138
125
|
}
|
139
126
|
|
140
127
|
);
|
141
128
|
|
142
|
-
#define CL_CHECK(err
|
143
|
-
do {
|
144
|
-
cl_int err_ = (err);
|
145
|
-
if (err_ != CL_SUCCESS) {
|
146
|
-
fprintf(stderr, "
|
147
|
-
|
148
|
-
|
129
|
+
#define CL_CHECK(err) \
|
130
|
+
do { \
|
131
|
+
cl_int err_ = (err); \
|
132
|
+
if (err_ != CL_SUCCESS) { \
|
133
|
+
fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
|
134
|
+
#err, err_, __FILE__, __LINE__); \
|
135
|
+
exit(1); \
|
136
|
+
} \
|
137
|
+
} while (0)
|
138
|
+
|
139
|
+
#define CLBLAST_CHECK(err) \
|
140
|
+
do { \
|
141
|
+
CLBlastStatusCode err_ = (err); \
|
142
|
+
if (err_ != CLBlastSuccess) { \
|
143
|
+
fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
|
144
|
+
#err, err_, __FILE__, __LINE__); \
|
145
|
+
exit(1); \
|
146
|
+
} \
|
149
147
|
} while (0)
|
150
148
|
|
151
149
|
static cl_platform_id platform;
|
@@ -188,48 +186,174 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
|
188
186
|
|
189
187
|
void ggml_cl_init(void) {
|
190
188
|
cl_int err = 0;
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
189
|
+
|
190
|
+
struct cl_device;
|
191
|
+
struct cl_platform {
|
192
|
+
cl_platform_id id;
|
193
|
+
unsigned number;
|
194
|
+
char name[128];
|
195
|
+
char vendor[128];
|
196
|
+
struct cl_device * devices;
|
197
|
+
unsigned n_devices;
|
198
|
+
struct cl_device * default_device;
|
199
|
+
};
|
200
|
+
|
201
|
+
struct cl_device {
|
202
|
+
struct cl_platform * platform;
|
203
|
+
cl_device_id id;
|
204
|
+
unsigned number;
|
205
|
+
cl_device_type type;
|
206
|
+
char name[128];
|
207
|
+
};
|
208
|
+
|
209
|
+
enum { NPLAT = 16, NDEV = 16 };
|
210
|
+
|
211
|
+
struct cl_platform platforms[NPLAT];
|
212
|
+
unsigned n_platforms = 0;
|
213
|
+
struct cl_device devices[NDEV];
|
214
|
+
unsigned n_devices = 0;
|
215
|
+
struct cl_device * default_device = NULL;
|
216
|
+
|
217
|
+
platform = NULL;
|
218
|
+
device = NULL;
|
219
|
+
|
220
|
+
cl_platform_id platform_ids[NPLAT];
|
221
|
+
CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
|
222
|
+
|
223
|
+
for (unsigned i = 0; i < n_platforms; i++) {
|
224
|
+
struct cl_platform * p = &platforms[i];
|
225
|
+
p->number = i;
|
226
|
+
p->id = platform_ids[i];
|
227
|
+
CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
|
228
|
+
CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
|
229
|
+
|
230
|
+
cl_device_id device_ids[NDEV];
|
231
|
+
cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
|
232
|
+
if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
|
233
|
+
p->n_devices = 0;
|
234
|
+
} else {
|
235
|
+
CL_CHECK(clGetDeviceIDsError);
|
236
|
+
}
|
237
|
+
p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
|
238
|
+
p->default_device = NULL;
|
239
|
+
|
240
|
+
for (unsigned j = 0; j < p->n_devices; j++) {
|
241
|
+
struct cl_device * d = &devices[n_devices];
|
242
|
+
d->number = n_devices++;
|
243
|
+
d->id = device_ids[j];
|
244
|
+
d->platform = p;
|
245
|
+
CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
|
246
|
+
CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
|
247
|
+
|
248
|
+
if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
|
249
|
+
p->default_device = d;
|
250
|
+
}
|
251
|
+
}
|
252
|
+
|
253
|
+
if (default_device == NULL && p->default_device != NULL) {
|
254
|
+
default_device = p->default_device;
|
255
|
+
}
|
256
|
+
}
|
257
|
+
|
258
|
+
if (n_devices == 0) {
|
259
|
+
fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
|
260
|
+
exit(1);
|
261
|
+
}
|
262
|
+
|
263
|
+
char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
|
264
|
+
char * user_device_string = getenv("GGML_OPENCL_DEVICE");
|
265
|
+
int user_platform_number = -1;
|
266
|
+
int user_device_number = -1;
|
267
|
+
|
268
|
+
unsigned n;
|
269
|
+
if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
|
270
|
+
user_platform_number = (int)n;
|
271
|
+
}
|
272
|
+
if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
|
273
|
+
user_device_number = (int)n;
|
274
|
+
}
|
275
|
+
|
276
|
+
struct cl_device * selected_devices = devices;
|
277
|
+
unsigned n_selected_devices = n_devices;
|
278
|
+
|
279
|
+
if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
|
280
|
+
for (unsigned i = 0; i < n_platforms; i++) {
|
281
|
+
struct cl_platform * p = &platforms[i];
|
282
|
+
if (strstr(p->name, user_platform_string) != NULL ||
|
283
|
+
strstr(p->vendor, user_platform_string) != NULL) {
|
284
|
+
user_platform_number = (int)i;
|
285
|
+
break;
|
286
|
+
}
|
287
|
+
}
|
288
|
+
if (user_platform_number == -1) {
|
289
|
+
fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
|
290
|
+
exit(1);
|
291
|
+
}
|
292
|
+
}
|
293
|
+
if (user_platform_number != -1) {
|
294
|
+
struct cl_platform * p = &platforms[user_platform_number];
|
295
|
+
selected_devices = p->devices;
|
296
|
+
n_selected_devices = p->n_devices;
|
297
|
+
default_device = p->default_device;
|
298
|
+
if (n_selected_devices == 0) {
|
299
|
+
fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
|
300
|
+
exit(1);
|
301
|
+
}
|
302
|
+
}
|
303
|
+
|
304
|
+
if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
|
305
|
+
for (unsigned i = 0; i < n_selected_devices; i++) {
|
306
|
+
struct cl_device * d = &selected_devices[i];
|
307
|
+
if (strstr(d->name, user_device_string) != NULL) {
|
308
|
+
user_device_number = d->number;
|
309
|
+
break;
|
310
|
+
}
|
311
|
+
}
|
312
|
+
if (user_device_number == -1) {
|
313
|
+
fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
|
314
|
+
exit(1);
|
315
|
+
}
|
316
|
+
}
|
317
|
+
if (user_device_number != -1) {
|
318
|
+
selected_devices = &devices[user_device_number];
|
319
|
+
n_selected_devices = 1;
|
320
|
+
default_device = &selected_devices[0];
|
321
|
+
}
|
322
|
+
|
323
|
+
GGML_ASSERT(n_selected_devices > 0);
|
324
|
+
|
325
|
+
if (default_device == NULL) {
|
326
|
+
default_device = &selected_devices[0];
|
327
|
+
}
|
328
|
+
|
329
|
+
fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
|
330
|
+
fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
|
331
|
+
if (default_device->type != CL_DEVICE_TYPE_GPU) {
|
332
|
+
fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
|
333
|
+
}
|
334
|
+
|
335
|
+
platform = default_device->platform->id;
|
336
|
+
device = default_device->id;
|
337
|
+
|
338
|
+
cl_context_properties properties[] = {
|
339
|
+
(intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
|
340
|
+
};
|
341
|
+
|
342
|
+
CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
|
343
|
+
|
344
|
+
CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
|
345
|
+
(err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
|
346
|
+
(queue = clCreateCommandQueue(context, device, 0, &err), err)
|
347
|
+
)));
|
348
|
+
|
349
|
+
program = build_program_from_source(context, device, program_source);
|
221
350
|
|
222
351
|
// Prepare dequantize kernels
|
223
|
-
kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
|
224
|
-
CL_CHECK(
|
225
|
-
|
226
|
-
CL_CHECK(
|
227
|
-
|
228
|
-
CL_CHECK(err, "clCreateKernel");
|
229
|
-
kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
|
230
|
-
CL_CHECK(err, "clCreateKernel");
|
231
|
-
kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
|
232
|
-
CL_CHECK(err, "clCreateKernel");
|
352
|
+
CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
|
353
|
+
CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
|
354
|
+
CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
|
355
|
+
CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
|
356
|
+
CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
|
233
357
|
}
|
234
358
|
|
235
359
|
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
|
@@ -242,9 +366,8 @@ static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags
|
|
242
366
|
clReleaseMemObject(*buf);
|
243
367
|
}
|
244
368
|
cl_int err;
|
245
|
-
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
|
369
|
+
CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
|
246
370
|
*cur_size = req_size;
|
247
|
-
CL_CHECK(err, "clCreateBuffer");
|
248
371
|
}
|
249
372
|
|
250
373
|
void ggml_cl_sgemm_wrapper(
|
@@ -253,7 +376,6 @@ void ggml_cl_sgemm_wrapper(
|
|
253
376
|
const float alpha, const void *host_a, const int lda,
|
254
377
|
const float *host_b, const int ldb, const float beta,
|
255
378
|
float *host_c, const int ldc, const int btype) {
|
256
|
-
cl_int err = 0;
|
257
379
|
|
258
380
|
cl_kernel kernel;
|
259
381
|
size_t global = n * k, local, size_qb;
|
@@ -267,13 +389,13 @@ void ggml_cl_sgemm_wrapper(
|
|
267
389
|
dequant = true;
|
268
390
|
kernel = kernel_q4_0;
|
269
391
|
local = 16;
|
270
|
-
size_qb = global * (sizeof(
|
392
|
+
size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
|
271
393
|
break;
|
272
394
|
case GGML_TYPE_Q4_1:
|
273
395
|
dequant = true;
|
274
396
|
kernel = kernel_q4_1;
|
275
397
|
local = 16;
|
276
|
-
size_qb = global * (sizeof(
|
398
|
+
size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
|
277
399
|
break;
|
278
400
|
case GGML_TYPE_Q5_0:
|
279
401
|
dequant = true;
|
@@ -291,7 +413,7 @@ void ggml_cl_sgemm_wrapper(
|
|
291
413
|
dequant = true;
|
292
414
|
kernel = kernel_q8_0;
|
293
415
|
local = 32;
|
294
|
-
size_qb = global * (sizeof(
|
416
|
+
size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
|
295
417
|
break;
|
296
418
|
default:
|
297
419
|
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
|
@@ -313,49 +435,40 @@ void ggml_cl_sgemm_wrapper(
|
|
313
435
|
cl_event ev_a, ev_qb, ev_b;
|
314
436
|
|
315
437
|
if (dequant) {
|
316
|
-
|
317
|
-
|
318
|
-
CL_CHECK(
|
319
|
-
err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
|
320
|
-
CL_CHECK(err, "clEnqueueWriteBuffer qb");
|
438
|
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
|
439
|
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
|
440
|
+
CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
|
321
441
|
} else {
|
322
|
-
|
323
|
-
CL_CHECK(err, "clEnqueueWriteBuffer b");
|
442
|
+
CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
|
324
443
|
}
|
325
444
|
|
326
|
-
|
327
|
-
CL_CHECK(err, "clEnqueueWriteBuffer a");
|
445
|
+
CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
|
328
446
|
if (dequant) {
|
329
|
-
|
330
|
-
CL_CHECK(
|
331
|
-
clReleaseEvent(ev_qb);
|
447
|
+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
|
448
|
+
CL_CHECK(clReleaseEvent(ev_qb));
|
332
449
|
}
|
333
|
-
clWaitForEvents(1, &ev_a);
|
334
|
-
clWaitForEvents(1, &ev_b);
|
335
|
-
clReleaseEvent(ev_a);
|
336
|
-
clReleaseEvent(ev_b);
|
450
|
+
CL_CHECK(clWaitForEvents(1, &ev_a));
|
451
|
+
CL_CHECK(clWaitForEvents(1, &ev_b));
|
452
|
+
CL_CHECK(clReleaseEvent(ev_a));
|
453
|
+
CL_CHECK(clReleaseEvent(ev_b));
|
337
454
|
|
338
455
|
cl_event ev_sgemm;
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
if (status != CLBlastSuccess) {
|
350
|
-
fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
|
351
|
-
abort();
|
352
|
-
}
|
456
|
+
CLBLAST_CHECK(CLBlastSgemm(
|
457
|
+
(CLBlastLayout)order,
|
458
|
+
(CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
|
459
|
+
m, n, k,
|
460
|
+
alpha,
|
461
|
+
cl_buffer_a, 0, lda,
|
462
|
+
cl_buffer_b, 0, ldb,
|
463
|
+
beta,
|
464
|
+
cl_buffer_c, 0, ldc,
|
465
|
+
&queue, &ev_sgemm));
|
353
466
|
|
354
467
|
cl_event ev_c;
|
355
|
-
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
|
468
|
+
CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
|
356
469
|
|
357
470
|
// Wait for completion
|
358
|
-
clWaitForEvents(1, &ev_c);
|
359
|
-
clReleaseEvent(ev_sgemm);
|
360
|
-
clReleaseEvent(ev_c);
|
471
|
+
CL_CHECK(clWaitForEvents(1, &ev_c));
|
472
|
+
CL_CHECK(clReleaseEvent(ev_sgemm));
|
473
|
+
CL_CHECK(clReleaseEvent(ev_c));
|
361
474
|
}
|