@fugood/llama.node 1.1.11 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +18 -1
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +166 -396
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +50 -30
  27. package/src/llama.cpp/common/chat.cpp +111 -1
  28. package/src/llama.cpp/common/chat.h +3 -0
  29. package/src/llama.cpp/common/common.h +1 -1
  30. package/src/llama.cpp/common/log.cpp +53 -2
  31. package/src/llama.cpp/common/log.h +10 -4
  32. package/src/llama.cpp/common/sampling.cpp +23 -2
  33. package/src/llama.cpp/common/sampling.h +3 -1
  34. package/src/llama.cpp/common/speculative.cpp +1 -1
  35. package/src/llama.cpp/ggml/CMakeLists.txt +3 -2
  36. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  37. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  38. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  39. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +14 -13
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  41. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +0 -6
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  44. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +4 -9
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +218 -4
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  47. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +41 -37
  48. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +150 -28
  49. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +320 -73
  50. package/src/llama.cpp/include/llama.h +5 -6
  51. package/src/llama.cpp/src/llama-adapter.cpp +33 -0
  52. package/src/llama.cpp/src/llama-adapter.h +3 -0
  53. package/src/llama.cpp/src/llama-arch.cpp +27 -4
  54. package/src/llama.cpp/src/llama-arch.h +2 -0
  55. package/src/llama.cpp/src/llama-context.cpp +62 -56
  56. package/src/llama.cpp/src/llama-context.h +1 -1
  57. package/src/llama.cpp/src/llama-graph.cpp +54 -9
  58. package/src/llama.cpp/src/llama-graph.h +8 -0
  59. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  60. package/src/llama.cpp/src/llama-hparams.h +9 -3
  61. package/src/llama.cpp/src/llama-kv-cache.cpp +1 -23
  62. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  63. package/src/llama.cpp/src/llama-model.cpp +159 -1
  64. package/src/llama.cpp/src/llama-model.h +0 -1
  65. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  66. package/src/anyascii.c +0 -22223
  67. package/src/anyascii.h +0 -42
  68. package/src/tts_utils.cpp +0 -371
  69. package/src/tts_utils.h +0 -103
@@ -270,19 +270,7 @@ llama_context::llama_context(
270
270
  }
271
271
  }
272
272
 
