llama_cpp 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,87 +10,77 @@
10
10
  #include "ggml.h"
11
11
 
12
12
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
13
- const char * clblast_dequant = MULTILINE_QUOTE(
13
+ static const char * program_source = MULTILINE_QUOTE(
14
14
 
15
+ typedef char int8_t;
15
16
  typedef uchar uint8_t;
16
17
  typedef int int32_t;
17
18
  typedef uint uint32_t;
18
19
 
19
- constant uint QK4_0 = 32;
20
- struct block_q4_0
20
+ struct __attribute__ ((packed)) block_q4_0
21
21
  {
22
- float d;
23
- uint8_t qs[QK4_0 / 2];
22
+ half d;
23
+ uint8_t qs[16]; /* QK4_0 / 2 */
24
24
  };
25
25
 
26
- constant uint QK4_1 = 32;
27
- struct block_q4_1
26
+ struct __attribute__ ((packed)) block_q4_1
28
27
  {
29
- float d;
30
- float m;
31
- uint8_t qs[QK4_1 / 2];
28
+ half d;
29
+ half m;
30
+ uint8_t qs[16]; /* QK4_1 / 2 */
32
31
  };
33
32
 
34
- constant uint QK5_0 = 32;
35
33
  struct __attribute__ ((packed)) block_q5_0
36
34
  {
37
35
  half d;
38
36
  uint32_t qh;
39
- uint8_t qs[QK5_0 / 2];
37
+ uint8_t qs[16]; /* QK5_0 / 2 */
40
38
  };
41
39
 
42
- constant uint QK5_1 = 32;
43
- struct block_q5_1
40
+ struct __attribute__ ((packed)) block_q5_1
44
41
  {
45
42
  half d;
46
43
  half m;
47
44
  uint32_t qh;
48
- uint8_t qs[QK5_1 / 2];
45
+ uint8_t qs[16]; /* QK5_1 / 2 */
49
46
  };
50
47
 
51
- constant uint QK8_0 = 32;
52
- struct block_q8_0
48
+ struct __attribute__ ((packed)) block_q8_0
53
49
  {
54
- float d;
55
- uint8_t qs[QK8_0];
50
+ half d;
51
+ int8_t qs[32]; /* QK8_0 */
56
52
  };
57
53
 
58
54
 
59
55
  __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
60
- constant uint qk = QK4_0;
61
-
62
- const uint i = get_global_id(0) / qk;
56
+ const uint i = get_global_id(0) / 32; /* QK4_0 */
63
57
  const uint j = get_local_id(0);
64
58
 
65
- const float d = x[i].d;
59
+ const float d = vload_half(0, (__global half*) &x[i].d);
66
60
 
67
61
  const int x0 = (x[i].qs[j] & 0xf) - 8;
68
62
  const int x1 = (x[i].qs[j] >> 4) - 8;
69
63
 
70
- y[i*qk + j + 0 ] = x0*d;
71
- y[i*qk + j + qk/2] = x1*d;
64
+ y[i*32 + j + 0 ] = x0*d;
65
+ y[i*32 + j + 16] = x1*d;
72
66
  }
73
67
 
74
68
  __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
75
- constant uint qk = QK4_1;
76
-
77
- const uint i = get_global_id(0) / qk;
69
+ const uint i = get_global_id(0) / 32; /* QK4_1 */
78
70
  const uint j = get_local_id(0);
79
71
 
80
- const float d = x[i].d;
81
- const float m = x[i].m;
72
+ const float d = vload_half(0, (__global half*) &x[i].d);
73
+ const float m = vload_half(0, (__global half*) &x[i].m);
82
74
 
83
75
  const int x0 = (x[i].qs[j] & 0xf);
84
76
  const int x1 = (x[i].qs[j] >> 4);
85
77
 
86
- y[i*qk + j + 0 ] = x0*d + m;
87
- y[i*qk + j + qk/2] = x1*d + m;
78
+ y[i*32 + j + 0 ] = x0*d + m;
79
+ y[i*32 + j + 16] = x1*d + m;
88
80
  }
89
81
 
90
82
  __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
91
- constant uint qk = QK5_0;
92
-
93
- const uint i = get_global_id(0) / qk;
83
+ const uint i = get_global_id(0) / 32; /* QK5_0 */
94
84
  const uint j = get_local_id(0);
95
85
 
96
86
  const float d = vload_half(0, (__global half*) &x[i].d);
@@ -103,14 +93,12 @@ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float*
103
93
  const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
104
94
  const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
105
95
 
