@fugood/llama.node 1.2.3 → 1.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/package.json +14 -14
  2. package/scripts/llama.cpp.patch +33 -11
  3. package/src/llama.cpp/CMakeLists.txt +1 -0
  4. package/src/llama.cpp/common/CMakeLists.txt +46 -2
  5. package/src/llama.cpp/common/arg.cpp +484 -204
  6. package/src/llama.cpp/common/arg.h +0 -1
  7. package/src/llama.cpp/common/chat-parser.cpp +156 -15
  8. package/src/llama.cpp/common/chat-parser.h +3 -0
  9. package/src/llama.cpp/common/chat.cpp +217 -6
  10. package/src/llama.cpp/common/chat.h +5 -3
  11. package/src/llama.cpp/common/common.cpp +22 -6
  12. package/src/llama.cpp/common/common.h +6 -4
  13. package/src/llama.cpp/common/http.h +73 -0
  14. package/src/llama.cpp/common/json-partial.cpp +51 -0
  15. package/src/llama.cpp/ggml/CMakeLists.txt +7 -6
  16. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  17. package/src/llama.cpp/ggml/include/ggml-rpc.h +8 -9
  18. package/src/llama.cpp/ggml/include/ggml.h +22 -0
  19. package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
  21. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +12 -12
  23. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +100 -3
  24. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
  25. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1 -0
  26. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -0
  27. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +209 -96
  28. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +32 -44
  29. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +107 -83
  30. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +17 -17
  31. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +8 -8
  32. package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +103 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +1 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +66 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +11 -9
  40. package/src/llama.cpp/include/llama.h +8 -0
  41. package/src/llama.cpp/src/llama-arch.cpp +93 -0
  42. package/src/llama.cpp/src/llama-arch.h +22 -0
  43. package/src/llama.cpp/src/llama-chat.cpp +1 -1
  44. package/src/llama.cpp/src/llama-context.cpp +6 -0
  45. package/src/llama.cpp/src/llama-graph.cpp +57 -22
  46. package/src/llama.cpp/src/llama-graph.h +10 -1
  47. package/src/llama.cpp/src/llama-hparams.cpp +5 -1
  48. package/src/llama.cpp/src/llama-hparams.h +17 -2
  49. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +2 -2
  50. package/src/llama.cpp/src/llama-kv-cache.cpp +2 -5
  51. package/src/llama.cpp/src/llama-memory-hybrid.cpp +11 -9
  52. package/src/llama.cpp/src/llama-memory-recurrent.cpp +11 -3
  53. package/src/llama.cpp/src/llama-model-loader.cpp +2 -0
  54. package/src/llama.cpp/src/llama-model.cpp +572 -45
  55. package/src/llama.cpp/src/llama-model.h +18 -0
  56. package/src/llama.cpp/src/llama-sampling.cpp +5 -0
  57. package/src/llama.cpp/src/llama-vocab.cpp +7 -1
  58. package/src/llama.cpp/src/llama-vocab.h +41 -40
  59. package/src/llama.cpp/src/unicode.h +43 -0
@@ -8,6 +8,7 @@
8
8
  #include <stdexcept>
9
9
  #include <stdint.h>
10
10
  #include <string.h>
11
+ #include <string>
11
12
  #if defined(__linux__)
12
13
  #include <asm/hwcap.h>
13
14
  #include <sys/auxv.h>
@@ -87,17 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
87
88
  return tensor->ne[dim];
88
89
  }
89
90
 