273
- // resolve automatic Flash Attention use and reserve worst-case graph
274
273
  if (!hparams.vocab_only) {
275
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
276
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
277
-
278
- LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
279
-
280
- int n_splits_pp = -1;
281
- int n_nodes_pp = -1;
282
-
283
- int n_splits_tg = -1;
284
- int n_nodes_tg = -1;
285
-
286
274
  llama_memory_context_ptr mctx;
287
275
  if (memory) {
288
276
  LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
@@ -294,54 +282,69 @@ llama_context::llama_context(
294
282
 
295
283
  cross.v_embd.clear();
296
284
 
297
- // reserve pp (prompt processing) graph first so that buffers are only allocated once
298
- {
299
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
285
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
286
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
287
+
288
+ // avoid reserving graphs with zero outputs
289
+ n_outputs = 1;
290
+
291
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
292
+
293
+ // resolve automatic Flash Attention use
294
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
295
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
300
296
  if (!gf) {
301
- throw std::runtime_error("failed to allocate compute pp buffers");
297
+ throw std::runtime_error("failed to split graph for Flash Attention check");
302
298
  }
303
299
 
304
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
305
- ggml_backend_sched_alloc_graph(sched.get(), gf);
306
-
307
- const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
308
- bool fa_device_mismatch = false;
309
- for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
310
- ggml_tensor * n = ggml_graph_node(gf, i);
311
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
312
- continue;
313
- }
314
- ggml_backend_dev_t device_fa = ggml_backend_get_device(
315
- ggml_backend_sched_get_tensor_backend(sched.get(), n));
316
-
317
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
318
- GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
319
- const int il = std::stoi(n->name + prefix_len);
320
- ggml_backend_dev_t device_kv = model.dev_layer(il);
321
- if (device_fa != device_kv) {
322
- LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
323
- "is assigned to device %s (usually due to missing support)\n",
324
- __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
325
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
326
- fa_device_mismatch = true;
327
- break;
328
- }
300
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
301
+ bool fa_device_mismatch = false;
302
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
303
+ ggml_tensor * n = ggml_graph_node(gf, i);
304
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
305
+ continue;
329
306
  }
330
- if (fa_device_mismatch) {
331
- cparams.flash_attn = false;
332
- LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
333
- if (ggml_is_quantized(params.type_v)) {
334
- throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
335
- }
336
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
337
- if (!gf) {
338
- throw std::runtime_error("failed to allocate compute pp buffers");
339
- }
340
- } else {
341
- cparams.flash_attn = true;
342
- LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
307
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
308
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
309
+
310
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
311
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
312
+ const int il = std::stoi(n->name + prefix_len);
313
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
314
+ if (device_fa != device_kv) {
315
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
316
+ "is assigned to device %s (usually due to missing support)\n",
317
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
318
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
319
+ fa_device_mismatch = true;
320
+ break;
343
321
  }
344
322
  }
323
+ if (fa_device_mismatch) {
324
+ cparams.flash_attn = false;
325
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
326
+ if (ggml_is_quantized(params.type_v)) {
327
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
328
+ }
329
+ } else {
330
+ cparams.flash_attn = true;
331
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
332
+ }
333
+ }
334
+
335
+ // reserve worst-case graph
336
+ int n_splits_pp = -1;
337
+ int n_nodes_pp = -1;
338
+
339
+ int n_splits_tg = -1;
340
+ int n_nodes_tg = -1;
341
+
342
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
343
+ {
344
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
345
+ if (!gf) {
346
+ throw std::runtime_error("failed to allocate compute pp buffers");
347
+ }
345
348
 
346
349
  n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
347
350
  n_nodes_pp = ggml_graph_n_nodes(gf);
@@ -1366,8 +1369,9 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
1366
1369
  return static_cast<llm_graph_result *>(gf_res_reserve.get());
1367
1370
  }
1368
1371
 
1369
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1372
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1370
1373
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
+ GGML_ASSERT(n_outputs >= 1);
1371
1375
 
1372
1376
  if (n_tokens % n_seqs != 0) {
1373
1377
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
@@ -1401,7 +1405,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1401
1405
  this->n_outputs = save_n_outputs;
1402
1406
 
1403
1407
  // initialize scheduler with the specified graph
1404
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1408
+ if (split_only) {
1409
+ ggml_backend_sched_split_graph(sched.get(), gf);
1410
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1405
1411
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1406
1412
  return nullptr;
1407
1413
  }
@@ -196,7 +196,7 @@ public:
196
196
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
197
197
 
198
198
  // reserve a graph with a dummy ubatch of the specified size
199
- ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
199
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
200
200
 
201
201
  private:
202
202
  llm_graph_params graph_params(
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
258
258
  }
259
259
  }
260
260
 
261
+ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
262
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
263
+ const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
264
+ (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
265
+ (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
266
+ (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
267
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
268
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
269
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
270
+
271
+ LLAMA_LOG_DEBUG(" ");
272
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
273
+ LLAMA_LOG_DEBUG("%2d", j);
274
+ }
275
+ LLAMA_LOG_DEBUG("\n");
276
+
277
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
278
+ LLAMA_LOG_DEBUG(" %2d ", i);
279
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
280
+ float val = data[i * n_kv + j];
281
+ if (val == -INFINITY) {
282
+ LLAMA_LOG_DEBUG(" ∞");
283
+ } else {
284
+ LLAMA_LOG_DEBUG(" 0");
285
+ }
286
+ }
287
+ LLAMA_LOG_DEBUG("\n");
288
+ }
289
+ }
290
+
261
291
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
262
292
  const int64_t n_kv = ubatch->n_tokens;
263
293
  const int64_t n_tokens = ubatch->n_tokens;
@@ -267,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
267
297
 
268
298
  float * data = (float *) kq_mask->data;
269
299
 
300
+ // [TAG_NO_CACHE_ISWA]
301
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
302
+
270
303
  for (int h = 0; h < 1; ++h) {
271
304
  for (int i1 = 0; i1 < n_tokens; ++i1) {
272
305
  const llama_seq_id s1 = ubatch->seq_id[i1][0];
@@ -277,21 +310,33 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
310
  for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
278
311
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
279
312
 
280
- // TODO: reimplement this like in llama_kv_cache
281
- if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
282
- if (hparams.use_alibi) {
283
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
284
- } else {
285
- f = 0.0f;
286
- }
287
- break;
313
+ if (s0 != s1) {
314
+ continue; // skip different sequences
288
315
  }
289
- }
290
316
 
317
+ if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
318
+ continue; // skip future tokens for causal attention
319
+ }
320
+
321
+ // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
322
+ //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
323
+ // continue; // skip masked tokens for SWA
324
+ //}
325
+
326
+ // TODO: reimplement this like in llama_kv_cache_unified
327
+ if (hparams.use_alibi) {
328
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
329
+ } else {
330
+ f = 0.0f;
331
+ }
332
+ }
291
333
  data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
