llama_cpp 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,164 +10,148 @@
10
10
  #include "ggml.h"
11
11
 
12
12
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
13
- const char * clblast_dequant = MULTILINE_QUOTE(
13
+ static const char * program_source = MULTILINE_QUOTE(
14
14
 
15
- struct block_q4_0
15
+ typedef char int8_t;
16
+ typedef uchar uint8_t;
17
+ typedef int int32_t;
18
+ typedef uint uint32_t;
19
+
20
+ struct __attribute__ ((packed)) block_q4_0
16
21
  {
17
- float d;
18
- uchar qs[16];
22
+ half d;
23
+ uint8_t qs[16]; /* QK4_0 / 2 */
19
24
  };
20
25
 
21
- __kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
22
- const uint i = get_global_id(0) / 32;
23
- const uint l = get_local_id(0);
24
-
25
- const float d = blocks[i].d;
26
+ struct __attribute__ ((packed)) block_q4_1
27
+ {
28
+ half d;
29
+ half m;
30
+ uint8_t qs[16]; /* QK4_1 / 2 */
31
+ };
26
32
 
27
- const uchar vi = blocks[i].qs[l];
33
+ struct __attribute__ ((packed)) block_q5_0
34
+ {
35
+ half d;
36
+ uint32_t qh;
37
+ uint8_t qs[16]; /* QK5_0 / 2 */
38
+ };
28
39
 
29
- const uint index = i*32 + l*2;
30
- result[index + 0] = ((vi & 0xf) - 8)*d;
31
- result[index + 1] = ((vi >> 4) - 8)*d;
32
- }
40
+ struct __attribute__ ((packed)) block_q5_1
41
+ {
42
+ half d;
43
+ half m;
44
+ uint32_t qh;
45
+ uint8_t qs[16]; /* QK5_1 / 2 */
46
+ };
33
47
 
34
- struct block_q4_1
48
+ struct __attribute__ ((packed)) block_q8_0
35
49
  {
36
- float d;
37
- float m;
38
- uchar qs[16];
50
+ half d;
51
+ int8_t qs[32]; /* QK8_0 */
39
52
  };
40
53
 
41
- __kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
42
- const uint i = get_global_id(0) / 32;
43
- const uint l = get_local_id(0);
44
54
 
45
- const float d = blocks[i].d;
46
- const float m = blocks[i].m;
55
+ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
56
+ const uint i = get_global_id(0) / 32; /* QK4_0 */
57
+ const uint j = get_local_id(0);
47
58
 
48
- const uchar vi = blocks[i].qs[l];
59
+ const float d = vload_half(0, (__global half*) &x[i].d);
49
60
 
50
- const uint index = i*32 + l*2;
51
- result[index + 0] = (vi & 0xf) * d + m;
52
- result[index + 1] = (vi >> 4) * d + m;
53
- }
61
+ const int x0 = (x[i].qs[j] & 0xf) - 8;
62
+ const int x1 = (x[i].qs[j] >> 4) - 8;
54
63
 
55
- struct block_q4_2
56
- {
57
- ushort d;
58
- uchar qs[8];
59
- };
64
+ y[i*32 + j + 0 ] = x0*d;
65
+ y[i*32 + j + 16] = x1*d;
66
+ }
60
67
 
61
- __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
62
- const uint i = get_global_id(0) / 16;
63
- const uint l = get_local_id(0);
68
+ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
69
+ const uint i = get_global_id(0) / 32; /* QK4_1 */
70
+ const uint j = get_local_id(0);
64
71
 
65
- const float d = vload_half(0, (__global half*) &blocks[i].d);
72
+ const float d = vload_half(0, (__global half*) &x[i].d);
73
+ const float m = vload_half(0, (__global half*) &x[i].m);
66
74
 
67
- const uchar vi = blocks[i].qs[l];
75
+ const int x0 = (x[i].qs[j] & 0xf);
76
+ const int x1 = (x[i].qs[j] >> 4);
68
77
 
69
- const uint index = i*16 + l*2;
70
- result[index + 0] = ((vi & 0xf) - 8)*d;
71
- result[index + 1] = ((vi >> 4) - 8)*d;
78
+ y[i*32 + j + 0 ] = x0*d + m;
79
+ y[i*32 + j + 16] = x1*d + m;
72
80
  }
