llama_cpp 0.9.2 → 0.9.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  #import "ggml-metal.h"
2
2
 
3
+ #import "ggml-backend-impl.h"
3
4
  #import "ggml.h"
4
5
 
5
6
  #import <Foundation/Foundation.h>
@@ -23,7 +24,7 @@
23
24
 
24
25
  #define UNUSED(x) (void)(x)
25
26
 
26
- #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
27
+ #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
27
28
 
28
29
  struct ggml_metal_buffer {
29
30
  const char * name;
@@ -85,6 +86,7 @@ struct ggml_metal_context {
85
86
  GGML_METAL_DECL_KERNEL(rms_norm);
86
87
  GGML_METAL_DECL_KERNEL(norm);
87
88
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
88
90
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
89
91
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
90
92
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -113,6 +115,7 @@ struct ggml_metal_context {
113
115
  GGML_METAL_DECL_KERNEL(rope_f32);
114
116
  GGML_METAL_DECL_KERNEL(rope_f16);
115
117
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
+ GGML_METAL_DECL_KERNEL(im2col_f16);
116
119
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
117
120
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
118
121
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -125,7 +128,7 @@ struct ggml_metal_context {
125
128
  // MSL code
126
129
  // TODO: move the contents here when ready
127
130
  // for now it is easier to work in a separate file
128
- static NSString * const msl_library_source = @"see metal.metal";
131
+ //static NSString * const msl_library_source = @"see metal.metal";
129
132
 
130
133
  // Here to assist with NSBundle Path Hack
131
134
  @interface GGMLMetalClass : NSObject
@@ -141,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
141
144
  ggml_metal_log_user_data = user_data;
142
145
  }
143
146
 
144
- static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
147
+ GGML_ATTRIBUTE_FORMAT(2, 3)
148
+ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
145
149
  if (ggml_metal_log_callback != NULL) {
146
150
  va_list args;
147
151
  va_start(args, format);
@@ -209,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
209
213
  } else {
210
214
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
211
215
 
212
- NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
216
+ NSString * sourcePath;
217
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
218
+ if (ggmlMetalPathResources) {
219
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
+ } else {
221
+ sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
222
+ }
213
223
  if (sourcePath == nil) {
214
224
  GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
215
225
  sourcePath = @"ggml-metal.metal";
@@ -280,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
280
290
  GGML_METAL_ADD_KERNEL(rms_norm);
281
291
  GGML_METAL_ADD_KERNEL(norm);
282
292
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
283
294
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
284
295
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
285
296
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -310,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
310
321
  GGML_METAL_ADD_KERNEL(rope_f32);
311
322
  GGML_METAL_ADD_KERNEL(rope_f16);
312
323
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
+ GGML_METAL_ADD_KERNEL(im2col_f16);
313
325
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
314
326
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
315
327
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -328,15 +340,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
328
340
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
329
341
  for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
330
342
  if ([ctx->device supportsFamily:i]) {
331
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
343
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
332
344
  break;
333
345
  }
334
346
  }
335
347
 
336
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
337
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
348
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
338
350
  if (ctx->device.maxTransferRate != 0) {
339
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
351
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
340
352
  } else {
341
353
  GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
342
354
  }
@@ -379,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
379
391
  GGML_METAL_DEL_KERNEL(rms_norm);
380
392
  GGML_METAL_DEL_KERNEL(norm);
381
393
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
394
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
382
395
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
383
396
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
384
397
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -409,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
409
422
  GGML_METAL_DEL_KERNEL(rope_f32);
410
423
  GGML_METAL_DEL_KERNEL(rope_f16);
411
424
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
+ GGML_METAL_DEL_KERNEL(im2col_f16);
412
426
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
413
427
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
414
428
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -466,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
466
480
 
467
481
  const int64_t tsize = ggml_nbytes(t);
468
482
 
483
+ if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
+ ctx = t->buffer->backend->context;
485
+ }
486
+
469
487
  // find the view that contains the tensor fully
470
488
  for (int i = 0; i < ctx->n_buffers; ++i) {
471
489
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -523,11 +541,11 @@ bool ggml_metal_add_buffer(
523
541
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
524
542
 
525
543
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
526
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
544
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
527
545
  return false;
528
546
  }
529
547
 
530
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
548
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
531
549
 
532
550
  ++ctx->n_buffers;
533
551
  } else {
@@ -547,11 +565,11 @@ bool ggml_metal_add_buffer(
547
565
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
548
566
 
549
567
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
550
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
568
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
551
569
  return false;
552
570
  }
553
571
 
554
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
572
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
555
573
  if (i + size_step < size) {
556
574
  GGML_METAL_LOG_INFO("\n");
557
575
  }
@@ -566,7 +584,7 @@ bool ggml_metal_add_buffer(
566
584
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
567
585
 
568
586
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
569
- GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
587
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
570
588
  } else {
571
589
  GGML_METAL_LOG_INFO("\n");
572
590
  }
@@ -744,6 +762,20 @@ void ggml_metal_graph_compute(
744
762
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
745
763
  struct ggml_tensor * dst = gf->nodes[i];
746
764
 
765
+ switch (dst->op) {
766
+ case GGML_OP_NONE:
767
+ case GGML_OP_RESHAPE:
768
+ case GGML_OP_VIEW:
769
+ case GGML_OP_TRANSPOSE:
770
+ case GGML_OP_PERMUTE:
771
+ {
772
+ // noop -> next node
773
+ } continue;
774
+ default:
775
+ {
776
+ } break;
777
+ }
778
+
747
779
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
748
780
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
749
781
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -797,14 +829,6 @@ void ggml_metal_graph_compute(
797
829
  //}
798
830
 
799
831
  switch (dst->op) {
800
- case GGML_OP_NONE:
801
- case GGML_OP_RESHAPE:
802
- case GGML_OP_VIEW:
803
- case GGML_OP_TRANSPOSE:
804
- case GGML_OP_PERMUTE:
805
- {
806
- // noop
807
- } break;
808
832
  case GGML_OP_CONCAT:
809
833
  {
810
834
  const int64_t nb = ne00;
@@ -1017,7 +1041,7 @@ void ggml_metal_graph_compute(
1017
1041
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1018
1042
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1019
1043
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1020
- [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
1044
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1021
1045
 
1022
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1023
1047
  } break;
@@ -1126,6 +1150,7 @@ void ggml_metal_graph_compute(
1126
1150
  switch (src0t) {
1127
1151
  case GGML_TYPE_F32:
1128
1152
  {
1153
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1129
1154
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1130
1155
  nrows = 4;
1131
1156
  } break;
@@ -1133,13 +1158,18 @@ void ggml_metal_graph_compute(
1133
1158
  {
1134
1159
  nth0 = 32;
1135
1160
  nth1 = 1;
1136
- if (ne11 * ne12 < 4) {
1137
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1138
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1139
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1140
- nrows = ne11;
1161
+ if (src1t == GGML_TYPE_F32) {
1162
+ if (ne11 * ne12 < 4) {
1163
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1164
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1165
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1166
+ nrows = ne11;
1167
+ } else {
1168
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1169
+ nrows = 4;
1170
+ }
1141
1171
  } else {
1142
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1172
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1143
1173
  nrows = 4;
1144
1174
  }
1145
1175
  } break;
@@ -1329,7 +1359,7 @@ void ggml_metal_graph_compute(
1329
1359
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1330
1360
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1331
1361
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1332
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1362
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1333
1363
 
1334
1364
  const int64_t nrows = ggml_nrows(src0);
1335
1365
 
@@ -1348,7 +1378,7 @@ void ggml_metal_graph_compute(
1348
1378
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1349
1379
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1350
1380
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1351
- [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
1381
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1352
1382
 
1353
1383
  const int64_t nrows = ggml_nrows(src0);
1354
1384
 
@@ -1403,8 +1433,7 @@ void ggml_metal_graph_compute(
1403
1433
  const int n_past = ((int32_t *) dst->op_params)[0];
1404
1434
  const int n_dims = ((int32_t *) dst->op_params)[1];
1405
1435
  const int mode = ((int32_t *) dst->op_params)[2];
1406
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1407
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1436
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1408
1437
 
1409
1438
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1410
1439
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1452,6 +1481,58 @@ void ggml_metal_graph_compute(
1452
1481
 
1453
1482
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1454
1483
  } break;
1484
+ case GGML_OP_IM2COL:
1485
+ {
1486
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
1487
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1488
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
1489
+
1490
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1491
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1492
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1493
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1494
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1495
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1496
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1497
+
1498
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1499
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
1500
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
1501
+ const int32_t IW = src1->ne[0];
1502
+
1503
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
1504
+ const int32_t KW = src0->ne[0];
1505
+
1506
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
1507
+ const int32_t OW = dst->ne[1];
1508
+
1509
+ const int32_t CHW = IC * KH * KW;
1510
+
1511
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1512
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1513
+
1514
+ switch (src0->type) {
1515
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
1516
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1517
+ default: GGML_ASSERT(false);
1518
+ };
1519
+
1520
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1521
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1522
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1523
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1524
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1525
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1526
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1527
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1528
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1529
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1530
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1531
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1532
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1533
+
1534
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1535
+ } break;
1455
1536
  case GGML_OP_DUP:
1456
1537
  case GGML_OP_CPY:
1457
1538
  case GGML_OP_CONT:
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
792
792
  constant int64_t & ne0,
793
793
  constant int64_t & ne1,
794
794
  uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
795
+ uint tiisg[[thread_index_in_simdgroup]]) {
796
796
 
797
797
  const int64_t r0 = tgpig.x;
798
798
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
844
844
  }
845
845
  }
846
846
 
847
+ #define N_F16_F16 4
848
+
849
+ kernel void kernel_mul_mv_f16_f16(
850
+ device const char * src0,
851
+ device const char * src1,
852
+ device float * dst,
853
+ constant int64_t & ne00,
854
+ constant int64_t & ne01,
855
+ constant int64_t & ne02,
856
+ constant uint64_t & nb00,
857
+ constant uint64_t & nb01,
858
+ constant uint64_t & nb02,
859
+ constant int64_t & ne10,
860
+ constant int64_t & ne11,
861
+ constant int64_t & ne12,
862
+ constant uint64_t & nb10,
863
+ constant uint64_t & nb11,
864
+ constant uint64_t & nb12,
865
+ constant int64_t & ne0,
866
+ constant int64_t & ne1,
867
+ uint3 tgpig[[threadgroup_position_in_grid]],
868
+ uint tiisg[[thread_index_in_simdgroup]]) {
869
+
870
+ const int64_t r0 = tgpig.x;
871
+ const int64_t rb = tgpig.y*N_F16_F16;
872
+ const int64_t im = tgpig.z;
873
+
874
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
875
+
876
+ if (ne00 < 128) {
877
+ for (int row = 0; row < N_F16_F16; ++row) {
878
+ int r1 = rb + row;
879
+ if (r1 >= ne11) {
880
+ break;
881
+ }
882
+
883
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
884
+
885
+ float sumf = 0;
886
+ for (int i = tiisg; i < ne00; i += 32) {
887
+ sumf += (half) x[i] * (half) y[i];
888
+ }
889
+
890
+ float all_sum = simd_sum(sumf);
891
+ if (tiisg == 0) {
892
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
893
+ }
894
+ }
895
+ } else {
896
+ device const half4 * x4 = (device const half4 *)x;
897
+ for (int row = 0; row < N_F16_F16; ++row) {
898
+ int r1 = rb + row;
899
+ if (r1 >= ne11) {
900
+ break;
901
+ }
902
+
903
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
904
+ device const half4 * y4 = (device const half4 *) y;
905
+
906
+ float sumf = 0;
907
+ for (int i = tiisg; i < ne00/4; i += 32) {
908
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
909
+ }
910
+
911
+ float all_sum = simd_sum(sumf);
912
+ if (tiisg == 0) {
913
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
914
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
915
+ }
916
+ }
917
+ }
918
+ }
919
+
847
920
  kernel void kernel_mul_mv_f16_f32_1row(
848
921
  device const char * src0,
849
922
  device const char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
1229
1302
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
1303
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
1304
 
1305
+ kernel void kernel_im2col_f16(
1306
+ device const float * x,
1307
+ device half * dst,
1308
+ constant int32_t & ofs0,
1309
+ constant int32_t & ofs1,
1310
+ constant int32_t & IW,
1311
+ constant int32_t & IH,
1312
+ constant int32_t & CHW,
1313
+ constant int32_t & s0,
1314
+ constant int32_t & s1,
1315
+ constant int32_t & p0,
1316
+ constant int32_t & p1,
1317
+ constant int32_t & d0,
1318
+ constant int32_t & d1,
1319
+ uint3 tgpig[[threadgroup_position_in_grid]],
1320
+ uint3 tgpg[[threadgroups_per_grid]],
1321
+ uint3 tpitg[[thread_position_in_threadgroup]],
1322
+ uint3 ntg[[threads_per_threadgroup]]) {
1323
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1324
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1325
+
1326
+ const int32_t offset_dst =
1327
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1328
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1329
+
1330
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1331
+ dst[offset_dst] = 0.0f;
1332
+ } else {
1333
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1334
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1335
+ }
1336
+ }
1337
+
1232
1338
  kernel void kernel_cpy_f16_f16(
1233
1339
  device const half * src0,
1234
1340
  device half * dst,