llama_cpp 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/examples/README.md +32 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +553 -313
- data/ext/llama_cpp/src/ggml-metal.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.m +157 -19
- data/ext/llama_cpp/src/ggml-metal.metal +149 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +493 -4
- data/ext/llama_cpp/src/ggml.c +736 -98
- data/ext/llama_cpp/src/ggml.h +140 -9
- data/ext/llama_cpp/src/llama.cpp +58 -31
- data/ext/llama_cpp/src/llama.h +8 -9
- data/lib/llama_cpp/version.rb +2 -2
- metadata +3 -2
@@ -41,12 +41,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
|
|
41
41
|
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
42
42
|
// - the mapping is used during computation to determine the arguments of the compute kernels
|
43
43
|
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
|
44
|
+
// - max_size specifies the maximum size of a tensor and is used to create shared views such
|
45
|
+
// that it is guaranteed that the tensor will fit in at least one of the views
|
44
46
|
//
|
45
47
|
bool ggml_metal_add_buffer(
|
46
48
|
struct ggml_metal_context * ctx,
|
47
49
|
const char * name,
|
48
50
|
void * data,
|
49
|
-
size_t size
|
51
|
+
size_t size,
|
52
|
+
size_t max_size);
|
50
53
|
|
51
54
|
// set data from host memory into the device
|
52
55
|
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
@@ -57,6 +57,7 @@ struct ggml_metal_context {
|
|
57
57
|
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
|
58
58
|
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
59
59
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
60
|
+
GGML_METAL_DECL_KERNEL(norm);
|
60
61
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
61
62
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
62
63
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
@@ -66,8 +67,10 @@ struct ggml_metal_context {
|
|
66
67
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
|
67
68
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
68
69
|
GGML_METAL_DECL_KERNEL(rope);
|
70
|
+
GGML_METAL_DECL_KERNEL(alibi_f32);
|
69
71
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
70
72
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
73
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
71
74
|
|
72
75
|
#undef GGML_METAL_DECL_KERNEL
|
73
76
|
};
|
@@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
162
165
|
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
|
163
166
|
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
164
167
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
168
|
+
GGML_METAL_ADD_KERNEL(norm);
|
165
169
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
166
170
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
167
171
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
@@ -171,12 +175,22 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
171
175
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
|
172
176
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
173
177
|
GGML_METAL_ADD_KERNEL(rope);
|
178
|
+
GGML_METAL_ADD_KERNEL(alibi_f32);
|
174
179
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
175
180
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
181
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
176
182
|
|
177
183
|
#undef GGML_METAL_ADD_KERNEL
|
178
184
|
}
|
179
185
|
|
186
|
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
187
|
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
188
|
+
if (ctx->device.maxTransferRate != 0) {
|
189
|
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
190
|
+
} else {
|
191
|
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
192
|
+
}
|
193
|
+
|
180
194
|
return ctx;
|
181
195
|
}
|
182
196
|
|
@@ -193,10 +207,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
193
207
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
194
208
|
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
195
209
|
|
210
|
+
const int64_t tsize = ggml_nbytes(t);
|
211
|
+
|
212
|
+
// find the view that contains the tensor fully
|
196
213
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
197
214
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
198
215
|
|
199
|
-
if (ioffs >= 0 && ioffs
|
216
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
200
217
|
*offs = (size_t) ioffs;
|
201
218
|
|
202
219
|
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
@@ -214,7 +231,8 @@ bool ggml_metal_add_buffer(
|
|
214
231
|
struct ggml_metal_context * ctx,
|
215
232
|
const char * name,
|
216
233
|
void * data,
|
217
|
-
size_t size
|
234
|
+
size_t size,
|
235
|
+
size_t max_size) {
|
218
236
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
219
237
|
fprintf(stderr, "%s: too many buffers\n", __func__);
|
220
238
|
return false;
|
@@ -231,30 +249,68 @@ bool ggml_metal_add_buffer(
|
|
231
249
|
}
|
232
250
|
}
|
233
251
|
|
234
|
-
size_t
|
235
|
-
|
236
|
-
|
237
|
-
|
252
|
+
const size_t size_page = getpagesize();
|
253
|
+
|
254
|
+
size_t size_aligned = size;
|
255
|
+
if ((size_aligned % size_page) != 0) {
|
256
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
238
257
|
}
|
239
258
|
|
240
|
-
|
241
|
-
ctx->
|
242
|
-
|
259
|
+
// the buffer fits into the max buffer size allowed by the device
|
260
|
+
if (size_aligned <= ctx->device.maxBufferLength) {
|
261
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
262
|
+
ctx->buffers[ctx->n_buffers].data = data;
|
263
|
+
ctx->buffers[ctx->n_buffers].size = size;
|
243
264
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
265
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
266
|
+
|
267
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
268
|
+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
269
|
+
return false;
|
270
|
+
}
|
271
|
+
|
272
|
+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
249
273
|
|
250
|
-
|
251
|
-
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
|
252
|
-
return false;
|
274
|
+
++ctx->n_buffers;
|
253
275
|
} else {
|
254
|
-
|
276
|
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
277
|
+
// one of the views
|
278
|
+
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
279
|
+
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
280
|
+
const size_t size_view = ctx->device.maxBufferLength;
|
281
|
+
|
282
|
+
for (size_t i = 0; i < size; i += size_step) {
|
283
|
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
284
|
+
|
285
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
286
|
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
287
|
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
288
|
+
|
289
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
290
|
+
|
291
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
292
|
+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
293
|
+
return false;
|
294
|
+
}
|
295
|
+
|
296
|
+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
297
|
+
if (i + size_step < size) {
|
298
|
+
fprintf(stderr, "\n");
|
299
|
+
}
|
300
|
+
|
301
|
+
++ctx->n_buffers;
|
302
|
+
}
|
255
303
|
}
|
256
304
|
|
257
|
-
|
305
|
+
fprintf(stderr, ", (%8.2f / %8.2f)",
|
306
|
+
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
307
|
+
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
308
|
+
|
309
|
+
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
310
|
+
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
|
311
|
+
} else {
|
312
|
+
fprintf(stderr, "\n");
|
313
|
+
}
|
258
314
|
}
|
259
315
|
|
260
316
|
return true;
|
@@ -735,6 +791,70 @@ void ggml_metal_graph_compute(
|
|
735
791
|
|
736
792
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
737
793
|
} break;
|
794
|
+
case GGML_OP_NORM:
|
795
|
+
{
|
796
|
+
if (encoder == nil) {
|
797
|
+
encoder = [command_buffer computeCommandEncoder];
|
798
|
+
}
|
799
|
+
|
800
|
+
const float eps = 1e-5f;
|
801
|
+
|
802
|
+
const int nth = 256;
|
803
|
+
|
804
|
+
[encoder setComputePipelineState:ctx->pipeline_norm];
|
805
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
806
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
807
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
808
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
809
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
810
|
+
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
811
|
+
|
812
|
+
const int64_t nrows = ggml_nrows(src0);
|
813
|
+
|
814
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
815
|
+
} break;
|
816
|
+
case GGML_OP_ALIBI:
|
817
|
+
{
|
818
|
+
if (encoder == nil) {
|
819
|
+
encoder = [command_buffer computeCommandEncoder];
|
820
|
+
}
|
821
|
+
|
822
|
+
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
823
|
+
|
824
|
+
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
|
825
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
826
|
+
const float max_bias = ((float *) src1->data)[2];
|
827
|
+
|
828
|
+
if (__builtin_popcount(n_head) != 1) {
|
829
|
+
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
830
|
+
}
|
831
|
+
|
832
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
833
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
834
|
+
|
835
|
+
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
836
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
837
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
838
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
839
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
840
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
841
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
842
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
843
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
844
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
845
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
846
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
847
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
848
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
849
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
850
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
851
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
852
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
853
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
854
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
855
|
+
const int nth = 32;
|
856
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
857
|
+
} break;
|
738
858
|
case GGML_OP_ROPE:
|
739
859
|
{
|
740
860
|
if (encoder == nil) {
|
@@ -788,6 +908,14 @@ void ggml_metal_graph_compute(
|
|
788
908
|
default: GGML_ASSERT(false && "not implemented");
|
789
909
|
};
|
790
910
|
} break;
|
911
|
+
case GGML_TYPE_F16:
|
912
|
+
{
|
913
|
+
switch (dstt) {
|
914
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
915
|
+
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
|
916
|
+
default: GGML_ASSERT(false && "not implemented");
|
917
|
+
};
|
918
|
+
} break;
|
791
919
|
default: GGML_ASSERT(false && "not implemented");
|
792
920
|
}
|
793
921
|
|
@@ -831,4 +959,14 @@ void ggml_metal_graph_compute(
|
|
831
959
|
dispatch_barrier_sync(queue, ^{});
|
832
960
|
|
833
961
|
[command_buffers[n_cb - 1] waitUntilCompleted];
|
962
|
+
|
963
|
+
// check status of command buffers
|
964
|
+
// needed to detect if the device ran out-of-memory for example (#1881)
|
965
|
+
for (int i = 0; i < n_cb; i++) {
|
966
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
|
967
|
+
if (status != MTLCommandBufferStatusCompleted) {
|
968
|
+
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
969
|
+
GGML_ASSERT(false);
|
970
|
+
}
|
971
|
+
}
|
834
972
|
}
|
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
|
|
256
256
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
257
257
|
}
|
258
258
|
|
259
|
+
kernel void kernel_norm(
|
260
|
+
device const void * src0,
|
261
|
+
device float * dst,
|
262
|
+
constant int64_t & ne00,
|
263
|
+
constant uint64_t & nb01,
|
264
|
+
constant float & eps,
|
265
|
+
threadgroup float * sum [[threadgroup(0)]],
|
266
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
267
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
268
|
+
uint ntg[[threads_per_threadgroup]]) {
|
269
|
+
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
270
|
+
// MEAN
|
271
|
+
// parallel sum
|
272
|
+
sum[tpitg] = 0.0f;
|
273
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
274
|
+
sum[tpitg] += x[i00];
|
275
|
+
}
|
276
|
+
// reduce
|
277
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
278
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
279
|
+
if (tpitg < i) {
|
280
|
+
sum[tpitg] += sum[tpitg + i];
|
281
|
+
}
|
282
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
283
|
+
}
|
284
|
+
// broadcast
|
285
|
+
if (tpitg == 0) {
|
286
|
+
sum[0] /= ne00;
|
287
|
+
}
|
288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
289
|
+
const float mean = sum[0];
|
290
|
+
|
291
|
+
// recenter
|
292
|
+
device float * y = dst + tgpig*ne00;
|
293
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
294
|
+
y[i00] = x[i00] - mean;
|
295
|
+
}
|
296
|
+
|
297
|
+
// VARIANCE
|
298
|
+
// parallel sum
|
299
|
+
sum[tpitg] = 0.0f;
|
300
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
301
|
+
sum[tpitg] += y[i00] * y[i00];
|
302
|
+
}
|
303
|
+
// reduce
|
304
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
305
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
306
|
+
if (tpitg < i) {
|
307
|
+
sum[tpitg] += sum[tpitg + i];
|
308
|
+
}
|
309
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
310
|
+
}
|
311
|
+
// broadcast
|
312
|
+
if (tpitg == 0) {
|
313
|
+
sum[0] /= ne00;
|
314
|
+
}
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
316
|
+
const float variance = sum[0];
|
317
|
+
|
318
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
319
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
320
|
+
y[i00] = y[i00] * scale;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
|
324
|
+
|
259
325
|
kernel void kernel_rms_norm(
|
260
326
|
device const void * src0,
|
261
327
|
device float * dst,
|
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
|
|
485
551
|
}
|
486
552
|
}
|
487
553
|
|
554
|
+
kernel void kernel_alibi_f32(
|
555
|
+
device const float * src0,
|
556
|
+
device float * dst,
|
557
|
+
constant int64_t & ne00,
|
558
|
+
constant int64_t & ne01,
|
559
|
+
constant int64_t & ne02,
|
560
|
+
constant int64_t & ne03,
|
561
|
+
constant uint64_t & nb00,
|
562
|
+
constant uint64_t & nb01,
|
563
|
+
constant uint64_t & nb02,
|
564
|
+
constant uint64_t & nb03,
|
565
|
+
constant int64_t & ne0,
|
566
|
+
constant int64_t & ne1,
|
567
|
+
constant int64_t & ne2,
|
568
|
+
constant int64_t & ne3,
|
569
|
+
constant uint64_t & nb0,
|
570
|
+
constant uint64_t & nb1,
|
571
|
+
constant uint64_t & nb2,
|
572
|
+
constant uint64_t & nb3,
|
573
|
+
constant float & m0,
|
574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
575
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
576
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
577
|
+
const int64_t i03 = tgpig[2];
|
578
|
+
const int64_t i02 = tgpig[1];
|
579
|
+
const int64_t i01 = tgpig[0];
|
580
|
+
|
581
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
582
|
+
|
583
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
584
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
585
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
586
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
587
|
+
|
588
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
589
|
+
float m_k = pow(m0, i2 + 1);
|
590
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
591
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
592
|
+
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
488
596
|
kernel void kernel_rope(
|
489
597
|
device const void * src0,
|
490
598
|
device float * dst,
|
@@ -540,6 +648,47 @@ kernel void kernel_rope(
|
|
540
648
|
}
|
541
649
|
}
|
542
650
|
|
651
|
+
kernel void kernel_cpy_f16_f16(
|
652
|
+
device const half * src0,
|
653
|
+
device half * dst,
|
654
|
+
constant int64_t & ne00,
|
655
|
+
constant int64_t & ne01,
|
656
|
+
constant int64_t & ne02,
|
657
|
+
constant int64_t & ne03,
|
658
|
+
constant uint64_t & nb00,
|
659
|
+
constant uint64_t & nb01,
|
660
|
+
constant uint64_t & nb02,
|
661
|
+
constant uint64_t & nb03,
|
662
|
+
constant int64_t & ne0,
|
663
|
+
constant int64_t & ne1,
|
664
|
+
constant int64_t & ne2,
|
665
|
+
constant int64_t & ne3,
|
666
|
+
constant uint64_t & nb0,
|
667
|
+
constant uint64_t & nb1,
|
668
|
+
constant uint64_t & nb2,
|
669
|
+
constant uint64_t & nb3,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
672
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
673
|
+
const int64_t i03 = tgpig[2];
|
674
|
+
const int64_t i02 = tgpig[1];
|
675
|
+
const int64_t i01 = tgpig[0];
|
676
|
+
|
677
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
678
|
+
|
679
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
680
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
681
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
682
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
683
|
+
|
684
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
685
|
+
|
686
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
687
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
688
|
+
dst_data[i00] = src[0];
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
543
692
|
kernel void kernel_cpy_f32_f16(
|
544
693
|
device const float * src0,
|
545
694
|
device half * dst,
|