llama_cpp 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,87 +10,77 @@
10
10
  #include "ggml.h"
11
11
 
12
12
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
13
- const char * clblast_dequant = MULTILINE_QUOTE(
13
+ static const char * program_source = MULTILINE_QUOTE(
14
14
 
15
+ typedef char int8_t;
15
16
  typedef uchar uint8_t;
16
17
  typedef int int32_t;
17
18
  typedef uint uint32_t;
18
19
 
19
- constant uint QK4_0 = 32;
20
- struct block_q4_0
20
+ struct __attribute__ ((packed)) block_q4_0
21
21
  {
22
- float d;
23
- uint8_t qs[QK4_0 / 2];
22
+ half d;
23
+ uint8_t qs[16]; /* QK4_0 / 2 */
24
24
  };
25
25
 
26
- constant uint QK4_1 = 32;
27
- struct block_q4_1
26
+ struct __attribute__ ((packed)) block_q4_1
28
27
  {
29
- float d;
30
- float m;
31
- uint8_t qs[QK4_1 / 2];
28
+ half d;
29
+ half m;
30
+ uint8_t qs[16]; /* QK4_1 / 2 */
32
31
  };
33
32
 
34
- constant uint QK5_0 = 32;
35
33
  struct __attribute__ ((packed)) block_q5_0
36
34
  {
37
35
  half d;
38
36
  uint32_t qh;
39
- uint8_t qs[QK5_0 / 2];
37
+ uint8_t qs[16]; /* QK5_0 / 2 */
40
38
  };
41
39
 
42
- constant uint QK5_1 = 32;
43
- struct block_q5_1
40
+ struct __attribute__ ((packed)) block_q5_1
44
41
  {
45
42
  half d;
46
43
  half m;
47
44
  uint32_t qh;
48
- uint8_t qs[QK5_1 / 2];
45
+ uint8_t qs[16]; /* QK5_1 / 2 */
49
46
  };
50
47
 
51
- constant uint QK8_0 = 32;
52
- struct block_q8_0
48
+ struct __attribute__ ((packed)) block_q8_0
53
49
  {
54
- float d;
55
- uint8_t qs[QK8_0];
50
+ half d;
51
+ int8_t qs[32]; /* QK8_0 */
56
52
  };
57
53
 
58
54
 
59
55
  __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
60
- constant uint qk = QK4_0;
61
-
62
- const uint i = get_global_id(0) / qk;
56
+ const uint i = get_global_id(0) / 32; /* QK4_0 */
63
57
  const uint j = get_local_id(0);
64
58
 
65
- const float d = x[i].d;
59
+ const float d = vload_half(0, (__global half*) &x[i].d);
66
60
 
67
61
  const int x0 = (x[i].qs[j] & 0xf) - 8;
68
62
  const int x1 = (x[i].qs[j] >> 4) - 8;
69
63
 
70
- y[i*qk + j + 0 ] = x0*d;
71
- y[i*qk + j + qk/2] = x1*d;
64
+ y[i*32 + j + 0 ] = x0*d;
65
+ y[i*32 + j + 16] = x1*d;
72
66
  }
73
67
 
74
68
  __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
75
- constant uint qk = QK4_1;
76
-
77
- const uint i = get_global_id(0) / qk;
69
+ const uint i = get_global_id(0) / 32; /* QK4_1 */
78
70
  const uint j = get_local_id(0);
79
71
 
80
- const float d = x[i].d;
81
- const float m = x[i].m;
72
+ const float d = vload_half(0, (__global half*) &x[i].d);
73
+ const float m = vload_half(0, (__global half*) &x[i].m);
82
74
 
83
75
  const int x0 = (x[i].qs[j] & 0xf);
84
76
  const int x1 = (x[i].qs[j] >> 4);
85
77
 
86
- y[i*qk + j + 0 ] = x0*d + m;
87
- y[i*qk + j + qk/2] = x1*d + m;
78
+ y[i*32 + j + 0 ] = x0*d + m;
79
+ y[i*32 + j + 16] = x1*d + m;
88
80
  }
89
81
 
90
82
  __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
91
- constant uint qk = QK5_0;
92
-
93
- const uint i = get_global_id(0) / qk;
83
+ const uint i = get_global_id(0) / 32; /* QK5_0 */
94
84
  const uint j = get_local_id(0);
95
85
 
96
86
  const float d = vload_half(0, (__global half*) &x[i].d);
@@ -103,14 +93,12 @@ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float*
103
93
  const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
104
94
  const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
105
95
 