90
- template<typename Ret, typename Variant, typename... Args>
91
- static Ret variant_call(const Variant & var, Args&&... args) {
92
- return std::visit([&](auto&& func) -> Ret {
93
- if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
94
- return func(std::forward<Args>(args)...);
95
- } else {
96
- throw std::runtime_error("Invalid function type in variant_call");
97
- }
98
- }, var);
99
- }
100
-
101
91
  namespace ggml::cpu::kleidiai {
102
92
 
103
93
  static size_t round_down(size_t x, size_t y) {
@@ -122,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
122
112
  return false;
123
113
  }
124
114
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
125
- GGML_ASSERT(kernels);
115
+ if (!kernels) {
116
+ return false;
117
+ }
126
118
  bool is_gemv = op->src[1]->ne[1] == 1;
127
119
  kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
128
120
  lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
@@ -136,19 +128,23 @@ class tensor_traits : public ggml::cpu::tensor_traits {
136
128
  size_t sr = kernel->get_sr();
137
129
 
138
130
  if (kernels->rhs_type == GGML_TYPE_Q4_0) {
139
- size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
131
+ if (!lhs_info->packed_size_ex) return false;
132
+ size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
140
133
  } else if (kernels->rhs_type == GGML_TYPE_F16) {
141
- size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
142
- variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
134
+ if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
135
+ const int64_t lhs_batch_size0 = op->src[1]->ne[2];
136
+ const int64_t rhs_batch_size0 = op->src[0]->ne[2];
137
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
138
+ size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
139
+ kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
143
140
  k * n * sizeof(float) + n * sizeof(float);
144
141
  } else {
145
- GGML_ASSERT(false);
142
+ return false;
146
143
  }
147
144
 
148
145
  return true;
149
146
  }
150
147
 
151
-
152
148
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
153
149
  if (dst->op == GGML_OP_MUL_MAT) {
154
150
  if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@@ -165,45 +161,52 @@ class tensor_traits : public ggml::cpu::tensor_traits {
165
161
  }
166
162
 
167
163
  bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
168
- static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
169
-
170
164
  const ggml_tensor * src0 = dst->src[0];
171
165
  const ggml_tensor * src1 = dst->src[1];
172
166
 
173
167
  GGML_TENSOR_BINARY_OP_LOCALS
174
168
 
175
169
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
176
- GGML_ASSERT(kernels);
170
+ if (!kernels) {
171
+ return false;
172
+ }
177
173
 
178
- bool is_gemv = src1->ne[1] == 1;
174
+ const bool is_gemv = src1->ne[1] == 1;
179
175
  kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
180
176
  lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
181
177
  GGML_ASSERT(kernel);
178
+ if (!kernels->rhs_info.pack_func_ex ||
179
+ !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
180
+ return false;
181
+ }
182
182
 
183
183
  const int nth = params->nth;
184
184
  const int ith = params->ith;
185
185
 
186
186
  const int64_t lhs_batch_size0 = ne12;
187
187
  const int64_t rhs_batch_size0 = ne02;
188
- const int64_t batch_size = rhs_batch_size0;
188
+ const int64_t batch_size = lhs_batch_size0;
189
189
 
190
+ GGML_ASSERT(rhs_batch_size0 > 0);
191
+ GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
190
192
  const int64_t r = lhs_batch_size0 / rhs_batch_size0;
191
193
 
192
- const int64_t m = ne11 * r;
193
- const int64_t n = ne01;
194
- const int64_t k = ne00;
194
+ const int64_t m_group = ne11;
195
+ const int64_t m = m_group;
196
+ const int64_t n = ne01;
197
+ const int64_t k = ne00;
195
198
 
196
199
  const size_t lhs_stride = src1->nb[1];
197
200
  const size_t rhs_stride = src0->nb[1];
198
201
  const size_t dst_stride = dst->nb[1];
199
202
 
200
- const int64_t mr = static_cast<int64_t>(kernel->get_mr());
201
- const int64_t nr = static_cast<int64_t>(kernel->get_nr());
202
- const int64_t kr = static_cast<int64_t>(kernel->get_kr());
203
- const int64_t sr = static_cast<int64_t>(kernel->get_sr());
203
+ const int64_t mr = (int64_t) kernel->get_mr();
204
+ const int64_t nr = (int64_t) kernel->get_nr();
205
+ const int64_t kr = (int64_t) kernel->get_kr();
206
+ const int64_t sr = (int64_t) kernel->get_sr();
204
207
 
205
- const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
206
- const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
208
+ const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
209
+ const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
207
210
  const size_t kxn_size = k * n * sizeof(float);
208
211
  const size_t bias_size = n * sizeof(float);
209
212
 
@@ -216,82 +219,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
216
219
  uint8_t * bias = rhs_kxn + kxn_size;
217
220
 
218
221
  for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
219
- const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
220
- const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
221
- uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
222
+ const int64_t rhs_batch_idx = batch_idx / r;
223
+ const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
224
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
222
225
 
223
- // LHS packing
226
+ // LHS packing (threaded over m, honoring mr alignment and KV groups)
224
227
  {
225
228
  const int64_t m_roundup_mr = kai_roundup(m, mr);
226
229
  const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
227
230
 
228
231
  if (ith < num_threads) {
229
- const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
232
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
230
233
  const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
231
234
 
232
- const int64_t m_start = ith * num_m_per_thread0;
233
- const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
235
+ const int64_t m_start = ith * num_m_per_thread0;
236
+ const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
237
+
238
+ // Base packed offset (aligned) and per-row stride in bytes
239
+ const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
240
+ const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
241
+ const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
234
242
 
235
- const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
236
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
243
+ int64_t remaining = m_count;
244
+ int64_t cur = m_start;
237
245
 
238
- const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
239
- void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
246
+ while (remaining > 0) {
247
+ const int64_t row_in_group = cur;
248
+ const int64_t avail = m_group - row_in_group;
249
+ const int64_t take = std::min(avail, remaining);
240
250
 
241
- variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
251
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
252
+ const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
253
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
254
+ void * dst_ptr = lhs_packed + dst_off;
255
+
256
+ lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
257
+
258
+ cur += take;
259
+ remaining -= take;
260
+ }
242
261
  }
243
262
  }
244
263
 
245
- // RHS packing
246
- if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
247
- // First thread to reach this point handles RHS packing
248
- memset(bias, 0, n * sizeof(float));
249
- transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
250
- reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
264
+ // RHS packing (single thread), then synchronize
265
+ if (ith == 0) {
266
+ memset(bias, 0, (size_t)n * sizeof(float));
267
+ transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
268
+ reinterpret_cast<float *>(rhs_kxn),
269
+ reinterpret_cast<const uint16_t *>(rhs_batch_base),
270
+ rhs_stride);
251
271
 
252
- variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
272
+ kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
253
273
  rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
254
274
  }
255
275
 
256
276
  ggml_barrier(params->threadpool);
257
277
 
258
- first_to_arrive.clear(std::memory_order_release);
259
-
260
- // Perform the matmul
278
+ // Matmul (threaded over n)
261
279
  {
262
- const int64_t m_to_process = m;
263
- const int64_t m_start = 0;
264
-
265
- const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
266
- int64_t num_threads = KAI_MIN(n / n_step, nth);
267
- if (num_threads <= 0) {
268
- num_threads = 1;
280
+ const int64_t n_step = (int64_t) kernel->get_n_step();
281
+ int64_t num_threads_n = KAI_MIN(n / n_step, nth);
282
+ if (num_threads_n <= 0) {
283
+ num_threads_n = 1;
269
284
  }
270
285
 
271
- if (ith < num_threads) {
272
- const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
273
- const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
286
+ if (ith < num_threads_n) {
287
+ const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
288
+ const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
274
289
 
275
290
  const int64_t n_start = ith * num_n_per_thread0;
276
- const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
291
+ const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
277
292
 
278
- const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
279
- const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
280
- const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
293
+ // LHS packed base at row 0 (consistent with packing above)
294
+ const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
295
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
296
+ const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
281
297
 
282
- const void * lhs_ptr = lhs_packed + lhs_packed_offset;
298
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
283
299
  const void * rhs_ptr = rhs_packed + rhs_packed_offset;
284
- float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
300
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
285
301
 
286
- variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
302
+ kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
287
303
  }
288
304
  }
289
305
 
290
306
  if (batch_idx != batch_size - 1) {
291
- // This barrier is necessary when the batch size is larger than 1. While processing a batch,
292
- // the work data buffer (params->wdata) is used as temporary storage which means that only
293
- // a single batch can be processed at any given time. No barrier is needed for the last
294
- // batch since GGML inserts a barrier between the execution of every operator.
295
307
  ggml_barrier(params->threadpool);
296
308
  }
297
309
  }
@@ -308,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
308
320
  GGML_TENSOR_BINARY_OP_LOCALS
309
321
 
310
322
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
311
- GGML_ASSERT(kernels);
323
+ if (!kernels) {
324
+ return false;
325
+ }
312
326
 
313
327
  bool is_gemv = src1->ne[1] == 1;
314
328
  kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
315
329
  lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
316
330
 
317
331
  GGML_ASSERT(kernel);
332
+ if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
333
+ !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
334
+ return false;
335
+ }
318
336
 
319
337
  const int ith = params->ith;
320
338
  const int nth_raw = params->nth;
@@ -356,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits {
356
374
  // Transform LHS
357
375
  const size_t src_stride = src1->nb[1];
358
376
  const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
359
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
377
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
360
378
  void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
361
379
 
362
- variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
380
+ // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
381
+ lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
363
382
  }
364
383
 
365
384
  ggml_barrier(params->threadpool);
366
385
 
367
386
  // Perform the operation
368
387
  const size_t dst_stride = dst->nb[1];
369
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
370
- const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
388
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
389
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
371
390
  const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
372
391
  const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
373
392
  const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
374
393
  float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
375
394
 
376
395
  if (n_to_process > 0) {
377
- variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
396
+ kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
378
397
  sizeof(float), -FLT_MAX, FLT_MAX);
379
398
  }