73
81
 
82
+ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
83
+ const uint i = get_global_id(0) / 32; /* QK5_0 */
84
+ const uint j = get_local_id(0);
74
85
 
75
- struct block_q5_0
76
- {
77
- float d;
78
- uint qh;
79
- uchar qs[16];
80
- };
81
-
82
- __kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
83
- const uint i = get_global_id(0) / 32;
84
- const uint l = get_local_id(0);
85
-
86
- const float d = blocks[i].d;
86
+ const float d = vload_half(0, (__global half*) &x[i].d);
87
87
 
88
- const uchar vi = blocks[i].qs[l];
88
+ uint32_t qh = x[i].qh;
89
89
 
90
- const uint l2 = l * 2;
90
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
91
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
91
92
 
92
- const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
93
- const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
93
+ const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
94
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
94
95
 
95
- const uint index = i*32 + l2;
96
- result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
97
- result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
96
+ y[i*32 + j + 0 ] = x0*d;
97
+ y[i*32 + j + 16] = x1*d;
98
98
  }
99
99
 
100
- struct block_q5_1
101
- {
102
- ushort d;
103
- ushort m;
104
- uint qh;
105
- uchar qs[16];
106
- };
107
-
108
- __kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
109
- const uint i = get_global_id(0) / 32;
110
- const uint l = get_local_id(0);
100
+ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
101
+ const uint i = get_global_id(0) / 32; /* QK5_1 */
102
+ const uint j = get_local_id(0);
111
103
 
112
- const float d = vload_half(0, (__global half*) &blocks[i].d);
113
- const float m = vload_half(0, (__global half*) &blocks[i].m);
104
+ const float d = vload_half(0, (__global half*) &x[i].d);
105
+ const float m = vload_half(0, (__global half*) &x[i].m);
114
106
 
115
- const uchar vi = blocks[i].qs[l];
107
+ uint32_t qh = x[i].qh;
116
108
 
117
- const uint l2 = l * 2;
109
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
110
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
118
111
 
119
- const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
120
- const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
112
+ const int x0 = (x[i].qs[j] & 0xf) | xh_0;
113
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
121
114
 
122
- const uint index = i*32 + l2;
123
- result[index + 0] = ((vi & 0xf) | vh0)*d + m;
124
- result[index + 1] = ((vi >> 4) | vh1)*d + m;
115
+ y[i*32 + j + 0 ] = x0*d + m;
116
+ y[i*32 + j + 16] = x1*d + m;
125
117
  }
126
118
 
127
- struct block_q8_0
128
- {
129
- float d;
130
- char qs[32];
131
- };
119
+ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
120
+ const uint i = get_global_id(0) / 32; /* QK8_0 */
121
+ const uint j = get_local_id(0);
132
122
 
133
- __kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
134
- const uint i = get_global_id(0) / 32;
135
- const uint l = get_local_id(0);
136
-
137
- result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
123
+ const float d = vload_half(0, (__global half*) &x[i].d);
124
+ y[i*32 + j] = x[i].qs[j]*d;
138
125
  }
139
126
 
140
127
  );
141
128
 