106
- y[i*qk + j + 0 ] = x0*d;
107
- y[i*qk + j + qk/2] = x1*d;
96
+ y[i*32 + j + 0 ] = x0*d;
97
+ y[i*32 + j + 16] = x1*d;
108
98
  }
109
99
 
110
100
  __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
111
- constant uint qk = QK5_1;
112
-
113
- const uint i = get_global_id(0) / qk;
101
+ const uint i = get_global_id(0) / 32; /* QK5_1 */
114
102
  const uint j = get_local_id(0);
115
103
 
116
104
  const float d = vload_half(0, (__global half*) &x[i].d);
@@ -124,28 +112,38 @@ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float*
124
112
  const int x0 = (x[i].qs[j] & 0xf) | xh_0;
125
113
  const int x1 = (x[i].qs[j] >> 4) | xh_1;
126
114
 
127
- y[i*qk + j + 0 ] = x0*d + m;
128
- y[i*qk + j + qk/2] = x1*d + m;
115
+ y[i*32 + j + 0 ] = x0*d + m;
116
+ y[i*32 + j + 16] = x1*d + m;
129
117
  }
130
118
 
131
119
  __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
132
- constant uint qk = QK8_0;
133
- const uint i = get_global_id(0) / qk;
120
+ const uint i = get_global_id(0) / 32; /* QK8_0 */
134
121
  const uint j = get_local_id(0);
135
122
 
136
- const float d = x[i].d;
137
- y[i*qk + j] = x[i].qs[j]*d;
123
+ const float d = vload_half(0, (__global half*) &x[i].d);
124
+ y[i*32 + j] = x[i].qs[j]*d;
138
125
  }
139
126
 
140
127
  );
141
128
 