106
- y[i*qk + j + 0 ] = x0*d;
107
- y[i*qk + j + qk/2] = x1*d;
96
+ y[i*32 + j + 0 ] = x0*d;
97
+ y[i*32 + j + 16] = x1*d;
108
98
  }
109
99
 
110
100
  __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
111
- constant uint qk = QK5_1;
112
-
113
- const uint i = get_global_id(0) / qk;
101
+ const uint i = get_global_id(0) / 32; /* QK5_1 */
114
102
  const uint j = get_local_id(0);
115
103
 
116
104
  const float d = vload_half(0, (__global half*) &x[i].d);
@@ -124,28 +112,38 @@ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float*
124
112
  const int x0 = (x[i].qs[j] & 0xf) | xh_0;
125
113
  const int x1 = (x[i].qs[j] >> 4) | xh_1;
126
114
 
127
- y[i*qk + j + 0 ] = x0*d + m;
128
- y[i*qk + j + qk/2] = x1*d + m;
115
+ y[i*32 + j + 0 ] = x0*d + m;
116
+ y[i*32 + j + 16] = x1*d + m;
129
117
  }
130
118
 
131
119
  __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
132
- constant uint qk = QK8_0;
133
- const uint i = get_global_id(0) / qk;
120
+ const uint i = get_global_id(0) / 32; /* QK8_0 */
134
121
  const uint j = get_local_id(0);
135
122
 
136
- const float d = x[i].d;
137
- y[i*qk + j] = x[i].qs[j]*d;
123
+ const float d = vload_half(0, (__global half*) &x[i].d);
124
+ y[i*32 + j] = x[i].qs[j]*d;
138
125
  }
139
126
 
140
127
  );
141
128
 