292
334
  }
293
335
  }
294
336
  }
337
+ if (debug) {
338
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
339
+ }
295
340
  }
296
341
 
297
342
  void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
@@ -78,6 +78,11 @@ struct llm_graph_params;
78
78
 
79
79
  class llm_graph_input_i {
80
80
  public:
81
+ llm_graph_input_i() {
82
+ const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
83
+ debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
84
+ }
85
+
81
86
  virtual ~llm_graph_input_i() = default;
82
87
 
83
88
  virtual void set_input(const llama_ubatch * ubatch) = 0;
@@ -90,6 +95,9 @@ public:
90
95
  GGML_UNUSED(params);
91
96
  return false;
92
97
  }
98
+ protected:
99
+ // env: LLAMA_GRAPH_INPUT_DEBUG
100
+ int debug = 0;
93
101
  };
94
102
 
95
103
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
@@ -1,6 +1,7 @@
1
1
  #include "llama-hparams.h"
2
2
 
3
3
  #include "ggml.h"
4
+ #include <cassert>
4
5
 
5
6
  void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
6
7
  if (dense_first) {
@@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
178
179
 
179
180
  return res;
180
181
  }
182
+
183
+ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
184
+ assert(p0 >= 0 && p1 >= 0);
185
+
186
+ switch (swa_type) {
187
+ case LLAMA_SWA_TYPE_NONE:
188
+ {
189
+ } break;
190
+ case LLAMA_SWA_TYPE_STANDARD:
191
+ {
192
+ if (p1 - p0 >= (int32_t) n_swa) {
193
+ return true;
194
+ }
195
+ } break;
196
+ case LLAMA_SWA_TYPE_CHUNKED:
197
+ {
198
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
199
+
200
+ if (p0 < pos_chunk_start) {
201
+ return true;
202
+ }
203
+ } break;
204
+ case LLAMA_SWA_TYPE_SYMMETRIC:
205
+ {
206
+ const int32_t half_n_swa = (int32_t) n_swa / 2;
207
+ const int32_t pos_diff = p1 - p0;
208
+
209
+ // Mask if outside the symmetric window
210
+ if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
211
+ return true;
212
+ }
213
+ } break;
214
+ }
215
+
216
+ return false;
217
+ }
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
16
16
  };
17
17
 
18
18
  enum llama_swa_type {
19
- LLAMA_SWA_TYPE_NONE = 0,
20
- LLAMA_SWA_TYPE_STANDARD = 1,
21
- LLAMA_SWA_TYPE_CHUNKED = 2,
19
+ LLAMA_SWA_TYPE_NONE = 0,
20
+ LLAMA_SWA_TYPE_STANDARD = 1,
21
+ LLAMA_SWA_TYPE_CHUNKED = 2,
22
+ LLAMA_SWA_TYPE_SYMMETRIC = 3,
22
23
  };
23
24
 
24
25
  struct llama_hparams_posnet {
@@ -227,6 +228,11 @@ struct llama_hparams {
227
228
 
228
229
  // number of layers for which has_kv() returns true
229
230
  uint32_t n_layer_kv() const;
231
+
232
+ // note that this function uses different SWA parameters from those in the hparams
233
+ // TODO: think of a better place for this function
234
+ // TODO: pack the SWA params in a struct?
235
+ static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
230
236
  };
231
237
 
232
238
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -1393,29 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
1393
1393
  }
1394
1394
 
1395
1395
  bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
1396
- assert(p0 >= 0 && p1 >= 0);
1397
-
1398
- switch (swa_type) {
1399
- case LLAMA_SWA_TYPE_NONE:
1400
- {
1401
- } break;
1402
- case LLAMA_SWA_TYPE_STANDARD:
1403
- {
1404
- if (p1 - p0 >= (int32_t) n_swa) {
1405
- return true;
1406
- }
1407
- } break;
1408
- case LLAMA_SWA_TYPE_CHUNKED:
1409
- {
1410
- const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1411
-
1412
- if (p0 < pos_chunk_start) {
1413
- return true;
1414
- }
1415
- } break;
1416
- }
1417
-
1418
- return false;
1396
+ return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
1419
1397
  }