142
- #define CL_CHECK(err, name) \
143
- do { \
144
- cl_int err_ = (err); \
145
- if (err_ != CL_SUCCESS) { \
146
- fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
147
- exit(1); \
148
- } \
129
+ #define CL_CHECK(err) \
130
+ do { \
131
+ cl_int err_ = (err); \
132
+ if (err_ != CL_SUCCESS) { \
133
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
134
+ #err, err_, __FILE__, __LINE__); \
135
+ exit(1); \
136
+ } \
137
+ } while (0)
138
+
139
+ #define CLBLAST_CHECK(err) \
140
+ do { \
141
+ CLBlastStatusCode err_ = (err); \
142
+ if (err_ != CLBlastSuccess) { \
143
+ fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
144
+ #err, err_, __FILE__, __LINE__); \
145
+ exit(1); \
146
+ } \
149
147
  } while (0)
150
148
 
151
149
  static cl_platform_id platform;
@@ -188,48 +186,174 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
188
186
 
189
187
  void ggml_cl_init(void) {
190
188
  cl_int err = 0;
191
- char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
192
- char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
193
- int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
194
- int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
195
- printf("\nInitializing CLBlast (First Run)...");
196
- printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
197
- cl_uint num_platforms;
198
- clGetPlatformIDs(0, NULL, &num_platforms);
199
- cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
200
- clGetPlatformIDs(num_platforms, platforms, NULL);
201
- platform = platforms[plat_num];
202
- char platform_buffer[1024];
203
- clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
204
- cl_uint num_devices;
205
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
206
- cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
207
- clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
208
- device = devices[dev_num];
209
- char device_buffer[1024];
210
- clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
211
- printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
212
- context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
213
- CL_CHECK(err, "clCreateContext");
214
- queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
215
- CL_CHECK(err, "clCreateCommandQueue");
216
-
217
- free(platforms);
218
- free(devices);
219
-
220
- program = build_program_from_source(context, device, clblast_dequant);
189
+
190
+ struct cl_device;
191
+ struct cl_platform {
192
+ cl_platform_id id;
193
+ unsigned number;
194
+ char name[128];
195
+ char vendor[128];
196
+ struct cl_device * devices;
197
+ unsigned n_devices;
198
+ struct cl_device * default_device;
199
+ };
200
+
201
+ struct cl_device {
202
+ struct cl_platform * platform;
203
+ cl_device_id id;
204
+ unsigned number;
205
+ cl_device_type type;
206
+ char name[128];
207
+ };
208
+
209
+ enum { NPLAT = 16, NDEV = 16 };
210
+
211
+ struct cl_platform platforms[NPLAT];
212
+ unsigned n_platforms = 0;
213
+ struct cl_device devices[NDEV];
214
+ unsigned n_devices = 0;
215
+ struct cl_device * default_device = NULL;
216
+
217
+ platform = NULL;
218
+ device = NULL;
219
+
220
+ cl_platform_id platform_ids[NPLAT];
221
+ CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
222
+
223
+ for (unsigned i = 0; i < n_platforms; i++) {
224
+ struct cl_platform * p = &platforms[i];
225
+ p->number = i;
226
+ p->id = platform_ids[i];
227
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
228
+ CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
229
+
230
+ cl_device_id device_ids[NDEV];
231
+ cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
232
+ if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
233
+ p->n_devices = 0;
234
+ } else {
235
+ CL_CHECK(clGetDeviceIDsError);
236
+ }
237
+ p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
238
+ p->default_device = NULL;
239
+
240
+ for (unsigned j = 0; j < p->n_devices; j++) {
241
+ struct cl_device * d = &devices[n_devices];
242
+ d->number = n_devices++;
243
+ d->id = device_ids[j];
244
+ d->platform = p;
245
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
246
+ CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
247
+
248
+ if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
249
+ p->default_device = d;
250
+ }
251
+ }
252
+
253
+ if (default_device == NULL && p->default_device != NULL) {
254
+ default_device = p->default_device;
255
+ }
256
+ }
257
+
258
+ if (n_devices == 0) {
259
+ fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
260
+ exit(1);
261
+ }
262
+
263
+ char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
264
+ char * user_device_string = getenv("GGML_OPENCL_DEVICE");
265
+ int user_platform_number = -1;
266
+ int user_device_number = -1;
267
+
268
+ unsigned n;
269
+ if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
270
+ user_platform_number = (int)n;
271
+ }
272
+ if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
273
+ user_device_number = (int)n;
274
+ }
275
+
276
+ struct cl_device * selected_devices = devices;
277
+ unsigned n_selected_devices = n_devices;
278
+
279
+ if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
280
+ for (unsigned i = 0; i < n_platforms; i++) {
281
+ struct cl_platform * p = &platforms[i];
282
+ if (strstr(p->name, user_platform_string) != NULL ||
283
+ strstr(p->vendor, user_platform_string) != NULL) {
284
+ user_platform_number = (int)i;
285
+ break;
286
+ }
287
+ }
288
+ if (user_platform_number == -1) {
289
+ fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
290
+ exit(1);
291
+ }
292
+ }
293
+ if (user_platform_number != -1) {
294
+ struct cl_platform * p = &platforms[user_platform_number];
295
+ selected_devices = p->devices;
296
+ n_selected_devices = p->n_devices;
297
+ default_device = p->default_device;
298
+ if (n_selected_devices == 0) {
299
+ fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
300
+ exit(1);
301
+ }
302
+ }
303
+
304
+ if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
305
+ for (unsigned i = 0; i < n_selected_devices; i++) {
306
+ struct cl_device * d = &selected_devices[i];
307
+ if (strstr(d->name, user_device_string) != NULL) {
308
+ user_device_number = d->number;
309
+ break;
310
+ }
311
+ }
312
+ if (user_device_number == -1) {
313
+ fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
314
+ exit(1);
315
+ }
316
+ }
317
+ if (user_device_number != -1) {
318
+ selected_devices = &devices[user_device_number];
319
+ n_selected_devices = 1;
320
+ default_device = &selected_devices[0];
321
+ }
322
+
323
+ GGML_ASSERT(n_selected_devices > 0);
324
+
325
+ if (default_device == NULL) {
326
+ default_device = &selected_devices[0];
327
+ }
328
+
329
+ fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
330
+ fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
331
+ if (default_device->type != CL_DEVICE_TYPE_GPU) {
332
+ fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
333
+ }
334
+
335
+ platform = default_device->platform->id;
336
+ device = default_device->id;
337
+
338
+ cl_context_properties properties[] = {
339
+ (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
340
+ };
341
+
342
+ CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
343
+
344
+ CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
345
+ (err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
346
+ (queue = clCreateCommandQueue(context, device, 0, &err), err)
347
+ )));
348
+
349
+ program = build_program_from_source(context, device, program_source);
221
350
 
222
351
  // Prepare dequantize kernels
223
- kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
224
- CL_CHECK(err, "clCreateKernel");
225
- kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
226
- CL_CHECK(err, "clCreateKernel");
227
- kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
228
- CL_CHECK(err, "clCreateKernel");
229
- kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
230
- CL_CHECK(err, "clCreateKernel");
231
- kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
232
- CL_CHECK(err, "clCreateKernel");
352
+ CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
353
+ CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
354
+ CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
355
+ CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
356
+ CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
233
357
  }
234
358
 
