llama_cpp 0.9.2 → 0.9.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,5 +1,6 @@
1
1
  #import "ggml-metal.h"
2
2
 
3
+ #import "ggml-backend-impl.h"
3
4
  #import "ggml.h"
4
5
 
5
6
  #import <Foundation/Foundation.h>
@@ -23,7 +24,7 @@
23
24
 
24
25
  #define UNUSED(x) (void)(x)
25
26
 
26
- #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
27
+ #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
27
28
 
28
29
  struct ggml_metal_buffer {
29
30
  const char * name;
@@ -85,6 +86,7 @@ struct ggml_metal_context {
85
86
  GGML_METAL_DECL_KERNEL(rms_norm);
86
87
  GGML_METAL_DECL_KERNEL(norm);
87
88
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
88
90
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
89
91
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
90
92
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -113,6 +115,7 @@ struct ggml_metal_context {
113
115
  GGML_METAL_DECL_KERNEL(rope_f32);
114
116
  GGML_METAL_DECL_KERNEL(rope_f16);
115
117
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
+ GGML_METAL_DECL_KERNEL(im2col_f16);
116
119
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
117
120
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
118
121
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -125,7 +128,7 @@ struct ggml_metal_context {
125
128
  // MSL code
126
129
  // TODO: move the contents here when ready
127
130
  // for now it is easier to work in a separate file
128
- static NSString * const msl_library_source = @"see metal.metal";
131
+ //static NSString * const msl_library_source = @"see metal.metal";
129
132
 
130
133
  // Here to assist with NSBundle Path Hack
131
134
  @interface GGMLMetalClass : NSObject
@@ -141,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
141
144
  ggml_metal_log_user_data = user_data;
142
145
  }
143
146
 
144
- static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
147
+ GGML_ATTRIBUTE_FORMAT(2, 3)
148
+ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
145
149
  if (ggml_metal_log_callback != NULL) {
146
150
  va_list args;
147
151
  va_start(args, format);
@@ -209,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
209
213
  } else {
210
214
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
211
215
 
212
- NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
216
+ NSString * sourcePath;
217
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
218
+ if (ggmlMetalPathResources) {
219
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
+ } else {
221
+ sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
222
+ }
213
223
  if (sourcePath == nil) {
214
224
  GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
215
225
  sourcePath = @"ggml-metal.metal";
@@ -280,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
280
290
  GGML_METAL_ADD_KERNEL(rms_norm);
281
291
  GGML_METAL_ADD_KERNEL(norm);
282
292
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
283
294
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
284
295
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
285
296
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -310,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
310
321
  GGML_METAL_ADD_KERNEL(rope_f32);
311
322
  GGML_METAL_ADD_KERNEL(rope_f16);
312
323
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
+ GGML_METAL_ADD_KERNEL(im2col_f16);
313
325
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
314
326
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
315
327
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -328,15 +340,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
328
340
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
329
341
  for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
330
342
  if ([ctx->device supportsFamily:i]) {
331
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
343
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
332
344
  break;
333
345
  }
334
346
  }
335
347
 
336
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
337
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
348
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
338
350
  if (ctx->device.maxTransferRate != 0) {
339
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
351
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
340
352
  } else {
341
353
  GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
342
354
  }
@@ -379,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
379
391
  GGML_METAL_DEL_KERNEL(rms_norm);
380
392
  GGML_METAL_DEL_KERNEL(norm);
381
393
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
394
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
382
395
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
383
396
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
384
397
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -409,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
409
422
  GGML_METAL_DEL_KERNEL(rope_f32);
410
423
  GGML_METAL_DEL_KERNEL(rope_f16);
411
424
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
+ GGML_METAL_DEL_KERNEL(im2col_f16);
412
426
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
413
427
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
414
428
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -466,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
466
480
 
467
481
  const int64_t tsize = ggml_nbytes(t);
468
482
 
483
+ if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
+ ctx = t->buffer->backend->context;
485
+ }
486
+
469
487
  // find the view that contains the tensor fully
470
488
  for (int i = 0; i < ctx->n_buffers; ++i) {
471
489
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -523,11 +541,11 @@ bool ggml_metal_add_buffer(
523
541
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
524
542
 
525
543
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
526
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
544
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
527
545
  return false;
528
546
  }
529
547
 
530
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
548
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
531
549
 
532
550
  ++ctx->n_buffers;
533
551
  } else {
@@ -547,11 +565,11 @@ bool ggml_metal_add_buffer(
547
565
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
548
566
 
549
567
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
550
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
568
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
551
569
  return false;
552
570
  }
553
571
 
554
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
572
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
555
573
  if (i + size_step < size) {
556
574
  GGML_METAL_LOG_INFO("\n");
557
575
  }
@@ -566,7 +584,7 @@ bool ggml_metal_add_buffer(
566
584
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
567
585
 
568
586
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
569
- GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
587
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
570
588
  } else {
571
589
  GGML_METAL_LOG_INFO("\n");
572
590
  }
@@ -744,6 +762,20 @@ void ggml_metal_graph_compute(
744
762
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
745
763
  struct ggml_tensor * dst = gf->nodes[i];
746
764
 
765
+ switch (dst->op) {
766
+ case GGML_OP_NONE:
767
+ case GGML_OP_RESHAPE:
768
+ case GGML_OP_VIEW:
769
+ case GGML_OP_TRANSPOSE:
770
+ case GGML_OP_PERMUTE:
771
+ {
772
+ // noop -> next node
773
+ } continue;
774
+ default:
775
+ {
776
+ } break;
777
+ }
778
+
747
779
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
748
780
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
749
781
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -797,14 +829,6 @@ void ggml_metal_graph_compute(
797
829
  //}
798
830
 
799
831
  switch (dst->op) {
800
- case GGML_OP_NONE:
801
- case GGML_OP_RESHAPE:
802
- case GGML_OP_VIEW:
803
- case GGML_OP_TRANSPOSE:
804
- case GGML_OP_PERMUTE:
805
- {
806
- // noop
807
- } break;
808
832
  case GGML_OP_CONCAT:
809
833
  {
810
834
  const int64_t nb = ne00;
@@ -1017,7 +1041,7 @@ void ggml_metal_graph_compute(
1017
1041
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1018
1042
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1019
1043
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1020
- [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
1044
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1021
1045
 
1022
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1023
1047
  } break;
@@ -1126,6 +1150,7 @@ void ggml_metal_graph_compute(
1126
1150
  switch (src0t) {
1127
1151
  case GGML_TYPE_F32:
1128
1152
  {
1153
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1129
1154
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1130
1155
  nrows = 4;
1131
1156
  } break;
@@ -1133,13 +1158,18 @@ void ggml_metal_graph_compute(
1133
1158
  {
1134
1159
  nth0 = 32;
1135
1160
  nth1 = 1;
1136
- if (ne11 * ne12 < 4) {
1137
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1138
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1139
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1140
- nrows = ne11;
1161
+ if (src1t == GGML_TYPE_F32) {
1162
+ if (ne11 * ne12 < 4) {
1163
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1164
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1165
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1166
+ nrows = ne11;
1167
+ } else {
1168
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1169
+ nrows = 4;
1170
+ }
1141
1171
  } else {
1142
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1172
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1143
1173
  nrows = 4;
1144
1174
  }
1145
1175
  } break;
@@ -1329,7 +1359,7 @@ void ggml_metal_graph_compute(
1329
1359
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1330
1360
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1331
1361
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1332
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1362
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1333
1363
 
1334
1364
  const int64_t nrows = ggml_nrows(src0);
1335
1365
 
@@ -1348,7 +1378,7 @@ void ggml_metal_graph_compute(
1348
1378
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1349
1379
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1350
1380
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1351
- [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
1381
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1352
1382
 
1353
1383
  const int64_t nrows = ggml_nrows(src0);
1354
1384
 
@@ -1403,8 +1433,7 @@ void ggml_metal_graph_compute(
1403
1433
  const int n_past = ((int32_t *) dst->op_params)[0];
1404
1434
  const int n_dims = ((int32_t *) dst->op_params)[1];
1405
1435
  const int mode = ((int32_t *) dst->op_params)[2];
1406
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1407
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1436
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1408
1437
 
1409
1438
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1410
1439
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1452,6 +1481,58 @@ void ggml_metal_graph_compute(
1452
1481
 
1453
1482
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1454
1483
  } break;
1484
+ case GGML_OP_IM2COL:
1485
+ {
1486
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
1487
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1488
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
1489
+
1490
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1491
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1492
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1493
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1494
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1495
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1496
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1497
+
1498
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1499
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
1500
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
1501
+ const int32_t IW = src1->ne[0];
1502
+
1503
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
1504
+ const int32_t KW = src0->ne[0];
1505
+
1506
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
1507
+ const int32_t OW = dst->ne[1];
1508
+
1509
+ const int32_t CHW = IC * KH * KW;
1510
+
1511
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1512
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1513
+
1514
+ switch (src0->type) {
1515
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
1516
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1517
+ default: GGML_ASSERT(false);
1518
+ };
1519
+
1520
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1521
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1522
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1523
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1524
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1525
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1526
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1527
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1528
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1529
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1530
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1531
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1532
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1533
+
1534
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1535
+ } break;
1455
1536
  case GGML_OP_DUP:
1456
1537
  case GGML_OP_CPY:
1457
1538
  case GGML_OP_CONT:
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
792
792
  constant int64_t & ne0,
793
793
  constant int64_t & ne1,
794
794
  uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
795
+ uint tiisg[[thread_index_in_simdgroup]]) {
796
796
 
797
797
  const int64_t r0 = tgpig.x;
798
798
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
844
844
  }
845
845
  }
846
846
 
847
+ #define N_F16_F16 4
848
+
849
+ kernel void kernel_mul_mv_f16_f16(
850
+ device const char * src0,
851
+ device const char * src1,
852
+ device float * dst,
853
+ constant int64_t & ne00,
854
+ constant int64_t & ne01,
855
+ constant int64_t & ne02,
856
+ constant uint64_t & nb00,
857
+ constant uint64_t & nb01,
858
+ constant uint64_t & nb02,
859
+ constant int64_t & ne10,
860
+ constant int64_t & ne11,
861
+ constant int64_t & ne12,
862
+ constant uint64_t & nb10,
863
+ constant uint64_t & nb11,
864
+ constant uint64_t & nb12,
865
+ constant int64_t & ne0,
866
+ constant int64_t & ne1,
867
+ uint3 tgpig[[threadgroup_position_in_grid]],
868
+ uint tiisg[[thread_index_in_simdgroup]]) {
869
+
870
+ const int64_t r0 = tgpig.x;
871
+ const int64_t rb = tgpig.y*N_F16_F16;
872
+ const int64_t im = tgpig.z;
873
+
874
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
875
+
876
+ if (ne00 < 128) {
877
+ for (int row = 0; row < N_F16_F16; ++row) {
878
+ int r1 = rb + row;
879
+ if (r1 >= ne11) {
880
+ break;
881
+ }
882
+
883
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
884
+
885
+ float sumf = 0;
886
+ for (int i = tiisg; i < ne00; i += 32) {
887
+ sumf += (half) x[i] * (half) y[i];
888
+ }
889
+
890
+ float all_sum = simd_sum(sumf);
891
+ if (tiisg == 0) {
892
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
893
+ }
894
+ }
895
+ } else {
896
+ device const half4 * x4 = (device const half4 *)x;
897
+ for (int row = 0; row < N_F16_F16; ++row) {
898
+ int r1 = rb + row;
899
+ if (r1 >= ne11) {
900
+ break;
901
+ }
902
+
903
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
904
+ device const half4 * y4 = (device const half4 *) y;
905
+
906
+ float sumf = 0;
907
+ for (int i = tiisg; i < ne00/4; i += 32) {
908
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
909
+ }
910
+
911
+ float all_sum = simd_sum(sumf);
912
+ if (tiisg == 0) {
913
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
914
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
915
+ }
916
+ }
917
+ }
918
+ }
919
+
847
920
  kernel void kernel_mul_mv_f16_f32_1row(
848
921
  device const char * src0,
849
922
  device const char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
1229
1302
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
1303
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
1304
 
1305
+ kernel void kernel_im2col_f16(
1306
+ device const float * x,
1307
+ device half * dst,
1308
+ constant int32_t & ofs0,
1309
+ constant int32_t & ofs1,
1310
+ constant int32_t & IW,
1311
+ constant int32_t & IH,
1312
+ constant int32_t & CHW,
1313
+ constant int32_t & s0,
1314
+ constant int32_t & s1,
1315
+ constant int32_t & p0,
1316
+ constant int32_t & p1,
1317
+ constant int32_t & d0,
1318
+ constant int32_t & d1,
1319
+ uint3 tgpig[[threadgroup_position_in_grid]],
1320
+ uint3 tgpg[[threadgroups_per_grid]],
1321
+ uint3 tpitg[[thread_position_in_threadgroup]],
1322
+ uint3 ntg[[threads_per_threadgroup]]) {
1323
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1324
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1325
+
1326
+ const int32_t offset_dst =
1327
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1328
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1329
+
1330
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1331
+ dst[offset_dst] = 0.0f;
1332
+ } else {
1333
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1334
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1335
+ }
1336
+ }
1337
+
1232
1338
  kernel void kernel_cpy_f16_f16(
1233
1339
  device const half * src0,
1234
1340
  device half * dst,