1420
1398
 
1421
1399
  void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
@@ -212,6 +212,7 @@ private:
212
212
  // env: LLAMA_KV_CACHE_DEBUG
213
213
  int debug = 0;
214
214
 
215
+ // this is the SWA type of the cache - not to be confused with the model SWA type
215
216
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
216
217
 
217
218
  std::vector<ggml_context_ptr> ctxs;
@@ -1110,7 +1110,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1110
1110
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1111
1111
 
1112
1112
  switch (hparams.n_layer) {
1113
- case 18: type = LLM_TYPE_537M; break;
1113
+ case 18: type = LLM_TYPE_270M; break;
1114
1114
  case 26: type = LLM_TYPE_1B; break;
1115
1115
  case 34: type = LLM_TYPE_4B; break;
1116
1116
  case 48: type = LLM_TYPE_12B; break;
@@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1142
1142
  default: type = LLM_TYPE_UNKNOWN;
1143
1143
  }
1144
1144
  } break;
1145
+ case LLM_ARCH_GEMMA_EMBEDDING:
1146
+ {
1147
+ hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
1148
+ hparams.set_swa_pattern(6);
1149
+
1150
+ hparams.causal_attn = false; // embeddings do not use causal attention
1151
+ hparams.rope_freq_base_train_swa = 10000.0f;
1152
+ hparams.rope_freq_scale_train_swa = 1.0f;
1153
+
1154
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1155
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1156
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
1157
+
1158
+ switch (hparams.n_layer) {
1159
+ case 24: type = LLM_TYPE_0_3B; break;
1160
+ default: type = LLM_TYPE_UNKNOWN;
1161
+ }
1162
+ hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
1163
+
1164
+ } break;
1145
1165
  case LLM_ARCH_STARCODER2:
1146
1166
  {
1147
1167
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3484
3504
  }
3485
3505
  } break;
3486
3506
  case LLM_ARCH_GEMMA3:
3507
+ case LLM_ARCH_GEMMA_EMBEDDING:
3487
3508
  {
3488
3509
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3489
3510
 
@@ -11045,6 +11066,137 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
11045
11066
  }
11046
11067
  };
11047
11068
 