235
359
  static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -242,9 +366,8 @@ static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags
242
366
  clReleaseMemObject(*buf);
243
367
  }
244
368
  cl_int err;
245
- *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
369
+ CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
246
370
  *cur_size = req_size;
247
- CL_CHECK(err, "clCreateBuffer");
248
371
  }
249
372
 
250
373
  void ggml_cl_sgemm_wrapper(
@@ -253,7 +376,6 @@ void ggml_cl_sgemm_wrapper(
253
376
  const float alpha, const void *host_a, const int lda,
254
377
  const float *host_b, const int ldb, const float beta,
255
378
  float *host_c, const int ldc, const int btype) {
256
- cl_int err = 0;
257
379
 
258
380
  cl_kernel kernel;
259
381
  size_t global = n * k, local, size_qb;
@@ -267,13 +389,13 @@ void ggml_cl_sgemm_wrapper(
267
389
  dequant = true;
268
390
  kernel = kernel_q4_0;
269
391
  local = 16;
270
- size_qb = global * (sizeof(float) + local) / 32;
392
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
271
393
  break;
272
394
  case GGML_TYPE_Q4_1:
273
395
  dequant = true;
274
396
  kernel = kernel_q4_1;
275
397
  local = 16;
276
- size_qb = global * (sizeof(float) * 2 + local) / 32;
398
+ size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
277
399
  break;
278
400
  case GGML_TYPE_Q5_0:
279
401
  dequant = true;
@@ -291,7 +413,7 @@ void ggml_cl_sgemm_wrapper(
291
413
  dequant = true;
292
414
  kernel = kernel_q8_0;
293
415
  local = 32;
294
- size_qb = global * (sizeof(float) + local) / 32;
416
+ size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
295
417
  break;
296
418
  default:
297
419
  fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
@@ -313,49 +435,40 @@ void ggml_cl_sgemm_wrapper(
313
435
  cl_event ev_a, ev_qb, ev_b;
314
436
 
315
437
  if (dequant) {
316
- err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
317
- err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
318
- CL_CHECK(err, "clSetKernelArg");
319
- err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
320
- CL_CHECK(err, "clEnqueueWriteBuffer qb");
438
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
439
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
440
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
321
441
  } else {
322
- err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
323
- CL_CHECK(err, "clEnqueueWriteBuffer b");
442
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
324
443
  }
325
444
 
326
- err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
327
- CL_CHECK(err, "clEnqueueWriteBuffer a");
445
+ CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
328
446
  if (dequant) {
329
- err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
330
- CL_CHECK(err, "clEnqueueNDRangeKernel");
331
- clReleaseEvent(ev_qb);
447
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
448
+ CL_CHECK(clReleaseEvent(ev_qb));
332
449
  }
333
- clWaitForEvents(1, &ev_a);
334
- clWaitForEvents(1, &ev_b);
335
- clReleaseEvent(ev_a);
336
- clReleaseEvent(ev_b);
450
+ CL_CHECK(clWaitForEvents(1, &ev_a));
451
+ CL_CHECK(clWaitForEvents(1, &ev_b));
452
+ CL_CHECK(clReleaseEvent(ev_a));
453
+ CL_CHECK(clReleaseEvent(ev_b));
337
454
 
338
455
  cl_event ev_sgemm;
339
- CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
340
- (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
341
- m, n, k,
342
- alpha,
343
- cl_buffer_a, 0, lda,
344
- cl_buffer_b, 0, ldb,
345
- beta,
346
- cl_buffer_c, 0, ldc,
347
- &queue, &ev_sgemm);
348
-
349
- if (status != CLBlastSuccess) {
350
- fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
351
- abort();
352
- }
456
+ CLBLAST_CHECK(CLBlastSgemm(
457
+ (CLBlastLayout)order,
458
+ (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
459
+ m, n, k,
460
+ alpha,
461
+ cl_buffer_a, 0, lda,
462
+ cl_buffer_b, 0, ldb,
463
+ beta,
464
+ cl_buffer_c, 0, ldc,
465
+ &queue, &ev_sgemm));
353
466
 
354
467
  cl_event ev_c;
355
- clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
468
+ CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
356
469
 
357
470
  // Wait for completion
358
- clWaitForEvents(1, &ev_c);
359
- clReleaseEvent(ev_sgemm);
360
- clReleaseEvent(ev_c);
471
+ CL_CHECK(clWaitForEvents(1, &ev_c));
472
+ CL_CHECK(clReleaseEvent(ev_sgemm));
473
+ CL_CHECK(clReleaseEvent(ev_c));
361
474
  }