380
399
 
@@ -383,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
383
402
 
384
403
  bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
385
404
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
386
- GGML_ASSERT(ctx.kernels);
405
+ if (!ctx.kernels) {
406
+ return false;
407
+ }
387
408
 
388
409
  const ggml_tensor * src0 = dst->src[0];
389
410
  const ggml_tensor * src1 = dst->src[1];
@@ -392,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
392
413
 
393
414
  rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
394
415
  kernel_info * kernel = &ctx.kernels->gemm;
416
+ if (!rhs_info->to_float || !kernel->get_nr) {
417
+ return false;
418
+ }
395
419
 
396
420
  const int64_t nc = ne00;
397
421
  const int64_t nr = ggml_nelements(src1);
@@ -434,7 +458,7 @@ public:
434
458
  struct kai_rhs_pack_qs4cxs1s0_param params;
435
459
  params.lhs_zero_point = 1;
436
460
  params.rhs_zero_point = 8;
437
- variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
461
+ ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
438
462
 
439
463
  return 0;
440
464
  GGML_UNUSED(data_size);
@@ -502,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_
502
526
  const size_t nr = ctx.kernels->gemm.get_nr();
503
527
  const size_t kr = ctx.kernels->gemm.get_kr();
504
528
 
505
- return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
529
+ return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
506
530
 
