cui-llama.rn 1.3.5 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -20
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +108 -37
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +1982 -1965
  10. package/cpp/common.h +665 -657
  11. package/cpp/ggml-backend-reg.cpp +5 -0
  12. package/cpp/ggml-backend.cpp +5 -2
  13. package/cpp/ggml-cpp.h +1 -0
  14. package/cpp/ggml-cpu-aarch64.cpp +6 -1
  15. package/cpp/ggml-cpu-quants.c +5 -1
  16. package/cpp/ggml-cpu.c +14122 -14122
  17. package/cpp/ggml-cpu.cpp +627 -627
  18. package/cpp/ggml-impl.h +11 -16
  19. package/cpp/ggml-metal-impl.h +288 -0
  20. package/cpp/ggml-metal.m +2 -2
  21. package/cpp/ggml-opt.cpp +854 -0
  22. package/cpp/ggml-opt.h +216 -0
  23. package/cpp/ggml.c +0 -1276
  24. package/cpp/ggml.h +0 -140
  25. package/cpp/gguf.cpp +1325 -0
  26. package/cpp/gguf.h +202 -0
  27. package/cpp/llama-adapter.cpp +346 -0
  28. package/cpp/llama-adapter.h +73 -0
  29. package/cpp/llama-arch.cpp +1434 -0
  30. package/cpp/llama-arch.h +395 -0
  31. package/cpp/llama-batch.cpp +368 -0
  32. package/cpp/llama-batch.h +88 -0
  33. package/cpp/llama-chat.cpp +567 -0
  34. package/cpp/llama-chat.h +51 -0
  35. package/cpp/llama-context.cpp +1771 -0
  36. package/cpp/llama-context.h +128 -0
  37. package/cpp/llama-cparams.cpp +1 -0
  38. package/cpp/llama-cparams.h +37 -0
  39. package/cpp/llama-cpp.h +30 -0
  40. package/cpp/llama-grammar.cpp +1 -0
  41. package/cpp/llama-grammar.h +3 -1
  42. package/cpp/llama-hparams.cpp +71 -0
  43. package/cpp/llama-hparams.h +140 -0
  44. package/cpp/llama-impl.cpp +167 -0
  45. package/cpp/llama-impl.h +16 -136
  46. package/cpp/llama-kv-cache.cpp +718 -0
  47. package/cpp/llama-kv-cache.h +218 -0
  48. package/cpp/llama-mmap.cpp +589 -0
  49. package/cpp/llama-mmap.h +67 -0
  50. package/cpp/llama-model-loader.cpp +1011 -0
  51. package/cpp/llama-model-loader.h +158 -0
  52. package/cpp/llama-model.cpp +2202 -0
  53. package/cpp/llama-model.h +391 -0
  54. package/cpp/llama-sampling.cpp +117 -4
  55. package/cpp/llama-vocab.cpp +21 -28
  56. package/cpp/llama-vocab.h +13 -1
  57. package/cpp/llama.cpp +12547 -23528
  58. package/cpp/llama.h +31 -6
  59. package/cpp/rn-llama.hpp +90 -87
  60. package/cpp/sgemm.cpp +776 -70
  61. package/cpp/sgemm.h +14 -14
  62. package/cpp/unicode.cpp +6 -0
  63. package/ios/RNLlama.mm +47 -0
  64. package/ios/RNLlamaContext.h +3 -1
  65. package/ios/RNLlamaContext.mm +71 -14
  66. package/jest/mock.js +15 -3
  67. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  68. package/lib/commonjs/index.js +33 -37
  69. package/lib/commonjs/index.js.map +1 -1
  70. package/lib/module/NativeRNLlama.js.map +1 -1
  71. package/lib/module/index.js +31 -35
  72. package/lib/module/index.js.map +1 -1
  73. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  74. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  75. package/lib/typescript/index.d.ts +21 -36
  76. package/lib/typescript/index.d.ts.map +1 -1
  77. package/llama-rn.podspec +4 -18
  78. package/package.json +2 -3
  79. package/src/NativeRNLlama.ts +32 -13
  80. package/src/index.ts +52 -47
