llama_cpp 0.2.1 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/examples/README.md +32 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +553 -313
- data/ext/llama_cpp/src/ggml-metal.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.m +157 -19
- data/ext/llama_cpp/src/ggml-metal.metal +149 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +493 -4
- data/ext/llama_cpp/src/ggml.c +736 -98
- data/ext/llama_cpp/src/ggml.h +140 -9
- data/ext/llama_cpp/src/llama.cpp +58 -31
- data/ext/llama_cpp/src/llama.h +8 -9
- data/lib/llama_cpp/version.rb +2 -2
- metadata +3 -2
@@ -41,12 +41,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
|
|
41
41
|
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
42
42
|
// - the mapping is used during computation to determine the arguments of the compute kernels
|
43
43
|
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
|
44
|
+
// - max_size specifies the maximum size of a tensor and is used to create shared views such
|
45
|
+
// that it is guaranteed that the tensor will fit in at least one of the views
|
44
46
|
//
|
45
47
|
bool ggml_metal_add_buffer(
|
46
48
|
struct ggml_metal_context * ctx,
|
47
49
|
const char * name,
|
48
50
|
void * data,
|
49
|
-
size_t size
|
51
|
+
size_t size,
|
52
|
+
size_t max_size);
|
50
53
|
|
51
54
|
// set data from host memory into the device
|
52
55
|
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
@@ -57,6 +57,7 @@ struct ggml_metal_context {
|
|
57
57
|
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
|
58
58
|
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
59
59
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
60
|
+
GGML_METAL_DECL_KERNEL(norm);
|
60
61
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
61
62
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
62
63
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
@@ -66,8 +67,10 @@ struct ggml_metal_context {
|
|
66
67
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
|
67
68
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
68
69
|
GGML_METAL_DECL_KERNEL(rope);
|
70
|
+
GGML_METAL_DECL_KERNEL(alibi_f32);
|
69
71
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
70
72
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
73
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
71
74
|
|
72
75
|
#undef GGML_METAL_DECL_KERNEL
|
73
76
|
};
|
@@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
162
165
|
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
|
163
166
|
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
164
167
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
168
|
+
GGML_METAL_ADD_KERNEL(norm);
|
165
169
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
166
170
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
167
171
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
@@ -171,12 +175,22 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
171
175
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
|
172
176
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
173
177
|
GGML_METAL_ADD_KERNEL(rope);
|
178
|
+
GGML_METAL_ADD_KERNEL(alibi_f32);
|
174
179
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
175
180
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
181
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
176
182
|
|
177
183
|
#undef GGML_METAL_ADD_KERNEL
|
178
184
|
}
|
179
185
|
|
186
|
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
187
|
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
188
|
+
if (ctx->device.maxTransferRate != 0) {
|
189
|
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
190
|
+
} else {
|
191
|
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
192
|
+
}
|
193
|
+
|
180
194
|
return ctx;
|
181
195
|
}
|
182
196
|
|
@@ -193,10 +207,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
193
207
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
194
208
|
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
195
209
|
|
210
|
+
const int64_t tsize = ggml_nbytes(t);
|
211
|
+
|
212
|
+
// find the view that contains the tensor fully
|
196
213
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
197
214
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
198
215
|
|
199
|
-
if (ioffs >= 0 && ioffs
|
216
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
200
217
|
*offs = (size_t) ioffs;
|
201
218
|
|
202
219
|
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
@@ -214,7 +231,8 @@ bool ggml_metal_add_buffer(
|
|
214
231
|
struct ggml_metal_context * ctx,
|
215
232
|
const char * name,
|
216
233
|
void * data,
|
217
|
-
size_t size
|
234
|
+
size_t size,
|
235
|
+
size_t max_size) {
|
218
236
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
219
237
|
fprintf(stderr, "%s: too many buffers\n", __func__);
|
220
238
|
return false;
|
@@ -231,30 +249,68 @@ bool ggml_metal_add_buffer(
|
|
231
249
|
}
|
232
250
|
}
|
233
251
|
|
234
|
-
size_t
|
235
|
-
|
236
|
-
|
237
|
-
|
252
|
+
const size_t size_page = getpagesize();
|
253
|
+
|
254
|
+
size_t size_aligned = size;
|
255
|
+
if ((size_aligned % size_page) != 0) {
|
256
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
238
257
|
}
|
239
258
|
|
240
|
-
|
241
|
-
ctx->
|
242
|
-
|
259
|
+
// the buffer fits into the max buffer size allowed by the device
|
260
|
+
if (size_aligned <= ctx->device.maxBufferLength) {
|
261
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
262
|
+
ctx->buffers[ctx->n_buffers].data = data;
|
263
|
+
ctx->buffers[ctx->n_buffers].size = size;
|
243
264
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
265
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
266
|
+
|
267
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
268
|
+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
269
|
+
return false;
|
270
|
+
}
|
271
|
+
|
272
|
+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
249
273
|
|
250
|
-
|
251
|
-
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
|
252
|
-
return false;
|
274
|
+
++ctx->n_buffers;
|
253
275
|
} else {
|
254
|
-
|
276
|
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
277
|
+
// one of the views
|
278
|
+
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
279
|
+
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
280
|
+
const size_t size_view = ctx->device.maxBufferLength;
|
281
|
+
|
282
|
+
for (size_t i = 0; i < size; i += size_step) {
|
283
|
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
284
|
+
|
285
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
286
|
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
287
|
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
288
|
+
|
289
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
290
|
+
|
291
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
292
|
+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
293
|
+
return false;
|
294
|
+
}
|
295
|
+
|
296
|
+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
297
|
+
if (i + size_step < size) {
|
298
|
+
fprintf(stderr, "\n");
|
299
|
+
}
|
300
|
+
|
301
|
+
++ctx->n_buffers;
|
302
|
+
}
|
255
303
|
}
|
256
304
|
|
257
|
-
|
305
|
+
fprintf(stderr, ", (%8.2f / %8.2f)",
|
306
|
+
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
307
|
+
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
308
|
+
|
309
|
+
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
310
|
+
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
|
311
|
+
} else {
|
312
|
+
fprintf(stderr, "\n");
|
313
|
+
}
|
258
314
|
}
|
259
315
|
|
260
316
|
return true;
|
@@ -735,6 +791,70 @@ void ggml_metal_graph_compute(
|
|
735
791
|
|
736
792
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
737
793
|
} break;
|
794
|
+
case GGML_OP_NORM:
|
795
|
+
{
|
796
|
+
if (encoder == nil) {
|
797
|
+
encoder = [command_buffer computeCommandEncoder];
|
798
|
+
}
|
799
|
+
|
800
|
+
const float eps = 1e-5f;
|
801
|
+
|
802
|
+
const int nth = 256;
|
803
|
+
|
804
|
+
[encoder setComputePipelineState:ctx->pipeline_norm];
|
805
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
806
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
807
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
808
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
809
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
810
|
+
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
811
|
+
|
812
|
+
const int64_t nrows = ggml_nrows(src0);
|
813
|
+
|
814
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
815
|
+
} break;
|
816
|
+
case GGML_OP_ALIBI:
|
817
|
+
{
|
818
|
+
if (encoder == nil) {
|
819
|
+
encoder = [command_buffer computeCommandEncoder];
|
820
|
+
}
|
821
|
+
|
822
|
+
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
823
|
+
|
824
|
+
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
|
825
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
826
|
+
const float max_bias = ((float *) src1->data)[2];
|
827
|
+
|
828
|
+
if (__builtin_popcount(n_head) != 1) {
|
829
|
+
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
830
|
+
}
|
831
|
+
|
832
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
833
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
834
|
+
|
835
|
+
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
836
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
837
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
838
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
839
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
840
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
841
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
842
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
843
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
844
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
845
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
846
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
847
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
848
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
849
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
850
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
851
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
852
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
853
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
854
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
855
|
+
const int nth = 32;
|
856
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
857
|
+
} break;
|
738
858
|
case GGML_OP_ROPE:
|
739
859
|
{
|
740
860
|
if (encoder == nil) {
|
@@ -788,6 +908,14 @@ void ggml_metal_graph_compute(
|
|
788
908
|
default: GGML_ASSERT(false && "not implemented");
|
789
909
|
};
|
790
910
|
} break;
|
911
|
+
case GGML_TYPE_F16:
|
912
|
+
{
|
913
|
+
switch (dstt) {
|
914
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
915
|
+
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
|
916
|
+
default: GGML_ASSERT(false && "not implemented");
|
917
|
+
};
|
918
|
+
} break;
|
791
919
|
default: GGML_ASSERT(false && "not implemented");
|
792
920
|
}
|
793
921
|
|
@@ -831,4 +959,14 @@ void ggml_metal_graph_compute(
|
|
831
959
|
dispatch_barrier_sync(queue, ^{});
|
832
960
|
|
833
961
|
[command_buffers[n_cb - 1] waitUntilCompleted];
|
962
|
+
|
963
|
+
// check status of command buffers
|
964
|
+
// needed to detect if the device ran out-of-memory for example (#1881)
|
965
|
+
for (int i = 0; i < n_cb; i++) {
|
966
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
|
967
|
+
if (status != MTLCommandBufferStatusCompleted) {
|
968
|
+
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
969
|
+
GGML_ASSERT(false);
|
970
|
+
}
|
971
|
+
}
|
834
972
|
}
|
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
|
|
256
256
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
257
257
|
}
|
258
258
|
|
259
|
+
kernel void kernel_norm(
|
260
|
+
device const void * src0,
|
261
|
+
device float * dst,
|
262
|
+
constant int64_t & ne00,
|
263
|
+
constant uint64_t & nb01,
|
264
|
+
constant float & eps,
|
265
|
+
threadgroup float * sum [[threadgroup(0)]],
|
266
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
267
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
268
|
+
uint ntg[[threads_per_threadgroup]]) {
|
269
|
+
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
270
|
+
// MEAN
|
271
|
+
// parallel sum
|
272
|
+
sum[tpitg] = 0.0f;
|
273
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
274
|
+
sum[tpitg] += x[i00];
|
275
|
+
}
|
276
|
+
// reduce
|
277
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
278
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
279
|
+
if (tpitg < i) {
|
280
|
+
sum[tpitg] += sum[tpitg + i];
|
281
|
+
}
|
282
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
283
|
+
}
|
284
|
+
// broadcast
|
285
|
+
if (tpitg == 0) {
|
286
|
+
sum[0] /= ne00;
|
287
|
+
}
|
288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
289
|
+
const float mean = sum[0];
|
290
|
+
|
291
|
+
// recenter
|
292
|
+
device float * y = dst + tgpig*ne00;
|
293
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
294
|
+
y[i00] = x[i00] - mean;
|
295
|
+
}
|
296
|
+
|
297
|
+
// VARIANCE
|
298
|
+
// parallel sum
|
299
|
+
sum[tpitg] = 0.0f;
|
300
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
301
|
+
sum[tpitg] += y[i00] * y[i00];
|
302
|
+
}
|
303
|
+
// reduce
|
304
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
305
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
306
|
+
if (tpitg < i) {
|
307
|
+
sum[tpitg] += sum[tpitg + i];
|
308
|
+
}
|
309
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
310
|
+
}
|
311
|
+
// broadcast
|
312
|
+
if (tpitg == 0) {
|
313
|
+
sum[0] /= ne00;
|
314
|
+
}
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
316
|
+
const float variance = sum[0];
|
317
|
+
|
318
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
319
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
320
|
+
y[i00] = y[i00] * scale;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
|
324
|
+
|
259
325
|
kernel void kernel_rms_norm(
|
260
326
|
device const void * src0,
|
261
327
|
device float * dst,
|
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
|
|
485
551
|
}
|
486
552
|
}
|
487
553
|
|
554
|
+
kernel void kernel_alibi_f32(
|
555
|
+
device const float * src0,
|
556
|
+
device float * dst,
|
557
|
+
constant int64_t & ne00,
|
558
|
+
constant int64_t & ne01,
|
559
|
+
constant int64_t & ne02,
|
560
|
+
constant int64_t & ne03,
|
561
|
+
constant uint64_t & nb00,
|
562
|
+
constant uint64_t & nb01,
|
563
|
+
constant uint64_t & nb02,
|
564
|
+
constant uint64_t & nb03,
|
565
|
+
constant int64_t & ne0,
|
566
|
+
constant int64_t & ne1,
|
567
|
+
constant int64_t & ne2,
|
568
|
+
constant int64_t & ne3,
|
569
|
+
constant uint64_t & nb0,
|
570
|
+
constant uint64_t & nb1,
|
571
|
+
constant uint64_t & nb2,
|
572
|
+
constant uint64_t & nb3,
|
573
|
+
constant float & m0,
|
574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
575
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
576
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
577
|
+
const int64_t i03 = tgpig[2];
|
578
|
+
const int64_t i02 = tgpig[1];
|
579
|
+
const int64_t i01 = tgpig[0];
|
580
|
+
|
581
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
582
|
+
|
583
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
584
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
585
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
586
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
587
|
+
|
588
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
589
|
+
float m_k = pow(m0, i2 + 1);
|
590
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
591
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
592
|
+
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
488
596
|
kernel void kernel_rope(
|
489
597
|
device const void * src0,
|
490
598
|
device float * dst,
|
@@ -540,6 +648,47 @@ kernel void kernel_rope(
|
|
540
648
|
}
|
541
649
|
}
|
542
650
|
|
651
|
+
kernel void kernel_cpy_f16_f16(
|
652
|
+
device const half * src0,
|
653
|
+
device half * dst,
|
654
|
+
constant int64_t & ne00,
|
655
|
+
constant int64_t & ne01,
|
656
|
+
constant int64_t & ne02,
|
657
|
+
constant int64_t & ne03,
|
658
|
+
constant uint64_t & nb00,
|
659
|
+
constant uint64_t & nb01,
|
660
|
+
constant uint64_t & nb02,
|
661
|
+
constant uint64_t & nb03,
|
662
|
+
constant int64_t & ne0,
|
663
|
+
constant int64_t & ne1,
|
664
|
+
constant int64_t & ne2,
|
665
|
+
constant int64_t & ne3,
|
666
|
+
constant uint64_t & nb0,
|
667
|
+
constant uint64_t & nb1,
|
668
|
+
constant uint64_t & nb2,
|
669
|
+
constant uint64_t & nb3,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
672
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
673
|
+
const int64_t i03 = tgpig[2];
|
674
|
+
const int64_t i02 = tgpig[1];
|
675
|
+
const int64_t i01 = tgpig[0];
|
676
|
+
|
677
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
678
|
+
|
679
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
680
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
681
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
682
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
683
|
+
|
684
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
685
|
+
|
686
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
687
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
688
|
+
dst_data[i00] = src[0];
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
543
692
|
kernel void kernel_cpy_f32_f16(
|
544
693
|
device const float * src0,
|
545
694
|
device half * dst,
|