llama_cpp 0.1.0 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,164 +10,148 @@
10
10
  #include "ggml.h"
11
11
 
12
12
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
13
- const char * clblast_dequant = MULTILINE_QUOTE(
13
+ static const char * program_source = MULTILINE_QUOTE(
14
14
 
15
- struct block_q4_0
15
+ typedef char int8_t;
16
+ typedef uchar uint8_t;
17
+ typedef int int32_t;
18
+ typedef uint uint32_t;
19
+
20
+ struct __attribute__ ((packed)) block_q4_0
16
21
  {
17
- float d;
18
- uchar qs[16];
22
+ half d;
23
+ uint8_t qs[16]; /* QK4_0 / 2 */
19
24
  };
20
25
 
21
- __kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
22
- const uint i = get_global_id(0) / 32;
23
- const uint l = get_local_id(0);
24
-
25
- const float d = blocks[i].d;
26
+ struct __attribute__ ((packed)) block_q4_1
27
+ {
28
+ half d;
29
+ half m;
30
+ uint8_t qs[16]; /* QK4_1 / 2 */
31
+ };
26
32
 
27
- const uchar vi = blocks[i].qs[l];
33
+ struct __attribute__ ((packed)) block_q5_0
34
+ {
35
+ half d;
36
+ uint32_t qh;
37
+ uint8_t qs[16]; /* QK5_0 / 2 */
38
+ };
28
39
 
29
- const uint index = i*32 + l*2;
30
- result[index + 0] = ((vi & 0xf) - 8)*d;
31
- result[index + 1] = ((vi >> 4) - 8)*d;
32
- }
40
+ struct __attribute__ ((packed)) block_q5_1
41
+ {
42
+ half d;
43
+ half m;
44
+ uint32_t qh;
45
+ uint8_t qs[16]; /* QK5_1 / 2 */
46
+ };
33
47
 
34
- struct block_q4_1
48
+ struct __attribute__ ((packed)) block_q8_0
35
49
  {
36
- float d;
37
- float m;
38
- uchar qs[16];
50
+ half d;
51
+ int8_t qs[32]; /* QK8_0 */
39
52
  };
40
53
 
41
- __kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
42
- const uint i = get_global_id(0) / 32;
43
- const uint l = get_local_id(0);
44
54
 
45
- const float d = blocks[i].d;
46
- const float m = blocks[i].m;
55
+ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
56
+ const uint i = get_global_id(0) / 32; /* QK4_0 */
57
+ const uint j = get_local_id(0);
47
58
 
48
- const uchar vi = blocks[i].qs[l];
59
+ const float d = vload_half(0, (__global half*) &x[i].d);
49
60
 
50
- const uint index = i*32 + l*2;
51
- result[index + 0] = (vi & 0xf) * d + m;
52
- result[index + 1] = (vi >> 4) * d + m;
53
- }
61
+ const int x0 = (x[i].qs[j] & 0xf) - 8;
62
+ const int x1 = (x[i].qs[j] >> 4) - 8;
54
63
 
55
- struct block_q4_2
56
- {
57
- ushort d;
58
- uchar qs[8];
59
- };
64
+ y[i*32 + j + 0 ] = x0*d;
65
+ y[i*32 + j + 16] = x1*d;
66
+ }
60
67
 
61
- __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
62
- const uint i = get_global_id(0) / 16;
63
- const uint l = get_local_id(0);
68
+ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
69
+ const uint i = get_global_id(0) / 32; /* QK4_1 */
70
+ const uint j = get_local_id(0);
64
71
 
65
- const float d = vload_half(0, (__global half*) &blocks[i].d);
72
+ const float d = vload_half(0, (__global half*) &x[i].d);
73
+ const float m = vload_half(0, (__global half*) &x[i].m);
66
74
 
67
- const uchar vi = blocks[i].qs[l];
75
+ const int x0 = (x[i].qs[j] & 0xf);
76
+ const int x1 = (x[i].qs[j] >> 4);
68
77
 
69
- const uint index = i*16 + l*2;
70
- result[index + 0] = ((vi & 0xf) - 8)*d;
71
- result[index + 1] = ((vi >> 4) - 8)*d;
78
+ y[i*32 + j + 0 ] = x0*d + m;
79
+ y[i*32 + j + 16] = x1*d + m;
72
80
  }