package/cpp/ggml-impl.h CHANGED
@@ -3,6 +3,8 @@
3
3
  // GGML internal header
4
4
 
5
5
  #include "ggml.h"
6
+ #include "gguf.h"
7
+
6
8
  #include <assert.h>
7
9
  #include <math.h>
8
10
  #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
@@ -551,22 +553,15 @@ static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) {
551
553
  #define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x)
552
554
  #define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x)
553
555
 
554
- // expose GGUF internals for test code
555
-
556
- LM_GGML_API size_t lm_gguf_type_size(enum lm_gguf_type type);
557
-
558
- LM_GGML_API struct lm_gguf_context * lm_gguf_init_from_file_impl(FILE * file, struct lm_gguf_init_params params);
559
-
560
- struct lm_gguf_buf {
561
- void * data;
562
- size_t size;
563
- size_t offset;
564
- };
565
- LM_GGML_API struct lm_gguf_buf lm_gguf_buf_init(size_t size);
566
- LM_GGML_API void lm_gguf_buf_free(struct lm_gguf_buf buf);
567
-
568
- LM_GGML_API void lm_gguf_write_to_buf(const struct lm_gguf_context * ctx, struct lm_gguf_buf * buf, bool only_meta);
569
-
570
556
  #ifdef __cplusplus
571
557
  }
572
558
  #endif