142
- #define CL_CHECK(err, name) \
143
- do { \
144
- cl_int err_ = (err); \
145
- if (err_ != CL_SUCCESS) { \
146
- fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
147
- exit(1); \
148
- } \
129
+ #define CL_CHECK(err) \
130
+ do { \
131
+ cl_int err_ = (err); \
132
+ if (err_ != CL_SUCCESS) { \
133
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
134
+ #err, err_, __FILE__, __LINE__); \
135
+ exit(1); \
136
+ } \
137
+ } while (0)
138
+
139
+ #define CLBLAST_CHECK(err) \
140
+ do { \
141
+ CLBlastStatusCode err_ = (err); \
142
+ if (err_ != CLBlastSuccess) { \
143
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
144
+ #err, err_, __FILE__, __LINE__); \
145
+ exit(1); \
146
+ } \
149
147
  } while (0)
150
148
 
151
149
  static cl_platform_id platform;
@@ -188,48 +186,174 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
188
186
 
189
187
  void ggml_cl_init(void) {
190
188
  cl_int err = 0;
191
- char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
192
- char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
193
- int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
194
- int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
195
- printf("\nInitializing CLBlast (First Run)...");
196
- printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
197
- cl_uint num_platforms;
198
- clGetPlatformIDs(0, NULL, &num_platforms);
199
- cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
200
- clGetPlatformIDs(num_platforms, platforms, NULL);
201
- platform = platforms[plat_num];
202
- char platform_buffer[1024];
203
- clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
204
- cl_uint num_devices;
205
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
206
- cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
207
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
208
- device = devices[dev_num];
209
- char device_buffer[1024];
210
- clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
211
- printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
212
- context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
213
- CL_CHECK(err, "clCreateContext");
214
- queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
215
- CL_CHECK(err, "clCreateCommandQueue");
216
-
217
- free(platforms);
218
- free(devices);
219
-
220
- program = build_program_from_source(context, device, clblast_dequant);
189
+
190
+ struct cl_device;
191
+ struct cl_platform {
192
+ cl_platform_id id;
193
+ unsigned number;
194
+ char name[128];
195
+ char vendor[128];
196
+ struct cl_device * devices;
197
+ unsigned n_devices;
198
+ struct cl_device * default_device;
199
+ };
200
+
201
+ struct cl_device {
202
+ struct cl_platform * platform;
203
+ cl_device_id id;
204
+ unsigned number;
205
+ cl_device_type type;
206
+ char name[128];
207
+ };
208
+
209
+ enum { NPLAT = 16, NDEV = 16 };
210
+
211
+ struct cl_platform platforms[NPLAT];
212
+ unsigned n_platforms = 0;
213
+ struct cl_device devices[NDEV];
214
+ unsigned n_devices = 0;
215
+ struct cl_device * default_device = NULL;
216
+
217
+ platform = NULL;
218
+ device = NULL;
219
+
220
+ cl_platform_id platform_ids[NPLAT];
221
+ CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
222
+
223
+ for (unsigned i = 0; i < n_platforms; i++) {
224
+ struct cl_platform * p = &platforms[i];
225
+ p->number = i;
226
+ p->id = platform_ids[i];
227
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
228
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
229
+
230
+ cl_device_id device_ids[NDEV];
231
+ cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
232
+ if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
233
+ p->n_devices = 0;
234
+ } else {
235
+ CL_CHECK(clGetDeviceIDsError);
236
+ }
237
+ p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
238
+ p->default_device = NULL;
239
+
240
+ for (unsigned j = 0; j < p->n_devices; j++) {
241
+ struct cl_device * d = &devices[n_devices];
242
+ d->number = n_devices++;
243
+ d->id = device_ids[j];
244
+ d->platform = p;
245
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
246
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
247
+
248
+ if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
249
+ p->default_device = d;
250
+ }
251
+ }
252
+
253
+ if (default_device == NULL && p->default_device != NULL) {
254
+ default_device = p->default_device;
255
+ }
256
+ }
257
+
258
+ if (n_devices == 0) {
259
+ fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
260
+ exit(1);
261
+ }
262
+
263
+ char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
264
+ char * user_device_string = getenv("GGML_OPENCL_DEVICE");
265
+ int user_platform_number = -1;
266
+ int user_device_number = -1;
267
+
268
+ unsigned n;
269
+ if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
270
+ user_platform_number = (int)n;
271
+ }
272
+ if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
273
+ user_device_number = (int)n;
274
+ }
275
+
276
+ struct cl_device * selected_devices = devices;
277
+ unsigned n_selected_devices = n_devices;
278
+
279
+ if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
280
+ for (unsigned i = 0; i < n_platforms; i++) {
281
+ struct cl_platform * p = &platforms[i];
282
+ if (strstr(p->name, user_platform_string) != NULL ||
283
+ strstr(p->vendor, user_platform_string) != NULL) {
284
+ user_platform_number = (int)i;
285
+ break;
286
+ }
287
+ }
288
+ if (user_platform_number == -1) {
289
+ fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
290
+ exit(1);
291
+ }
292
+ }
293
+ if (user_platform_number != -1) {
294
+ struct cl_platform * p = &platforms[user_platform_number];
295
+ selected_devices = p->devices;
296
+ n_selected_devices = p->n_devices;
297
+ default_device = p->default_device;
298
+ if (n_selected_devices == 0) {
299
+ fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
300
+ exit(1);
301
+ }
302
+ }
303
+
304
+ if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
305
+ for (unsigned i = 0; i < n_selected_devices; i++) {
306
+ struct cl_device * d = &selected_devices[i];
307
+ if (strstr(d->name, user_device_string) != NULL) {
308
+ user_device_number = d->number;
309
+ break;
310
+ }
311
+ }
312
+ if (user_device_number == -1) {
313
+ fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
314
+ exit(1);
315
+ }
316
+ }
317
+ if (user_device_number != -1) {
318
+ selected_devices = &devices[user_device_number];
319
+ n_selected_devices = 1;
320
+ default_device = &selected_devices[0];
321
+ }
322
+
323
+ GGML_ASSERT(n_selected_devices > 0);
324
+
325
+ if (default_device == NULL) {
326
+ default_device = &selected_devices[0];
327
+ }
328
+
329
+ fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
330
+ fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
331
+ if (default_device->type != CL_DEVICE_TYPE_GPU) {
332
+ fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
333
+ }
334
+
335
+ platform = default_device->platform->id;
336
+ device = default_device->id;
337
+
338
+ cl_context_properties properties[] = {
339
+ (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
340
+ };
341
+
342
+ CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
343
+
344
+ CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
345
+ (err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
346
+ (queue = clCreateCommandQueue(context, device, 0, &err), err)
347
+ )));
348
+
349
+ program = build_program_from_source(context, device, program_source);
221
350
 
222
351
  // Prepare dequantize kernels
223
- kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
224
- CL_CHECK(err, "clCreateKernel");
225
- kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
226
- CL_CHECK(err, "clCreateKernel");
227
- kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
228
- CL_CHECK(err, "clCreateKernel");
229
- kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
230
- CL_CHECK(err, "clCreateKernel");
231
- kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
232
- CL_CHECK(err, "clCreateKernel");
352
+ CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
353
+ CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
354
+ CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
355
+ CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
356
+ CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
233
357
  }
234
358
 