73
81
 
82
+ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
83
+ const uint i = get_global_id(0) / 32; /* QK5_0 */
84
+ const uint j = get_local_id(0);
74
85
 
75
- struct block_q5_0
76
- {
77
- float d;
78
- uint qh;
79
- uchar qs[16];
80
- };
81
-
82
- __kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
83
- const uint i = get_global_id(0) / 32;
84
- const uint l = get_local_id(0);
85
-
86
- const float d = blocks[i].d;
86
+ const float d = vload_half(0, (__global half*) &x[i].d);
87
87
 
88
- const uchar vi = blocks[i].qs[l];
88
+ uint32_t qh = x[i].qh;
89
89
 
90
- const uint l2 = l * 2;
90
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
91
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
91
92
 
92
- const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
93
- const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
93
+ const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
94
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
94
95
 
95
- const uint index = i*32 + l2;
96
- result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
97
- result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
96
+ y[i*32 + j + 0 ] = x0*d;
97
+ y[i*32 + j + 16] = x1*d;
98
98
  }
99
99
 
100
- struct block_q5_1
101
- {
102
- ushort d;
103
- ushort m;
104
- uint qh;
105
- uchar qs[16];
106
- };
107
-
108
- __kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
109
- const uint i = get_global_id(0) / 32;
110
- const uint l = get_local_id(0);
100
+ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
101
+ const uint i = get_global_id(0) / 32; /* QK5_1 */
102
+ const uint j = get_local_id(0);
111
103
 
112
- const float d = vload_half(0, (__global half*) &blocks[i].d);
113
- const float m = vload_half(0, (__global half*) &blocks[i].m);
104
+ const float d = vload_half(0, (__global half*) &x[i].d);
105
+ const float m = vload_half(0, (__global half*) &x[i].m);
114
106
 
115
- const uchar vi = blocks[i].qs[l];
107
+ uint32_t qh = x[i].qh;
116
108
 
117
- const uint l2 = l * 2;
109
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
110
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
118
111
 
119
- const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
120
- const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
112
+ const int x0 = (x[i].qs[j] & 0xf) | xh_0;
113
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
121
114
 
122
- const uint index = i*32 + l2;
123
- result[index + 0] = ((vi & 0xf) | vh0)*d + m;
124
- result[index + 1] = ((vi >> 4) | vh1)*d + m;
115
+ y[i*32 + j + 0 ] = x0*d + m;
116
+ y[i*32 + j + 16] = x1*d + m;
125
117
  }
126
118
 
127
- struct block_q8_0
128
- {
129
- float d;
130
- char qs[32];
131
- };
119
+ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
120
+ const uint i = get_global_id(0) / 32; /* QK8_0 */
121
+ const uint j = get_local_id(0);
132
122
 
133
- __kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
134
- const uint i = get_global_id(0) / 32;
135
- const uint l = get_local_id(0);
136
-
137
- result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
123
+ const float d = vload_half(0, (__global half*) &x[i].d);
124
+ y[i*32 + j] = x[i].qs[j]*d;
138
125
  }
139
126
 
140
127
  );
141
128
 