559
+
560
+ #ifdef __cplusplus
561
+ #include <vector>
562
+
563
+ // expose GGUF internals for test code
564
+ LM_GGML_API size_t lm_gguf_type_size(enum lm_gguf_type type);
565
+ LM_GGML_API struct lm_gguf_context * lm_gguf_init_from_file_impl(FILE * file, struct lm_gguf_init_params params);
566
+ LM_GGML_API void lm_gguf_write_to_buf(const struct lm_gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta);
567
+ #endif // __cplusplus
@@ -0,0 +1,288 @@
1
+ #ifndef LM_GGML_METAL_IMPL
2
+ #define LM_GGML_METAL_IMPL
3
+
4
+ // kernel argument structs
5
+ //
6
+ // - element counters (e.g. ne00) typically use int32_t to reduce register usage
7
+ // however, be careful from int overflows when using those in the kernel implementation
8
+ //
9
+ // - strides (e.g. nb00) use uint64_t
10
+
11
+ typedef struct {
12
+ int32_t ne00;
13
+ int32_t ne01;
14
+ int32_t ne02;
15
+ int32_t ne03;
16
+ uint64_t nb00;
17
+ uint64_t nb01;
18
+ uint64_t nb02;
19
+ uint64_t nb03;
20
+ int32_t ne10;
21
+ int32_t ne11;
22
+ int32_t ne12;
23
+ int32_t ne13;
24
+ uint64_t nb10;
25
+ uint64_t nb11;
26
+ uint64_t nb12;
27
+ uint64_t nb13;
28
+ int32_t ne0;
29
+ int32_t ne1;
30
+ int32_t ne2;
31
+ int32_t ne3;
32
+ uint64_t nb0;
33
+ uint64_t nb1;
34
+ uint64_t nb2;
35
+ uint64_t nb3;
36
+ int32_t dim;
37
+ } lm_ggml_metal_kargs_concat;
38
+
39
+ typedef struct {
40
+ int32_t ne00;
41
+ int32_t ne01;
42
+ int32_t ne02;
43
+ int32_t ne03;
44
+ uint64_t nb00;
45
+ uint64_t nb01;
46
+ uint64_t nb02;
47
+ uint64_t nb03;
48
+ int32_t ne10;
49
+ int32_t ne11;
50
+ int32_t ne12;
51
+ int32_t ne13;
52
+ uint64_t nb10;
53
+ uint64_t nb11;
54
+ uint64_t nb12;
55
+ uint64_t nb13;
56
+ int32_t ne0;
57
+ int32_t ne1;
58
+ int32_t ne2;
59
+ int32_t ne3;
60
+ uint64_t nb0;
61
+ uint64_t nb1;
62
+ uint64_t nb2;
63
+ uint64_t nb3;
64
+ uint64_t offs;
65
+ } lm_ggml_metal_kargs_bin;
66
+
67
+ typedef struct {
68
+ int32_t ne00;
69
+ int32_t ne01;
70
+ int32_t ne02;
71
+ int32_t ne03;
72
+ uint64_t nb00;
73
+ uint64_t nb01;
74
+ uint64_t nb02;
75
+ uint64_t nb03;
76
+ int32_t ne0;
77
+ int32_t ne1;
78
+ int32_t ne2;
79
+ int32_t ne3;
80
+ uint64_t nb0;
81
+ uint64_t nb1;
82
+ uint64_t nb2;
83
+ uint64_t nb3;
84
+ } lm_ggml_metal_kargs_repeat;
85
+
86
+ typedef struct {
87
+ int64_t ne00;
88
+ int64_t ne01;
89
+ int64_t ne02;
90
+ int64_t ne03;
91
+ uint64_t nb00;
92
+ uint64_t nb01;
93
+ uint64_t nb02;
94
+ uint64_t nb03;
95
+ int64_t ne0;
96
+ int64_t ne1;
97
+ int64_t ne2;
98
+ int64_t ne3;
99
+ uint64_t nb0;
100
+ uint64_t nb1;
101
+ uint64_t nb2;
102
+ uint64_t nb3;
103
+ } lm_ggml_metal_kargs_cpy;
104
+
105
+ typedef struct {
106
+ int64_t ne10;
107
+ int64_t ne11;
108
+ int64_t ne12;
109
+ uint64_t nb10;
110
+ uint64_t nb11;
111
+ uint64_t nb12;
112
+ uint64_t nb13;
113
+ uint64_t nb1;
114
+ uint64_t nb2;
115
+ uint64_t nb3;
116
+ uint64_t offs;
117
+ bool inplace;
118
+ } lm_ggml_metal_kargs_set;
119
+
120
+ typedef struct {
121
+ int32_t ne00;
122
+ int32_t ne01;
123
+ int32_t ne02;
124
+ int32_t ne03;
125
+ uint64_t nb00;
126
+ uint64_t nb01;
127
+ uint64_t nb02;
128
+ uint64_t nb03;
129
+ int32_t ne0;
130
+ int32_t ne1;
131
+ int32_t ne2;
132
+ int32_t ne3;
133
+ uint64_t nb0;
134
+ uint64_t nb1;
135
+ uint64_t nb2;
136
+ uint64_t nb3;
137
+ int32_t n_past;
138
+ int32_t n_dims;
139
+ int32_t n_ctx_orig;
140
+ float freq_base;
141
+ float freq_scale;
142
+ float ext_factor;
143
+ float attn_factor;
144
+ float beta_fast;
145
+ float beta_slow;
146
+ } lm_ggml_metal_kargs_rope;
147
+
148
+ typedef struct {
149
+ int32_t ne01;
150
+ int32_t ne02;
151
+ int32_t ne03;
152
+ uint64_t nb01;
153
+ uint64_t nb02;
154
+ uint64_t nb03;
155
+ int32_t ne11;
156
+ int32_t ne_12_2; // assume K and V are same shape
157
+ int32_t ne_12_3;
158
+ uint64_t nb_12_1;
159
+ uint64_t nb_12_2;
160
+ uint64_t nb_12_3;
161
+ uint64_t nb31;
162
+ int32_t ne1;
163
+ int32_t ne2;
164
+ float scale;
165
+ float max_bias;
166
+ float m0;
167
+ float m1;
168
+ uint16_t n_head_log2;
169
+ float logit_softcap;
170
+ } lm_ggml_metal_kargs_flash_attn_ext;
171
+
172
+ typedef struct {
173
+ int32_t ne00;
174
+ int32_t ne02;
175
+ uint64_t nb01;
176
+ uint64_t nb02;
177
+ uint64_t nb03;
178
+ int32_t ne12;
179
+ uint64_t nb10;
180
+ uint64_t nb11;
181
+ uint64_t nb12;
182
+ uint64_t nb13;
183
+ int32_t ne0;
184
+ int32_t ne1;
185
+ int16_t r2;
186
+ int16_t r3;
187
+ } lm_ggml_metal_kargs_mul_mm;
188
+
189
+ typedef struct {
190
+ int32_t ne00;
191
+ int32_t ne01;
192
+ int32_t ne02;
193
+ uint64_t nb00;
194
+ uint64_t nb01;
195
+ uint64_t nb02;
196
+ uint64_t nb03;
197
+ int32_t ne10;
198
+ int32_t ne11;
199
+ int32_t ne12;
200
+ uint64_t nb10;
201
+ uint64_t nb11;
202
+ uint64_t nb12;
203
+ uint64_t nb13;
204
+ int32_t ne0;
205
+ int32_t ne1;
206
+ int16_t r2;
207
+ int16_t r3;
208
+ } lm_ggml_metal_kargs_mul_mv;
209
+
210
+ typedef struct {
211
+ int32_t ne00;
212
+ int32_t ne01;
213
+ int32_t ne02;
214
+ uint64_t nb00;
215
+ uint64_t nb01;
216
+ uint64_t nb02;
217
+ uint64_t nb03;
218
+ int32_t ne10;
219
+ int32_t ne11;
220
+ int32_t ne12;
221
+ uint64_t nb10;
222
+ uint64_t nb11;
223
+ uint64_t nb12;
224
+ uint64_t nb13;
225
+ int32_t ne0;
226
+ int32_t ne1;
227
+ int16_t r2;
228
+ int16_t r3;
229
+ int16_t nsg;
230
+ int16_t nxpsg;
231
+ int16_t r1ptg;
232
+ } lm_ggml_metal_kargs_mul_mv_ext;
233
+
234
+ typedef struct {
235
+ int32_t nei0;
236
+ int32_t nei1;
237
+ uint64_t nbi1;
238
+ int32_t ne00;
239
+ int32_t ne02;
240
+ uint64_t nb01;
241
+ uint64_t nb02;
242
+ int32_t ne11;
243
+ int32_t ne12;
244
+ int32_t ne13;
245
+ uint64_t nb10;
246
+ uint64_t nb11;
247
+ uint64_t nb12;
248
+ int32_t ne0;
249
+ int32_t ne1;
250
+ } lm_ggml_metal_kargs_mul_mm_id;
251
+
252
+ typedef struct {
253
+ int32_t nei0;
254
+ int32_t nei1;
255
+ uint64_t nbi1;
256
+ int32_t ne00;
257
+ int32_t ne01;
258
+ int32_t ne02;
259
+ uint64_t nb00;
260
+ uint64_t nb01;
261
+ uint64_t nb02;
262
+ int32_t ne10;
263
+ int32_t ne11;
264
+ int32_t ne12;
265
+ int32_t ne13;
266
+ uint64_t nb10;
267
+ uint64_t nb11;
268
+ uint64_t nb12;
269
+ int32_t ne0;
270
+ int32_t ne1;
271
+ uint64_t nb1;
272
+ } lm_ggml_metal_kargs_mul_mv_id;
273
+
274
+ typedef struct {
275
+ int32_t ne00;
276
+ int32_t ne00_4;
277
+ uint64_t nb01;
278
+ float eps;
279
+ } lm_ggml_metal_kargs_norm;
280
+
281
+ typedef struct {
282
+ int32_t ne00;
283
+ int32_t ne00_4;
284
+ uint64_t nb01;
285
+ float eps;
286
+ } lm_ggml_metal_kargs_rms_norm;
287
+
288
+ #endif // LM_GGML_METAL_IMPL
package/cpp/ggml-metal.m CHANGED
@@ -2067,8 +2067,8 @@ static void lm_ggml_metal_encode_node(
2067
2067
  LM_GGML_ASSERT(ne12 % ne02 == 0);
2068
2068
  LM_GGML_ASSERT(ne13 % ne03 == 0);
2069
2069
 
2070
- const uint r2 = ne12/ne02;
2071
- const uint r3 = ne13/ne03;
2070
+ const uint32_t r2 = ne12/ne02;
2071
+ const uint32_t r3 = ne13/ne03;
2072
2072
 
2073
2073
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2074
2074
  // to the matrix-vector kernel