142
- #define CL_CHECK(err, name) \
143
- do { \
144
- cl_int err_ = (err); \
145
- if (err_ != CL_SUCCESS) { \
146
- fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
147
- exit(1); \
148
- } \
129
+ #define CL_CHECK(err) \
130
+ do { \
131
+ cl_int err_ = (err); \
132
+ if (err_ != CL_SUCCESS) { \
133
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
134
+ #err, err_, __FILE__, __LINE__); \
135
+ exit(1); \
136
+ } \
149
137
  } while (0)
150
138
 
151
- #define QK5_0 32
152
- typedef struct {
153
- ggml_fp16_t d; // delta
154
- uint8_t qh[4]; // 5-th bit of quants
155
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
156
- } block_q5_0;
157
-
158
-
159
- typedef struct {
160
- float d; // delta
161
- uint32_t qh; // 5-th bit of quants
162
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
163
- } cl_block_q5_0;
139
+ #define CLBLAST_CHECK(err) \
140
+ do { \
141
+ CLBlastStatusCode err_ = (err); \
142
+ if (err_ != CLBlastSuccess) { \
143
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
144
+ #err, err_, __FILE__, __LINE__); \
145
+ exit(1); \
146
+ } \
147
+ } while (0)
164
148
 
165
149
  static cl_platform_id platform;
166
150
  static cl_device_id device;
167
151
  static cl_context context;
168
152
  static cl_command_queue queue;
169
153
  static cl_program program;
170
- static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
154
+ static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
171
155
  static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
172
156
  static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
173
157
 
@@ -202,50 +186,174 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
202
186
 
203
187
  void ggml_cl_init(void) {
204
188
  cl_int err = 0;
205
- char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
206
- char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
207
- int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
208
- int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
209
- printf("\nInitializing CLBlast (First Run)...");
210
- printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
211
- cl_uint num_platforms;
212
- clGetPlatformIDs(0, NULL, &num_platforms);
213
- cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
214
- clGetPlatformIDs(num_platforms, platforms, NULL);
215
- platform = platforms[plat_num];
216
- char platform_buffer[1024];
217
- clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
218
- cl_uint num_devices;
219
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
220
- cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
221
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
222
- device = devices[dev_num];
223
- char device_buffer[1024];
224
- clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
225
- printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
226
- context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
227
- CL_CHECK(err, "clCreateContext");
228
- queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
229
- CL_CHECK(err, "clCreateCommandQueue");
230
-
231
- free(platforms);
232
- free(devices);
233
-
234
- program = build_program_from_source(context, device, clblast_dequant);
189
+
190
+ struct cl_device;
191
+ struct cl_platform {
192
+ cl_platform_id id;
193
+ unsigned number;
194
+ char name[128];
195
+ char vendor[128];
196
+ struct cl_device * devices;
197
+ unsigned n_devices;
198
+ struct cl_device * default_device;
199
+ };
200
+
201
+ struct cl_device {
202
+ struct cl_platform * platform;
203
+ cl_device_id id;
204
+ unsigned number;
205
+ cl_device_type type;
206
+ char name[128];
207
+ };
208
+
209
+ enum { NPLAT = 16, NDEV = 16 };
210
+
211
+ struct cl_platform platforms[NPLAT];
212
+ unsigned n_platforms = 0;
213
+ struct cl_device devices[NDEV];
214
+ unsigned n_devices = 0;
215
+ struct cl_device * default_device = NULL;
216
+
217
+ platform = NULL;
218
+ device = NULL;
219
+
220
+ cl_platform_id platform_ids[NPLAT];
221
+ CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
222
+
223
+ for (unsigned i = 0; i < n_platforms; i++) {
224
+ struct cl_platform * p = &platforms[i];
225
+ p->number = i;
226
+ p->id = platform_ids[i];
227
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
228
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
229
+
230
+ cl_device_id device_ids[NDEV];
231
+ cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
232
+ if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
233
+ p->n_devices = 0;
234
+ } else {
235
+ CL_CHECK(clGetDeviceIDsError);
236
+ }
237
+ p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
238
+ p->default_device = NULL;
239
+
240
+ for (unsigned j = 0; j < p->n_devices; j++) {
241
+ struct cl_device * d = &devices[n_devices];
242
+ d->number = n_devices++;
243
+ d->id = device_ids[j];
244
+ d->platform = p;
245
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
246
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
247
+
248
+ if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
249
+ p->default_device = d;
250
+ }
251
+ }
252
+
253
+ if (default_device == NULL && p->default_device != NULL) {
254
+ default_device = p->default_device;
255
+ }
256
+ }
257
+
258
+ if (n_devices == 0) {
259
+ fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
260
+ exit(1);
261
+ }
262
+
263
+ char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
264
+ char * user_device_string = getenv("GGML_OPENCL_DEVICE");
265
+ int user_platform_number = -1;
266
+ int user_device_number = -1;
267
+
268
+ unsigned n;
269
+ if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
270
+ user_platform_number = (int)n;
271
+ }
272
+ if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
273
+ user_device_number = (int)n;
274
+ }
275
+
276
+ struct cl_device * selected_devices = devices;
277
+ unsigned n_selected_devices = n_devices;
278
+
279
+ if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
280
+ for (unsigned i = 0; i < n_platforms; i++) {
281
+ struct cl_platform * p = &platforms[i];
282
+ if (strstr(p->name, user_platform_string) != NULL ||
283
+ strstr(p->vendor, user_platform_string) != NULL) {
284
+ user_platform_number = (int)i;
285
+ break;
286
+ }
287
+ }
288
+ if (user_platform_number == -1) {
289
+ fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
290
+ exit(1);
291
+ }
292
+ }
293
+ if (user_platform_number != -1) {
294
+ struct cl_platform * p = &platforms[user_platform_number];
295
+ selected_devices = p->devices;
296
+ n_selected_devices = p->n_devices;
297
+ default_device = p->default_device;
298
+ if (n_selected_devices == 0) {
299
+ fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
300
+ exit(1);
301
+ }
302
+ }
303
+
304
+ if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
305
+ for (unsigned i = 0; i < n_selected_devices; i++) {
306
+ struct cl_device * d = &selected_devices[i];
307
+ if (strstr(d->name, user_device_string) != NULL) {
308
+ user_device_number = d->number;
309
+ break;
310
+ }
311
+ }
312
+ if (user_device_number == -1) {
313
+ fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
314
+ exit(1);
315
+ }
316
+ }
317
+ if (user_device_number != -1) {
318
+ selected_devices = &devices[user_device_number];
319
+ n_selected_devices = 1;
320
+ default_device = &selected_devices[0];
321
+ }
322
+
323
+ GGML_ASSERT(n_selected_devices > 0);
324
+
325
+ if (default_device == NULL) {
326
+ default_device = &selected_devices[0];
327
+ }
328
+
329
+ fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
330
+ fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
331
+ if (default_device->type != CL_DEVICE_TYPE_GPU) {
332
+ fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
333
+ }
334
+
335
+ platform = default_device->platform->id;
336
+ device = default_device->id;
337
+
338
+ cl_context_properties properties[] = {
339
+ (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
340
+ };
341
+
342
+ CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
343
+
344
+ CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
345
+ (err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
346
+ (queue = clCreateCommandQueue(context, device, 0, &err), err)
347
+ )));
348
+
349
+ program = build_program_from_source(context, device, program_source);
235
350
 