142
- #define CL_CHECK(err, name) \
143
- do { \
144
- cl_int err_ = (err); \
145
- if (err_ != CL_SUCCESS) { \
146
- fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
147
- exit(1); \
148
- } \
129
+ #define CL_CHECK(err) \
130
+ do { \
131
+ cl_int err_ = (err); \
132
+ if (err_ != CL_SUCCESS) { \
133
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
134
+ #err, err_, __FILE__, __LINE__); \
135
+ exit(1); \
136
+ } \
149
137
  } while (0)
150
138
 
151
- #define QK5_0 32
152
- typedef struct {
153
- ggml_fp16_t d; // delta
154
- uint8_t qh[4]; // 5-th bit of quants
155
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
156
- } block_q5_0;
157
-
158
-
159
- typedef struct {
160
- float d; // delta
161
- uint32_t qh; // 5-th bit of quants
162
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
163
- } cl_block_q5_0;
139
+ #define CLBLAST_CHECK(err) \
140
+ do { \
141
+ CLBlastStatusCode err_ = (err); \
142
+ if (err_ != CLBlastSuccess) { \
143
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
144
+ #err, err_, __FILE__, __LINE__); \
145
+ exit(1); \
146
+ } \
147
+ } while (0)
164
148
 
165
149
  static cl_platform_id platform;
166
150
  static cl_device_id device;
167
151
  static cl_context context;
168
152
  static cl_command_queue queue;
169
153
  static cl_program program;
170
- static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
154
+ static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
171
155
  static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
172
156
  static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
173
157
 
@@ -202,50 +186,174 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
202
186
 
203
187
  void ggml_cl_init(void) {
204
188
  cl_int err = 0;
205
- char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
206
- char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
207
- int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
208
- int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
209
- printf("\nInitializing CLBlast (First Run)...");
210
- printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
211
- cl_uint num_platforms;
212
- clGetPlatformIDs(0, NULL, &num_platforms);
213
- cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
214
- clGetPlatformIDs(num_platforms, platforms, NULL);
215
- platform = platforms[plat_num];
216
- char platform_buffer[1024];
217
- clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
218
- cl_uint num_devices;
219
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
220
- cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
221
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
222
- device = devices[dev_num];
223
- char device_buffer[1024];
224
- clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
225
- printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
226
- context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
227
- CL_CHECK(err, "clCreateContext");
228
- queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
229
- CL_CHECK(err, "clCreateCommandQueue");
230
-
231
- free(platforms);
232
- free(devices);
233
-
234
- program = build_program_from_source(context, device, clblast_dequant);
189
+
190
+ struct cl_device;
191
+ struct cl_platform {
192
+ cl_platform_id id;
193
+ unsigned number;
194
+ char name[128];
195
+ char vendor[128];
196
+ struct cl_device * devices;
197
+ unsigned n_devices;
198
+ struct cl_device * default_device;
199
+ };
200
+
201
+ struct cl_device {
202
+ struct cl_platform * platform;
203
+ cl_device_id id;
204
+ unsigned number;
205
+ cl_device_type type;
206
+ char name[128];
207
+ };
208
+
209
+ enum { NPLAT = 16, NDEV = 16 };
210
+
211
+ struct cl_platform platforms[NPLAT];
212
+ unsigned n_platforms = 0;
213
+ struct cl_device devices[NDEV];
214
+ unsigned n_devices = 0;
215
+ struct cl_device * default_device = NULL;
216
+
217
+ platform = NULL;
218
+ device = NULL;
219
+
220
+ cl_platform_id platform_ids[NPLAT];
221
+ CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
222
+
223
+ for (unsigned i = 0; i < n_platforms; i++) {
224
+ struct cl_platform * p = &platforms[i];
225
+ p->number = i;
226
+ p->id = platform_ids[i];
227
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
228
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
229
+
230
+ cl_device_id device_ids[NDEV];
231
+ cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
232
+ if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
233
+ p->n_devices = 0;
234
+ } else {
235
+ CL_CHECK(clGetDeviceIDsError);
236
+ }
237
+ p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
238
+ p->default_device = NULL;
239
+
240
+ for (unsigned j = 0; j < p->n_devices; j++) {
241
+ struct cl_device * d = &devices[n_devices];
242
+ d->number = n_devices++;
243
+ d->id = device_ids[j];
244
+ d->platform = p;
245
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
246
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
247
+
248
+ if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
249
+ p->default_device = d;
250
+ }
251
+ }
252
+
253
+ if (default_device == NULL && p->default_device != NULL) {
254
+ default_device = p->default_device;
255
+ }
256
+ }
257
+
258
+ if (n_devices == 0) {
259
+ fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
260
+ exit(1);
261
+ }
262
+
263
+ char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
264
+ char * user_device_string = getenv("GGML_OPENCL_DEVICE");
265
+ int user_platform_number = -1;
266
+ int user_device_number = -1;
267
+
268
+ unsigned n;
269
+ if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
270
+ user_platform_number = (int)n;
271
+ }
272
+ if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
273
+ user_device_number = (int)n;
274
+ }
275
+
276
+ struct cl_device * selected_devices = devices;
277
+ unsigned n_selected_devices = n_devices;
278
+
279
+ if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
280
+ for (unsigned i = 0; i < n_platforms; i++) {
281
+ struct cl_platform * p = &platforms[i];
282
+ if (strstr(p->name, user_platform_string) != NULL ||
283
+ strstr(p->vendor, user_platform_string) != NULL) {
284
+ user_platform_number = (int)i;
285
+ break;
286
+ }
287
+ }
288
+ if (user_platform_number == -1) {
289
+ fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
290
+ exit(1);
291
+ }
292
+ }
293
+ if (user_platform_number != -1) {
294
+ struct cl_platform * p = &platforms[user_platform_number];
295
+ selected_devices = p->devices;
296
+ n_selected_devices = p->n_devices;
297
+ default_device = p->default_device;
298
+ if (n_selected_devices == 0) {
299
+ fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
300
+ exit(1);
301
+ }
302
+ }
303
+
304
+ if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
305
+ for (unsigned i = 0; i < n_selected_devices; i++) {
306
+ struct cl_device * d = &selected_devices[i];
307
+ if (strstr(d->name, user_device_string) != NULL) {
308
+ user_device_number = d->number;
309
+ break;
310
+ }
311
+ }
312
+ if (user_device_number == -1) {
313
+ fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
314
+ exit(1);
315
+ }
316
+ }
317
+ if (user_device_number != -1) {
318
+ selected_devices = &devices[user_device_number];
319
+ n_selected_devices = 1;
320
+ default_device = &selected_devices[0];
321
+ }
322
+
323
+ GGML_ASSERT(n_selected_devices > 0);
324
+
325
+ if (default_device == NULL) {
326
+ default_device = &selected_devices[0];
327
+ }
328
+
329
+ fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
330
+ fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
331
+ if (default_device->type != CL_DEVICE_TYPE_GPU) {
332
+ fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
333
+ }
334
+
335
+ platform = default_device->platform->id;
336
+ device = default_device->id;
337
+
338
+ cl_context_properties properties[] = {
339
+ (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
340
+ };
341
+
342
+ CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
343
+
344
+ CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
345
+ (err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
346
+ (queue = clCreateCommandQueue(context, device, 0, &err), err)
347
+ )));
348
+
349
+ program = build_program_from_source(context, device, program_source);
235
350
 