507
531
  GGML_UNUSED(buft);
508
532
  }
@@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
3467
3467
 
3468
3468
  GGML_ASSERT(eps >= 0.0f);
3469
3469
 
3470
- // TODO: optimize
3471
3470
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3472
3471
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3473
3472
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3474
3473
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3475
3474
 
3476
- ggml_float sum = 0.0;
3477
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3478
- sum += (ggml_float)x[i00];
3479
- }
3480
-
3475
+ float sum = 0.0;
3476
+ ggml_vec_sum_f32(ne00, &sum, x);
3481
3477
  float mean = sum/ne00;
3482
3478
 
3483
3479
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3480
+ float variance = 0;
3484
3481
 
3485
- ggml_float sum2 = 0.0;
3486
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3487
- float v = x[i00] - mean;
3488
- y[i00] = v;
3489
- sum2 += (ggml_float)(v*v);
3490
- }
3482
+ #ifdef GGML_USE_ACCELERATE
3483
+ mean = -mean;
3484
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3485
+ vDSP_measqv(y, 1, &variance, ne00);
3486
+ #else
3487
+ variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3488
+ #endif //GGML_USE_ACCELERATE
3491
3489
 
3492
- float variance = sum2/ne00;
3493
3490
  const float scale = 1.0f/sqrtf(variance + eps);