236
351
  // Prepare dequantize kernels
237
- kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
238
- CL_CHECK(err, "clCreateKernel");
239
- kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
240
- CL_CHECK(err, "clCreateKernel");
241
- kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
242
- CL_CHECK(err, "clCreateKernel");
243
- kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
244
- CL_CHECK(err, "clCreateKernel");
245
- kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
246
- CL_CHECK(err, "clCreateKernel");
247
- kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
248
- CL_CHECK(err, "clCreateKernel");
352
+ CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
353
+ CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
354
+ CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
355
+ CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
356
+ CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
249
357
  }
250
358
 
251
359
  static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -258,9 +366,8 @@ static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags
258
366
  clReleaseMemObject(*buf);
259
367
  }
260
368
  cl_int err;
261
- *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
369
+ CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
262
370
  *cur_size = req_size;
263
- CL_CHECK(err, "clCreateBuffer");
264
371
  }
265
372
 
266
373
  void ggml_cl_sgemm_wrapper(
@@ -269,12 +376,10 @@ void ggml_cl_sgemm_wrapper(
269
376
  const float alpha, const void *host_a, const int lda,
270
377
  const float *host_b, const int ldb, const float beta,
271
378
  float *host_c, const int ldc, const int btype) {
272
- cl_int err = 0;
273
379
 
274
380
  cl_kernel kernel;
275
381
  size_t global = n * k, local, size_qb;
276
382
  bool dequant;
277
- cl_block_q5_0* cl_host_b;
278
383
 
279
384
  switch (btype) {
280
385
  case GGML_TYPE_F32:
@@ -284,36 +389,19 @@ void ggml_cl_sgemm_wrapper(
284
389
  dequant = true;
285
390
  kernel = kernel_q4_0;
286
391
  local = 16;
287
- size_qb = global * (sizeof(float) + local) / 32;
392
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
288
393
  break;
289
394
  case GGML_TYPE_Q4_1:
290
395
  dequant = true;
291
396
  kernel = kernel_q4_1;
292
397
  local = 16;
293
- size_qb = global * (sizeof(float) * 2 + local) / 32;
294
- break;
295
- case GGML_TYPE_Q4_2:
296
- dequant = true;
297
- kernel = kernel_q4_2;
298
- local = 8;
299
- size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
398
+ size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
300
399
  break;
301
400
  case GGML_TYPE_Q5_0:
302
401
  dequant = true;
303
402
  kernel = kernel_q5_0;
304
403
  local = 16;
305
- // For some reason OpenCL seems to be incapable of working with structs of size 22.
306
- // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
307
- // TODO Find the reason, fix and remove workaround.
308
- const block_q5_0* b = (const block_q5_0*) host_b;
309
- cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
310
- for (size_t i = 0; i < global / 32; i++) {
311
- cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
312
- memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
313
- memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
314
- }
315
- host_b = (const float*) cl_host_b;
316
- size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
404
+ size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32;
317
405
  break;
318
406
  case GGML_TYPE_Q5_1:
319
407
  dequant = true;
@@ -325,7 +413,7 @@ void ggml_cl_sgemm_wrapper(
325
413
  dequant = true;
326
414
  kernel = kernel_q8_0;
327
415
  local = 32;
328
- size_qb = global * (sizeof(float) + local) / 32;
416
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
329
417
  break;
330
418
  default:
331
419
  fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
@@ -347,52 +435,40 @@ void ggml_cl_sgemm_wrapper(
347
435
  cl_event ev_a, ev_qb, ev_b;
348
436
 
349
437
  if (dequant) {
350
- err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
351
- err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
352
- CL_CHECK(err, "clSetKernelArg");
353
- err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
354
- CL_CHECK(err, "clEnqueueWriteBuffer qb");
438
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
439
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
440
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
355
441
  } else {
356
- err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
357
- CL_CHECK(err, "clEnqueueWriteBuffer b");
442
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
358
443
  }
359
444
 
360
- err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
361
- CL_CHECK(err, "clEnqueueWriteBuffer a");
445
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
362
446
  if (dequant) {
363
- err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
364
- CL_CHECK(err, "clEnqueueNDRangeKernel");
365
- clReleaseEvent(ev_qb);
447
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
448
+ CL_CHECK(clReleaseEvent(ev_qb));
366
449
  }
367
- clWaitForEvents(1, &ev_a);
368
- clWaitForEvents(1, &ev_b);
369
- clReleaseEvent(ev_a);
370
- clReleaseEvent(ev_b);
450
+ CL_CHECK(clWaitForEvents(1, &ev_a));
451
+ CL_CHECK(clWaitForEvents(1, &ev_b));
452
+ CL_CHECK(clReleaseEvent(ev_a));
453
+ CL_CHECK(clReleaseEvent(ev_b));
371
454
 
372
455
  cl_event ev_sgemm;
373
- CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
374
- (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
375
- m, n, k,
376
- alpha,
377
- cl_buffer_a, 0, lda,
378
- cl_buffer_b, 0, ldb,
379
- beta,
380
- cl_buffer_c, 0, ldc,
381
- &queue, &ev_sgemm);
382
-
383
- if (status != CLBlastSuccess) {
384
- fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
385
- abort();
386
- }
456
+ CLBLAST_CHECK(CLBlastSgemm(
457
+ (CLBlastLayout)order,
458
+ (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
459
+ m, n, k,
460
+ alpha,
461
+ cl_buffer_a, 0, lda,
462
+ cl_buffer_b, 0, ldb,
463
+ beta,
464
+ cl_buffer_c, 0, ldc,
465
+ &queue, &ev_sgemm));
387
466
 
388
467
  cl_event ev_c;
389
- clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
468
+ CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
390
469
 
391
470
  // Wait for completion
392
- clWaitForEvents(1, &ev_c);
393
- clReleaseEvent(ev_sgemm);
394
- clReleaseEvent(ev_c);
395
- if (btype == GGML_TYPE_Q5_0) {
396
- free((void*) cl_host_b);
397
- }
471
+ CL_CHECK(clWaitForEvents(1, &ev_c));
472
+ CL_CHECK(clReleaseEvent(ev_sgemm));
473
+ CL_CHECK(clReleaseEvent(ev_c));
398
474
  }