236
351
  // Prepare dequantize kernels
237
- kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
238
- CL_CHECK(err, "clCreateKernel");
239
- kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
240
- CL_CHECK(err, "clCreateKernel");
241
- kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
242
- CL_CHECK(err, "clCreateKernel");
243
- kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
244
- CL_CHECK(err, "clCreateKernel");
245
- kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
246
- CL_CHECK(err, "clCreateKernel");
247
- kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
248
- CL_CHECK(err, "clCreateKernel");
352
+ CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
353
+ CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
354
+ CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
355
+ CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
356
+ CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
249
357
  }
250
358
 
251
359
  static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -258,9 +366,8 @@ static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags
258
366
  clReleaseMemObject(*buf);
259
367
  }
260
368
  cl_int err;
261
- *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
369
+ CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
262
370
  *cur_size = req_size;
263
- CL_CHECK(err, "clCreateBuffer");
264
371
  }
265
372
 
266
373
  void ggml_cl_sgemm_wrapper(
@@ -269,12 +376,10 @@ void ggml_cl_sgemm_wrapper(
269
376
  const float alpha, const void *host_a, const int lda,
270
377
  const float *host_b, const int ldb, const float beta,
271
378
  float *host_c, const int ldc, const int btype) {
272
- cl_int err = 0;
273
379
 
274
380
  cl_kernel kernel;
275
381
  size_t global = n * k, local, size_qb;
276
382
  bool dequant;
277
- cl_block_q5_0* cl_host_b;
278
383
 
279
384
  switch (btype) {
280
385
  case GGML_TYPE_F32:
@@ -284,36 +389,19 @@ void ggml_cl_sgemm_wrapper(
284
389
  dequant = true;
285
390
  kernel = kernel_q4_0;
286
391
  local = 16;
287
- size_qb = global * (sizeof(float) + local) / 32;
392
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
288
393
  break;
289
394
  case GGML_TYPE_Q4_1:
290
395
  dequant = true;
291
396
  kernel = kernel_q4_1;
292
397
  local = 16;
293
- size_qb = global * (sizeof(float) * 2 + local) / 32;
294
- break;
295
- case GGML_TYPE_Q4_2:
296
- dequant = true;
297
- kernel = kernel_q4_2;
298
- local = 8;
299
- size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
398
+ size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
300
399
  break;
301
400
  case GGML_TYPE_Q5_0:
302
401
  dequant = true;
303
402
  kernel = kernel_q5_0;
304
403
  local = 16;
305
- // For some reason OpenCL seems to be incapable of working with structs of size 22.
306
- // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
307
- // TODO Find the reason, fix and remove workaround.
308
- const block_q5_0* b = (const block_q5_0*) host_b;
309
- cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
310
- for (size_t i = 0; i < global / 32; i++) {
311
- cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
312
- memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
313
- memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
314
- }
315
- host_b = (const float*) cl_host_b;
316
- size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
404
+ size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32;
317
405
  break;
318
406
  case GGML_TYPE_Q5_1:
319
407
  dequant = true;
@@ -325,7 +413,7 @@ void ggml_cl_sgemm_wrapper(
325
413
  dequant = true;
326
414
  kernel = kernel_q8_0;
327
415
  local = 32;
328
- size_qb = global * (sizeof(float) + local) / 32;
416
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
329
417
  break;
330
418
  default:
331
419
  fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
@@ -347,52 +435,40 @@ void ggml_cl_sgemm_wrapper(
347
435
  cl_event ev_a, ev_qb, ev_b;
348
436
 
349
437
  if (dequant) {
350
- err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
351
- err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
352
- CL_CHECK(err, "clSetKernelArg");
353
- err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
354
- CL_CHECK(err, "clEnqueueWriteBuffer qb");
438
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
439
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
440
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
355
441
  } else {
356
- err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
357
- CL_CHECK(err, "clEnqueueWriteBuffer b");
442
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
358
443
  }
359
444
 
360
- err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
361
- CL_CHECK(err, "clEnqueueWriteBuffer a");
445
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
362
446
  if (dequant) {
363
- err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
364
- CL_CHECK(err, "clEnqueueNDRangeKernel");
365
- clReleaseEvent(ev_qb);
447
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
448
+ CL_CHECK(clReleaseEvent(ev_qb));
366
449
  }
367
- clWaitForEvents(1, &ev_a);
368
- clWaitForEvents(1, &ev_b);
369
- clReleaseEvent(ev_a);
370
- clReleaseEvent(ev_b);
450
+ CL_CHECK(clWaitForEvents(1, &ev_a));
451
+ CL_CHECK(clWaitForEvents(1, &ev_b));
452
+ CL_CHECK(clReleaseEvent(ev_a));
453
+ CL_CHECK(clReleaseEvent(ev_b));
371
454
 
372
455
  cl_event ev_sgemm;
373
- CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
374
- (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
375
- m, n, k,
376
- alpha,
377
- cl_buffer_a, 0, lda,
378
- cl_buffer_b, 0, ldb,
379
- beta,
380
- cl_buffer_c, 0, ldc,
381
- &queue, &ev_sgemm);
382
-
383
- if (status != CLBlastSuccess) {
384
- fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
385
- abort();
386
- }
456
+ CLBLAST_CHECK(CLBlastSgemm(
457
+ (CLBlastLayout)order,
458
+ (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
459
+ m, n, k,
460
+ alpha,
461
+ cl_buffer_a, 0, lda,
462
+ cl_buffer_b, 0, ldb,
463
+ beta,
464
+ cl_buffer_c, 0, ldc,
465
+ &queue, &ev_sgemm));
387
466
 
388
467
  cl_event ev_c;
389
- clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
468
+ CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
390
469
 
391
470
  // Wait for completion
392
- clWaitForEvents(1, &ev_c);
393
- clReleaseEvent(ev_sgemm);
394
- clReleaseEvent(ev_c);
395
- if (btype == GGML_TYPE_Q5_0) {
396
- free((void*) cl_host_b);
397
- }
471
+ CL_CHECK(clWaitForEvents(1, &ev_c));
472
+ CL_CHECK(clReleaseEvent(ev_sgemm));
473
+ CL_CHECK(clReleaseEvent(ev_c));
398
474
  }