3494
-
3495
3491
  ggml_vec_scale_f32(ne00, y, scale);
3496
3492
  }
3497
3493
  }
@@ -8135,7 +8131,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8135
8131
  }
8136
8132
 
8137
8133
  // V /= S
8138
- const float S_inv = 1.0f/S;
8134
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8139
8135
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
8140
8136
 
8141
8137
  // dst indices
@@ -8637,7 +8633,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8637
8633
  // n_head
8638
8634
  for (int h = ih0; h < ih1; ++h) {
8639
8635
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8636
+ const float dt_soft_plus = ggml_softplus(dt[h]);
8641
8637
  const float dA = expf(dt_soft_plus * A[h]);
8642
8638
  const int g = h / (nh / ng); // repeat_interleave
8643
8639
 
@@ -8734,7 +8730,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8734
8730
  // n_head
8735
8731
  for (int h = ih0; h < ih1; ++h) {
8736
8732
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8733
+ const float dt_soft_plus = ggml_softplus(dt[h]);
8738
8734
  const int g = h / (nh / ng); // repeat_interleave
8739
8735
 
8740
8736
  // dim
@@ -8997,6 +8993,10 @@ void ggml_compute_forward_unary(
8997
8993
  {
8998
8994
  ggml_compute_forward_exp(params, dst);
8999
8995
  } break;
8996
+ case GGML_UNARY_OP_XIELU:
8997
+ {
8998
+ ggml_compute_forward_xielu(params, dst);
8999
+ } break;
9000
9000
  default:
9001
9001
  {
9002
9002
  GGML_ABORT("fatal error");
@@ -998,9 +998,9 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
998
998
  #define GGML_F32_EPR 4
999
999
 
1000
1000
  #define GGML_F32x4 __m128
1001
- #define GGML_F32x4_ZERO __lsx_vldi(0)
1002
- #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1003
- #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1001
+ #define GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
1002
+ #define GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1003
+ #define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
1004
1004
  #define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
1005
1005
  #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1006
1006
  #define GGML_F32x4_ADD __lsx_vfadd_s
@@ -1022,7 +1022,7 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1022
1022
  __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
1023
1023
  tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
1024
1024
  tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1025
- const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1025
+ const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
1026
1026
  tmp = __lsx_vsrli_d((__m128i) t0, 32); \
1027
1027
  tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
1028
1028
  tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
@@ -1052,7 +1052,7 @@ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1052
1052
  tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]);
1053
1053
  tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]);
1054
1054
 
1055
- return __lsx_vld(tmp, 0);
1055
+ return (__m128)__lsx_vld(tmp, 0);
1056
1056
  }
1057
1057
 
1058
1058
  static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
@@ -1067,9 +1067,9 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1067
1067
  }
1068
1068
 
1069
1069
  #define GGML_F32Cx4 __m128
1070
- #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1071
- #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1072
- #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1070
+ #define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
1071
+ #define GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1072
+ #define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
1073
1073
  #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1074
1074
  #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1075
1075
  #define GGML_F32Cx4_ADD __lsx_vfadd_s