llama_cpp 0.8.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +3 -11
- data/ext/llama_cpp/llama_cpp.cpp +228 -165
- data/ext/llama_cpp/src/ggml-cuda.cu +441 -77
- data/ext/llama_cpp/src/ggml-impl.h +237 -0
- data/ext/llama_cpp/src/ggml-metal.m +71 -42
- data/ext/llama_cpp/src/ggml-metal.metal +171 -35
- data/ext/llama_cpp/src/ggml-opencl.cpp +161 -169
- data/ext/llama_cpp/src/{k_quants.c → ggml-quants.c} +3329 -1099
- data/ext/llama_cpp/src/{k_quants.h → ggml-quants.h} +81 -22
- data/ext/llama_cpp/src/ggml.c +1303 -3419
- data/ext/llama_cpp/src/ggml.h +33 -11
- data/ext/llama_cpp/src/llama.cpp +1925 -2655
- data/ext/llama_cpp/src/llama.h +48 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +34 -14
- metadata +5 -4
@@ -0,0 +1,237 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "ggml.h"
|
4
|
+
|
5
|
+
// GGML internal header
|
6
|
+
|
7
|
+
#include <assert.h>
|
8
|
+
#include <stddef.h>
|
9
|
+
#include <stdbool.h>
|
10
|
+
#include <string.h> // memcpy
|
11
|
+
#include <math.h> // fabsf
|
12
|
+
|
13
|
+
#ifdef __cplusplus
|
14
|
+
extern "C" {
|
15
|
+
#endif
|
16
|
+
|
17
|
+
// static_assert should be a #define, but if it's not,
|
18
|
+
// fall back to the _Static_assert C11 keyword.
|
19
|
+
// if C99 - static_assert is noop
|
20
|
+
// ref: https://stackoverflow.com/a/53923785/4039976
|
21
|
+
#ifndef static_assert
|
22
|
+
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
|
23
|
+
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
24
|
+
#else
|
25
|
+
#define static_assert(cond, msg) struct global_scope_noop_trick
|
26
|
+
#endif
|
27
|
+
#endif
|
28
|
+
|
29
|
+
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
30
|
+
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
|
31
|
+
#ifndef __FMA__
|
32
|
+
#define __FMA__
|
33
|
+
#endif
|
34
|
+
#ifndef __F16C__
|
35
|
+
#define __F16C__
|
36
|
+
#endif
|
37
|
+
#ifndef __SSE3__
|
38
|
+
#define __SSE3__
|
39
|
+
#endif
|
40
|
+
#endif
|
41
|
+
|
42
|
+
#undef MIN
|
43
|
+
#undef MAX
|
44
|
+
|
45
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
46
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
47
|
+
|
48
|
+
// 16-bit float
|
49
|
+
// on Arm, we use __fp16
|
50
|
+
// on x86, we use uint16_t
|
51
|
+
#if defined(__ARM_NEON) && !defined(_MSC_VER)
|
52
|
+
|
53
|
+
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
54
|
+
//
|
55
|
+
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
56
|
+
//
|
57
|
+
#include <arm_neon.h>
|
58
|
+
|
59
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
|
60
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
|
61
|
+
|
62
|
+
#define GGML_FP16_TO_FP32(x) ((float) (x))
|
63
|
+
#define GGML_FP32_TO_FP16(x) (x)
|
64
|
+
|
65
|
+
#else
|
66
|
+
|
67
|
+
#ifdef __wasm_simd128__
|
68
|
+
#include <wasm_simd128.h>
|
69
|
+
#else
|
70
|
+
#ifdef __POWER9_VECTOR__
|
71
|
+
#include <altivec.h>
|
72
|
+
#undef bool
|
73
|
+
#define bool _Bool
|
74
|
+
#else
|
75
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
76
|
+
#include <intrin.h>
|
77
|
+
#else
|
78
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
|
79
|
+
#if !defined(__riscv)
|
80
|
+
#include <immintrin.h>
|
81
|
+
#endif
|
82
|
+
#endif
|
83
|
+
#endif
|
84
|
+
#endif
|
85
|
+
#endif
|
86
|
+
|
87
|
+
#ifdef __riscv_v_intrinsic
|
88
|
+
#include <riscv_vector.h>
|
89
|
+
#endif
|
90
|
+
|
91
|
+
#ifdef __F16C__
|
92
|
+
|
93
|
+
#ifdef _MSC_VER
|
94
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
|
95
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
|
96
|
+
#else
|
97
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
|
98
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
|
99
|
+
#endif
|
100
|
+
|
101
|
+
#elif defined(__POWER9_VECTOR__)
|
102
|
+
|
103
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
104
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
105
|
+
/* the inline asm below is about 12% faster than the lookup method */
|
106
|
+
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
107
|
+
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
108
|
+
|
109
|
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
110
|
+
register float f;
|
111
|
+
register double d;
|
112
|
+
__asm__(
|
113
|
+
"mtfprd %0,%2\n"
|
114
|
+
"xscvhpdp %0,%0\n"
|
115
|
+
"frsp %1,%0\n" :
|
116
|
+
/* temp */ "=d"(d),
|
117
|
+
/* out */ "=f"(f):
|
118
|
+
/* in */ "r"(h));
|
119
|
+
return f;
|
120
|
+
}
|
121
|
+
|
122
|
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
123
|
+
register double d;
|
124
|
+
register ggml_fp16_t r;
|
125
|
+
__asm__( /* xscvdphp can work on double or single precision */
|
126
|
+
"xscvdphp %0,%2\n"
|
127
|
+
"mffprd %1,%0\n" :
|
128
|
+
/* temp */ "=d"(d),
|
129
|
+
/* out */ "=r"(r):
|
130
|
+
/* in */ "f"(f));
|
131
|
+
return r;
|
132
|
+
}
|
133
|
+
|
134
|
+
#else
|
135
|
+
|
136
|
+
// FP16 <-> FP32
|
137
|
+
// ref: https://github.com/Maratyszcza/FP16
|
138
|
+
|
139
|
+
static inline float fp32_from_bits(uint32_t w) {
|
140
|
+
union {
|
141
|
+
uint32_t as_bits;
|
142
|
+
float as_value;
|
143
|
+
} fp32;
|
144
|
+
fp32.as_bits = w;
|
145
|
+
return fp32.as_value;
|
146
|
+
}
|
147
|
+
|
148
|
+
static inline uint32_t fp32_to_bits(float f) {
|
149
|
+
union {
|
150
|
+
float as_value;
|
151
|
+
uint32_t as_bits;
|
152
|
+
} fp32;
|
153
|
+
fp32.as_value = f;
|
154
|
+
return fp32.as_bits;
|
155
|
+
}
|
156
|
+
|
157
|
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
158
|
+
const uint32_t w = (uint32_t) h << 16;
|
159
|
+
const uint32_t sign = w & UINT32_C(0x80000000);
|
160
|
+
const uint32_t two_w = w + w;
|
161
|
+
|
162
|
+
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
163
|
+
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
|
164
|
+
const float exp_scale = 0x1.0p-112f;
|
165
|
+
#else
|
166
|
+
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
|
167
|
+
#endif
|
168
|
+
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
169
|
+
|
170
|
+
const uint32_t magic_mask = UINT32_C(126) << 23;
|
171
|
+
const float magic_bias = 0.5f;
|
172
|
+
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
173
|
+
|
174
|
+
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
175
|
+
const uint32_t result = sign |
|
176
|
+
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
|
177
|
+
return fp32_from_bits(result);
|
178
|
+
}
|
179
|
+
|
180
|
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
181
|
+
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
|
182
|
+
const float scale_to_inf = 0x1.0p+112f;
|
183
|
+
const float scale_to_zero = 0x1.0p-110f;
|
184
|
+
#else
|
185
|
+
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
|
186
|
+
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
|
187
|
+
#endif
|
188
|
+
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
189
|
+
|
190
|
+
const uint32_t w = fp32_to_bits(f);
|
191
|
+
const uint32_t shl1_w = w + w;
|
192
|
+
const uint32_t sign = w & UINT32_C(0x80000000);
|
193
|
+
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
194
|
+
if (bias < UINT32_C(0x71000000)) {
|
195
|
+
bias = UINT32_C(0x71000000);
|
196
|
+
}
|
197
|
+
|
198
|
+
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
199
|
+
const uint32_t bits = fp32_to_bits(base);
|
200
|
+
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
201
|
+
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
202
|
+
const uint32_t nonsign = exp_bits + mantissa_bits;
|
203
|
+
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
|
204
|
+
}
|
205
|
+
|
206
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
207
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
208
|
+
|
209
|
+
#endif // __F16C__
|
210
|
+
|
211
|
+
#endif // __ARM_NEON
|
212
|
+
|
213
|
+
// precomputed f32 table for f16 (256 KB)
|
214
|
+
// defined in ggml.c, initialized in ggml_init()
|
215
|
+
extern float ggml_table_f32_f16[1 << 16];
|
216
|
+
|
217
|
+
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
218
|
+
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
219
|
+
// This is also true for POWER9.
|
220
|
+
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
|
221
|
+
|
222
|
+
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
223
|
+
uint16_t s;
|
224
|
+
memcpy(&s, &f, sizeof(uint16_t));
|
225
|
+
return ggml_table_f32_f16[s];
|
226
|
+
}
|
227
|
+
|
228
|
+
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
|
229
|
+
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
230
|
+
|
231
|
+
#endif
|
232
|
+
|
233
|
+
// TODO: backend v2 PR
|
234
|
+
|
235
|
+
#ifdef __cplusplus
|
236
|
+
}
|
237
|
+
#endif
|
@@ -62,6 +62,7 @@ struct ggml_metal_context {
|
|
62
62
|
GGML_METAL_DECL_KERNEL(mul);
|
63
63
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
64
64
|
GGML_METAL_DECL_KERNEL(scale);
|
65
|
+
GGML_METAL_DECL_KERNEL(scale_4);
|
65
66
|
GGML_METAL_DECL_KERNEL(silu);
|
66
67
|
GGML_METAL_DECL_KERNEL(relu);
|
67
68
|
GGML_METAL_DECL_KERNEL(gelu);
|
@@ -209,6 +210,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
209
210
|
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
210
211
|
|
211
212
|
NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
213
|
+
if (sourcePath == nil) {
|
214
|
+
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
215
|
+
sourcePath = @"ggml-metal.metal";
|
216
|
+
}
|
212
217
|
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
|
213
218
|
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
|
214
219
|
if (error) {
|
@@ -233,14 +238,17 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
233
238
|
// load kernels
|
234
239
|
{
|
235
240
|
NSError * error = nil;
|
236
|
-
|
237
|
-
|
238
|
-
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
241
|
+
|
242
|
+
/*
|
239
243
|
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
240
244
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
241
245
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
246
|
+
*/
|
247
|
+
#define GGML_METAL_ADD_KERNEL(name) \
|
248
|
+
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
249
|
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
242
250
|
if (error) { \
|
243
|
-
|
251
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
244
252
|
return NULL; \
|
245
253
|
}
|
246
254
|
|
@@ -249,6 +257,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
249
257
|
GGML_METAL_ADD_KERNEL(mul);
|
250
258
|
GGML_METAL_ADD_KERNEL(mul_row);
|
251
259
|
GGML_METAL_ADD_KERNEL(scale);
|
260
|
+
GGML_METAL_ADD_KERNEL(scale_4);
|
252
261
|
GGML_METAL_ADD_KERNEL(silu);
|
253
262
|
GGML_METAL_ADD_KERNEL(relu);
|
254
263
|
GGML_METAL_ADD_KERNEL(gelu);
|
@@ -347,6 +356,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
347
356
|
GGML_METAL_DEL_KERNEL(mul);
|
348
357
|
GGML_METAL_DEL_KERNEL(mul_row);
|
349
358
|
GGML_METAL_DEL_KERNEL(scale);
|
359
|
+
GGML_METAL_DEL_KERNEL(scale_4);
|
350
360
|
GGML_METAL_DEL_KERNEL(silu);
|
351
361
|
GGML_METAL_DEL_KERNEL(relu);
|
352
362
|
GGML_METAL_DEL_KERNEL(gelu);
|
@@ -923,15 +933,20 @@ void ggml_metal_graph_compute(
|
|
923
933
|
|
924
934
|
const float scale = *(const float *) src1->data;
|
925
935
|
|
926
|
-
|
936
|
+
int64_t n = ggml_nelements(dst);
|
937
|
+
|
938
|
+
if (n % 4 == 0) {
|
939
|
+
n /= 4;
|
940
|
+
[encoder setComputePipelineState:ctx->pipeline_scale_4];
|
941
|
+
} else {
|
942
|
+
[encoder setComputePipelineState:ctx->pipeline_scale];
|
943
|
+
}
|
944
|
+
|
927
945
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
928
946
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
929
947
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
930
948
|
|
931
|
-
|
932
|
-
GGML_ASSERT(n % 4 == 0);
|
933
|
-
|
934
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
949
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
935
950
|
} break;
|
936
951
|
case GGML_OP_UNARY:
|
937
952
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
@@ -986,11 +1001,15 @@ void ggml_metal_graph_compute(
|
|
986
1001
|
} break;
|
987
1002
|
case GGML_OP_SOFT_MAX:
|
988
1003
|
{
|
989
|
-
|
1004
|
+
int nth = 32; // SIMD width
|
990
1005
|
|
991
1006
|
if (ne00%4 == 0) {
|
992
1007
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
993
1008
|
} else {
|
1009
|
+
do {
|
1010
|
+
nth *= 2;
|
1011
|
+
} while (nth <= ne00 && nth <= 1024);
|
1012
|
+
nth /= 2;
|
994
1013
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
995
1014
|
}
|
996
1015
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -998,8 +1017,9 @@ void ggml_metal_graph_compute(
|
|
998
1017
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
999
1018
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1000
1019
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1020
|
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
1001
1021
|
|
1002
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01,
|
1022
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1003
1023
|
} break;
|
1004
1024
|
case GGML_OP_DIAG_MASK_INF:
|
1005
1025
|
{
|
@@ -1380,14 +1400,18 @@ void ggml_metal_graph_compute(
|
|
1380
1400
|
|
1381
1401
|
const int nth = MIN(1024, ne00);
|
1382
1402
|
|
1383
|
-
const int n_past
|
1384
|
-
const int n_dims
|
1385
|
-
const int mode
|
1403
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
1404
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1405
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
1406
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
1386
1407
|
|
1387
|
-
float freq_base;
|
1388
|
-
float
|
1389
|
-
memcpy(&
|
1390
|
-
memcpy(&
|
1408
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
1409
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
1410
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
1411
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
1412
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
1413
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
1414
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
1391
1415
|
|
1392
1416
|
switch (src0->type) {
|
1393
1417
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
@@ -1395,30 +1419,35 @@ void ggml_metal_graph_compute(
|
|
1395
1419
|
default: GGML_ASSERT(false);
|
1396
1420
|
};
|
1397
1421
|
|
1398
|
-
[encoder setBuffer:id_src0
|
1399
|
-
[encoder setBuffer:id_src1
|
1400
|
-
[encoder setBuffer:id_dst
|
1401
|
-
[encoder setBytes:&ne00
|
1402
|
-
[encoder setBytes:&ne01
|
1403
|
-
[encoder setBytes:&ne02
|
1404
|
-
[encoder setBytes:&ne03
|
1405
|
-
[encoder setBytes:&nb00
|
1406
|
-
[encoder setBytes:&nb01
|
1407
|
-
[encoder setBytes:&nb02
|
1408
|
-
[encoder setBytes:&nb03
|
1409
|
-
[encoder setBytes:&ne0
|
1410
|
-
[encoder setBytes:&ne1
|
1411
|
-
[encoder setBytes:&ne2
|
1412
|
-
[encoder setBytes:&ne3
|
1413
|
-
[encoder setBytes:&nb0
|
1414
|
-
[encoder setBytes:&nb1
|
1415
|
-
[encoder setBytes:&nb2
|
1416
|
-
[encoder setBytes:&nb3
|
1417
|
-
[encoder setBytes:&n_past
|
1418
|
-
[encoder setBytes:&n_dims
|
1419
|
-
[encoder setBytes:&mode
|
1420
|
-
[encoder setBytes:&
|
1421
|
-
[encoder setBytes:&
|
1422
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1423
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1424
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1425
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1426
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1427
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1428
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1429
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1430
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1431
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1432
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1433
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1434
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1435
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1436
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1437
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1438
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1439
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1440
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1441
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1442
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1443
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1444
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
1445
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
1446
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
1447
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
1448
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
1449
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
1450
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
1422
1451
|
|
1423
1452
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1424
1453
|
} break;
|