235
359
  static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -242,9 +366,8 @@ static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags
242
366
  clReleaseMemObject(*buf);
243
367
  }
244
368
  cl_int err;
245
- *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
369
+ CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
246
370
  *cur_size = req_size;
247
- CL_CHECK(err, "clCreateBuffer");
248
371
  }
249
372
 
250
373
  void ggml_cl_sgemm_wrapper(
@@ -253,7 +376,6 @@ void ggml_cl_sgemm_wrapper(
253
376
  const float alpha, const void *host_a, const int lda,
254
377
  const float *host_b, const int ldb, const float beta,
255
378
  float *host_c, const int ldc, const int btype) {
256
- cl_int err = 0;
257
379
 
258
380
  cl_kernel kernel;
259
381
  size_t global = n * k, local, size_qb;
@@ -267,13 +389,13 @@ void ggml_cl_sgemm_wrapper(
267
389
  dequant = true;
268
390
  kernel = kernel_q4_0;
269
391
  local = 16;
270
- size_qb = global * (sizeof(float) + local) / 32;
392
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
271
393
  break;
272
394
  case GGML_TYPE_Q4_1:
273
395
  dequant = true;
274
396
  kernel = kernel_q4_1;
275
397
  local = 16;
276
- size_qb = global * (sizeof(float) * 2 + local) / 32;
398
+ size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
277
399
  break;
278
400
  case GGML_TYPE_Q5_0:
279
401
  dequant = true;
@@ -291,7 +413,7 @@ void ggml_cl_sgemm_wrapper(
291
413
  dequant = true;
292
414
  kernel = kernel_q8_0;
293
415
  local = 32;
294
- size_qb = global * (sizeof(float) + local) / 32;
416
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
295
417
  break;
296
418
  default:
297
419
  fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
@@ -313,49 +435,40 @@ void ggml_cl_sgemm_wrapper(
313
435
  cl_event ev_a, ev_qb, ev_b;
314
436
 
315
437
  if (dequant) {
316
- err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
317
- err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
318
- CL_CHECK(err, "clSetKernelArg");
319
- err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
320
- CL_CHECK(err, "clEnqueueWriteBuffer qb");
438
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
439
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
440
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
321
441
  } else {
322
- err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
323
- CL_CHECK(err, "clEnqueueWriteBuffer b");
442
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
324
443
  }
325
444
 
326
- err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
327
- CL_CHECK(err, "clEnqueueWriteBuffer a");
445
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
328
446
  if (dequant) {
329
- err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
330
- CL_CHECK(err, "clEnqueueNDRangeKernel");
331
- clReleaseEvent(ev_qb);
447
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
448
+ CL_CHECK(clReleaseEvent(ev_qb));
332
449
  }
333
- clWaitForEvents(1, &ev_a);
334
- clWaitForEvents(1, &ev_b);
335
- clReleaseEvent(ev_a);
336
- clReleaseEvent(ev_b);
450
+ CL_CHECK(clWaitForEvents(1, &ev_a));
451
+ CL_CHECK(clWaitForEvents(1, &ev_b));
452
+ CL_CHECK(clReleaseEvent(ev_a));
453
+ CL_CHECK(clReleaseEvent(ev_b));
337
454
 
338
455
  cl_event ev_sgemm;
339
- CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
340
- (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
341
- m, n, k,
342
- alpha,
343
- cl_buffer_a, 0, lda,
344
- cl_buffer_b, 0, ldb,
345
- beta,
346
- cl_buffer_c, 0, ldc,
347
- &queue, &ev_sgemm);
348
-
349
- if (status != CLBlastSuccess) {
350
- fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
351
- abort();
352
- }
456
+ CLBLAST_CHECK(CLBlastSgemm(
457
+ (CLBlastLayout)order,
458
+ (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
459
+ m, n, k,
460
+ alpha,
461
+ cl_buffer_a, 0, lda,
462
+ cl_buffer_b, 0, ldb,
463
+ beta,
464
+ cl_buffer_c, 0, ldc,
465
+ &queue, &ev_sgemm));
353
466
 
354
467
  cl_event ev_c;
355
- clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
468
+ CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
356
469
 
357
470
  // Wait for completion
358
- clWaitForEvents(1, &ev_c);
359
- clReleaseEvent(ev_sgemm);
360
- clReleaseEvent(ev_c);
471
+ CL_CHECK(clWaitForEvents(1, &ev_c));
472
+ CL_CHECK(clReleaseEvent(ev_sgemm));
473
+ CL_CHECK(clReleaseEvent(ev_c));
361
474
  }