11069
+ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
11070
+ llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11071
+ const int64_t n_embd_head = hparams.n_embd_head_k;
11072
+
11073
+ ggml_tensor * cur;
11074
+ ggml_tensor * inpL;
11075
+
11076
+ inpL = build_inp_embd(model.tok_embd);
11077
+
11078
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
11079
+ if (ubatch.token) {
11080
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
11081
+ cb(inpL, "inp_scaled", -1);
11082
+ }
11083
+
11084
+ // inp_pos - contains the positions
11085
+ ggml_tensor * inp_pos = build_inp_pos();
11086
+
11087
+ // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11088
+ auto * inp_attn = build_attn_inp_kv_iswa();
11089
+
11090
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11091
+
11092
+ for (int il = 0; il < n_layer; ++il) {
11093
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
11094
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
11095
+
11096
+ // norm
11097
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
11098
+ cb(cur, "attn_norm", il);
11099
+
11100
+ // self-attention
11101
+ {
11102
+ // compute Q and K and RoPE them
11103
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
11104
+ cb(Qcur, "Qcur", il);
11105
+
11106
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
11107
+ cb(Kcur, "Kcur", il);
11108
+
11109
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
11110
+ cb(Vcur, "Vcur", il);
11111
+
11112
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11113
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
11114
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
11115
+
11116
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
11117
+ cb(Qcur, "Qcur_normed", il);
11118
+
11119
+ Qcur = ggml_rope_ext(
11120
+ ctx0, Qcur, inp_pos, nullptr,
11121
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11122
+ ext_factor, attn_factor, beta_fast, beta_slow);
11123
+
11124
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
11125
+ cb(Kcur, "Kcur_normed", il);
11126
+
11127
+ Kcur = ggml_rope_ext(
11128
+ ctx0, Kcur, inp_pos, nullptr,
11129
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11130
+ ext_factor, attn_factor, beta_fast, beta_slow);
11131
+
11132
+ cb(Qcur, "Qcur", il);
11133
+ cb(Kcur, "Kcur", il);
11134
+ cb(Vcur, "Vcur", il);
11135
+
11136
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
11137
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
11138
+
11139
+ cur = build_attn(inp_attn,
11140
+ model.layers[il].wo, NULL,
11141
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
11142
+ }
11143
+
11144
+ if (il == n_layer - 1 && inp_out_ids) {
11145
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11146
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
11147
+ }
11148
+
11149
+ cur = build_norm(cur,
11150
+ model.layers[il].attn_post_norm, NULL,
11151
+ LLM_NORM_RMS, il);
11152
+ cb(cur, "attn_post_norm", il);
11153
+
11154
+ ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
11155
+ cb(sa_out, "sa_out", il);
11156
+
11157
+ cur = build_norm(sa_out,
11158
+ model.layers[il].ffn_norm, NULL,
11159
+ LLM_NORM_RMS, il);
11160
+ cb(cur, "ffn_norm", il);
11161
+
11162
+ // feed-forward network
11163
+ {
11164
+ cur = build_ffn(cur,
11165
+ model.layers[il].ffn_up, NULL, NULL,
11166
+ model.layers[il].ffn_gate, NULL, NULL,
11167
+ model.layers[il].ffn_down, NULL, NULL,
11168
+ NULL,
11169
+ LLM_FFN_GELU, LLM_FFN_PAR, il);
11170
+ cb(cur, "ffn_out", il);
11171
+ }
11172
+
11173
+ cur = build_norm(cur,
11174
+ model.layers[il].ffn_post_norm, NULL,
11175
+ LLM_NORM_RMS, -1);
11176
+ cb(cur, "ffn_post_norm", -1);
11177
+
11178
+ cur = ggml_add(ctx0, cur, sa_out);
11179
+
11180
+ cur = build_cvec(cur, il);
11181
+ cb(cur, "l_out", il);
11182
+
11183
+ // input for next layer
11184
+ inpL = cur;
11185
+ }
11186
+
11187
+ cur = inpL;
11188
+
11189
+ cur = build_norm(cur,
11190
+ model.output_norm, NULL,
11191
+ LLM_NORM_RMS, -1);
11192
+
11193
+ cb(cur, "result_norm", -1);
11194
+ res->t_embd = cur;
11195
+
11196
+ ggml_build_forward_expand(gf, cur);
11197
+ }
11198
+ };
11199
+
11048
11200
  // TODO: move up next to build_starcoder
11049
11201
  struct llm_build_starcoder2 : public llm_graph_context {
11050
11202
  llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@@ -18481,6 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
18481
18633
  case LLM_ARCH_NOMIC_BERT_MOE:
18482
18634
  case LLM_ARCH_NEO_BERT:
18483
18635
  case LLM_ARCH_WAVTOKENIZER_DEC:
18636
+ //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
18484
18637
  case LLM_ARCH_DREAM:
18485
18638
  case LLM_ARCH_LLADA:
18486
18639
  {
@@ -18761,6 +18914,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
18761
18914
  {
18762
18915
  llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
18763
18916
  } break;
18917
+ case LLM_ARCH_GEMMA_EMBEDDING:
18918
+ {
18919
+ llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
18920
+ } break;
18764
18921
  case LLM_ARCH_STARCODER2:
18765
18922
  {
18766
18923
  llm = std::make_unique<llm_build_starcoder2>(*this, params);
@@ -19161,6 +19318,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
19161
19318
  case LLM_ARCH_GEMMA2:
19162
19319
  case LLM_ARCH_GEMMA3:
19163
19320
  case LLM_ARCH_GEMMA3N:
19321
+ case LLM_ARCH_GEMMA_EMBEDDING:
19164
19322
  case LLM_ARCH_STARCODER2:
19165
19323
  case LLM_ARCH_OPENELM:
19166
19324
  case LLM_ARCH_GPTNEOX:
@@ -39,7 +39,6 @@ enum llm_type {
39
39
  LLM_TYPE_410M,
40
40
  LLM_TYPE_450M,
41
41
  LLM_TYPE_475M,
42
- LLM_TYPE_537M,
43
42
  LLM_TYPE_558M,
44
43
  LLM_TYPE_700M,
45
44
  LLM_TYPE_770M,