@fugood/llama.node 1.2.5 → 1.3.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.2.5",
4
+ "version": "1.3.0-rc.0",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -72,19 +72,19 @@
72
72
  "CMakeLists.txt"
73
73
  ],
74
74
  "optionalDependencies": {
75
- "@fugood/node-llama-linux-x64": "1.2.5",
76
- "@fugood/node-llama-linux-x64-vulkan": "1.2.5",
77
- "@fugood/node-llama-linux-x64-cuda": "1.2.5",
78
- "@fugood/node-llama-linux-arm64": "1.2.5",
79
- "@fugood/node-llama-linux-arm64-vulkan": "1.2.5",
80
- "@fugood/node-llama-linux-arm64-cuda": "1.2.5",
81
- "@fugood/node-llama-win32-x64": "1.2.5",
82
- "@fugood/node-llama-win32-x64-vulkan": "1.2.5",
83
- "@fugood/node-llama-win32-x64-cuda": "1.2.5",
84
- "@fugood/node-llama-win32-arm64": "1.2.5",
85
- "@fugood/node-llama-win32-arm64-vulkan": "1.2.5",
86
- "@fugood/node-llama-darwin-x64": "1.2.5",
87
- "@fugood/node-llama-darwin-arm64": "1.2.5"
75
+ "@fugood/node-llama-linux-x64": "1.3.0-rc.0",
76
+ "@fugood/node-llama-linux-x64-vulkan": "1.3.0-rc.0",
77
+ "@fugood/node-llama-linux-x64-cuda": "1.3.0-rc.0",
78
+ "@fugood/node-llama-linux-arm64": "1.3.0-rc.0",
79
+ "@fugood/node-llama-linux-arm64-vulkan": "1.3.0-rc.0",
80
+ "@fugood/node-llama-linux-arm64-cuda": "1.3.0-rc.0",
81
+ "@fugood/node-llama-win32-x64": "1.3.0-rc.0",
82
+ "@fugood/node-llama-win32-x64-vulkan": "1.3.0-rc.0",
83
+ "@fugood/node-llama-win32-x64-cuda": "1.3.0-rc.0",
84
+ "@fugood/node-llama-win32-arm64": "1.3.0-rc.0",
85
+ "@fugood/node-llama-win32-arm64-vulkan": "1.3.0-rc.0",
86
+ "@fugood/node-llama-darwin-x64": "1.3.0-rc.0",
87
+ "@fugood/node-llama-darwin-arm64": "1.3.0-rc.0"
88
88
  },
89
89
  "devDependencies": {
90
90
  "@babel/preset-env": "^7.24.4",
@@ -168,6 +168,25 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
168
168
  static_cast<napi_property_attributes>(napi_enumerable)),
169
169
  InstanceMethod<&LlamaContext::DecodeAudioTokens>(
170
170
  "decodeAudioTokens",
171
+ static_cast<napi_property_attributes>(napi_enumerable)),
172
+ // Parallel decoding methods
173
+ InstanceMethod<&LlamaContext::EnableParallelMode>(
174
+ "enableParallelMode",
175
+ static_cast<napi_property_attributes>(napi_enumerable)),
176
+ InstanceMethod<&LlamaContext::DisableParallelMode>(
177
+ "disableParallelMode",
178
+ static_cast<napi_property_attributes>(napi_enumerable)),
179
+ InstanceMethod<&LlamaContext::QueueCompletion>(
180
+ "queueCompletion",
181
+ static_cast<napi_property_attributes>(napi_enumerable)),
182
+ InstanceMethod<&LlamaContext::QueueEmbedding>(
183
+ "queueEmbedding",
184
+ static_cast<napi_property_attributes>(napi_enumerable)),
185
+ InstanceMethod<&LlamaContext::QueueRerank>(
186
+ "queueRerank",
187
+ static_cast<napi_property_attributes>(napi_enumerable)),
188
+ InstanceMethod<&LlamaContext::CancelRequest>(
189
+ "cancelRequest",
171
190
  static_cast<napi_property_attributes>(napi_enumerable))});
172
191
  Napi::FunctionReference *constructor = new Napi::FunctionReference();
173
192
  *constructor = Napi::Persistent(func);
@@ -217,6 +236,7 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
217
236
  params.n_ctx = get_option<int32_t>(options, "n_ctx", 512);
218
237
  params.n_batch = get_option<int32_t>(options, "n_batch", 2048);
219
238
  params.n_ubatch = get_option<int32_t>(options, "n_ubatch", 512);
239
+ params.n_parallel = get_option<int32_t>(options, "n_parallel", 1); // Default to 1 for compatibility
220
240
  params.embedding = get_option<bool>(options, "embedding", false);
221
241
  if (params.embedding) {
222
242
  // For non-causal models, batch size must be equal to ubatch size
@@ -288,6 +308,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
288
308
  }
289
309
  }
290
310
  }
311
+ // Initialize validity flag for async callback safety
312
+ _context_valid = std::make_shared<std::atomic<bool>>(true);
313
+
291
314
  // Use rn-llama context instead of direct session
292
315
  _rn_ctx = new llama_rn_context();
293
316
  if (!_rn_ctx->loadModel(params)) {
@@ -305,6 +328,11 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
305
328
  }
306
329
 
307
330
  LlamaContext::~LlamaContext() {
331
+ // Invalidate the context to prevent use-after-free in async callbacks
332
+ if (_context_valid) {
333
+ _context_valid->store(false);
334
+ }
335
+
308
336
  // The DisposeWorker is responsible for cleanup of _rn_ctx
309
337
  // If _rn_ctx is still not null here, it means disposal was not properly initiated
310
338
  if (_rn_ctx) {
@@ -579,7 +607,7 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
579
607
  // grammar: string
580
608
  result.Set("grammar", chatParams.grammar);
581
609
  // grammar_lazy: boolean
582
- result.Set("grammea_lazy", chatParams.grammar_lazy);
610
+ result.Set("grammar_lazy", chatParams.grammar_lazy);
583
611
  // grammar_triggers: [{ value: string, token: number }]
584
612
  Napi::Array grammar_triggers = Napi::Array::New(env);
585
613
  for (size_t i = 0; i < chatParams.grammar_triggers.size(); i++) {
@@ -1135,6 +1163,11 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
1135
1163
  _wip->SetStop();
1136
1164
  }
1137
1165
 
1166
+ // stop_processing_loop
1167
+ if (_rn_ctx && _rn_ctx->slot_manager) {
1168
+ _rn_ctx->slot_manager->stop_processing_loop();
1169
+ }
1170
+
1138
1171
  if (_rn_ctx == nullptr) {
1139
1172
  auto promise = Napi::Promise::Deferred(env);
1140
1173
  promise.Resolve(env.Undefined());
@@ -4,6 +4,10 @@
4
4
  #include "rn-llama/rn-llama.h"
5
5
  #include "rn-llama/rn-completion.h"
6
6
  #include "rn-llama/rn-tts.h"
7
+ #include "rn-llama/rn-slot.h"
8
+ #include "rn-llama/rn-slot-manager.h"
9
+ #include <atomic>
10
+ #include <memory>
7
11
 
8
12
  using namespace rnllama;
9
13
 
@@ -55,10 +59,22 @@ private:
55
59
  Napi::Value GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info);
56
60
  Napi::Value DecodeAudioTokens(const Napi::CallbackInfo &info);
57
61
 
62
+ // Parallel decoding methods
63
+ Napi::Value EnableParallelMode(const Napi::CallbackInfo &info);
64
+ void DisableParallelMode(const Napi::CallbackInfo &info);
65
+ Napi::Value QueueCompletion(const Napi::CallbackInfo &info);
66
+ Napi::Value QueueEmbedding(const Napi::CallbackInfo &info);
67
+ Napi::Value QueueRerank(const Napi::CallbackInfo &info);
68
+ void CancelRequest(const Napi::CallbackInfo &info);
69
+
58
70
  std::string _info;
59
71
  Napi::Object _meta;
60
72
  LlamaCompletionWorker *_wip = nullptr;
61
73
 
62
74
  // Use rn-llama context instead of direct llama.cpp types
63
75
  llama_rn_context *_rn_ctx = nullptr;
76
+
77
+ // Validity flag for async callbacks to prevent use-after-free
78
+ // Shared pointer ensures callbacks can safely check if context is still alive
79
+ std::shared_ptr<std::atomic<bool>> _context_valid;
64
80
  };
package/src/common.hpp CHANGED
@@ -16,11 +16,12 @@ static bool is_nil(const Napi::Value &value) {
16
16
  return value.IsNull() || value.IsUndefined();
17
17
  }
18
18
 
19
- static std::string json_stringify(const Napi::Object &obj) {
20
- Napi::Env env = obj.Env();
19
+ // Overload for Napi::Value to handle both arrays and objects
20
+ static std::string json_stringify(const Napi::Value &value) {
21
+ Napi::Env env = value.Env();
21
22
  Napi::Object json = env.Global().Get("JSON").As<Napi::Object>();
22
23
  Napi::Function stringify = json.Get("stringify").As<Napi::Function>();
23
- return stringify.Call(json, {obj}).As<Napi::String>().ToString();
24
+ return stringify.Call(json, {value}).As<Napi::String>().ToString();
24
25
  }
25
26
 
26
27
  static void console_log(Napi::Env env, const std::string &message) {
@@ -1760,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
1760
1760
  ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
1761
1761
  add_opt(common_arg(
1762
1762
  {"-t", "--threads"}, "N",
1763
- string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
1763
+ string_format("number of CPU threads to use during generation (default: %d)", params.cpuparams.n_threads),
1764
1764
  [](common_params & params, int value) {
1765
1765
  params.cpuparams.n_threads = value;
1766
1766
  if (params.cpuparams.n_threads <= 0) {
@@ -577,6 +577,10 @@ extern "C" {
577
577
  GGML_UNARY_OP_EXP,
578
578
  GGML_UNARY_OP_GELU_ERF,
579
579
  GGML_UNARY_OP_XIELU,
580
+ GGML_UNARY_OP_FLOOR,
581
+ GGML_UNARY_OP_CEIL,
582
+ GGML_UNARY_OP_ROUND,
583
+ GGML_UNARY_OP_TRUNC,
580
584
 
581
585
  GGML_UNARY_OP_COUNT,
582
586
  };
@@ -1151,6 +1155,46 @@ extern "C" {
1151
1155
  struct ggml_context * ctx,
1152
1156
  struct ggml_tensor * a);
1153
1157
 
1158
+ GGML_API struct ggml_tensor * ggml_floor(
1159
+ struct ggml_context * ctx,
1160
+ struct ggml_tensor * a);
1161
+
1162
+ GGML_API struct ggml_tensor * ggml_floor_inplace(
1163
+ struct ggml_context * ctx,
1164
+ struct ggml_tensor * a);
1165
+
1166
+ GGML_API struct ggml_tensor * ggml_ceil(
1167
+ struct ggml_context * ctx,
1168
+ struct ggml_tensor * a);
1169
+
1170
+ GGML_API struct ggml_tensor * ggml_ceil_inplace(
1171
+ struct ggml_context * ctx,
1172
+ struct ggml_tensor * a);
1173
+
1174
+ GGML_API struct ggml_tensor * ggml_round(
1175
+ struct ggml_context * ctx,
1176
+ struct ggml_tensor * a);
1177
+
1178
+ GGML_API struct ggml_tensor * ggml_round_inplace(
1179
+ struct ggml_context * ctx,
1180
+ struct ggml_tensor * a);
1181
+
1182
+ /**
1183
+ * Truncates the fractional part of each element in the tensor (towards zero).
1184
+ * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
1185
+ * Similar to std::trunc in C/C++.
1186
+ */
1187
+
1188
+ GGML_API struct ggml_tensor * ggml_trunc(
1189
+ struct ggml_context * ctx,
1190
+ struct ggml_tensor * a);
1191
+
1192
+ GGML_API struct ggml_tensor * ggml_trunc_inplace(
1193
+ struct ggml_context * ctx,
1194
+ struct ggml_tensor * a);
1195
+
1196
+
1197
+
1154
1198
  // xIELU activation function
1155
1199
  // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
1156
1200
  // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
@@ -68,7 +68,7 @@ struct ggml_compute_params {
68
68
  #endif // __VXE2__
69
69
  #endif // __s390x__ && __VEC__
70
70
 
71
- #if defined(__ARM_FEATURE_SVE)
71
+ #if defined(__ARM_FEATURE_SVE) && defined(__linux__)
72
72
  #include <sys/prctl.h>
73
73
  #endif
74
74
 
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
689
689
  #endif
690
690
 
691
691
  static void ggml_init_arm_arch_features(void) {
692
- #if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
692
+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
693
+ #if defined(__linux__)
693
694
  ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
695
+ #else
696
+ // TODO: add support of SVE for non-linux systems
697
+ #error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
698
+ #endif
694
699
  #endif
695
700
  }
696
701
 
@@ -2179,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2179
2184
  case GGML_UNARY_OP_HARDSWISH:
2180
2185
  case GGML_UNARY_OP_HARDSIGMOID:
2181
2186
  case GGML_UNARY_OP_EXP:
2187
+ case GGML_UNARY_OP_FLOOR:
2188
+ case GGML_UNARY_OP_CEIL:
2189
+ case GGML_UNARY_OP_ROUND:
2190
+ case GGML_UNARY_OP_TRUNC:
2182
2191
  {
2183
2192
  n_tasks = 1;
2184
2193
  } break;
@@ -3558,13 +3567,17 @@ void ggml_cpu_init(void) {
3558
3567
  #ifdef GGML_USE_OPENMP
3559
3568
  //if (!getenv("OMP_WAIT_POLICY")) {
3560
3569
  // // set the wait policy to active, so that OpenMP threads don't sleep
3561
- // putenv("OMP_WAIT_POLICY=active");
3570
+ // setenv("OMP_WAIT_POLICY", "active", 0)
3562
3571
  //}
3563
3572
 
3564
3573
  if (!getenv("KMP_BLOCKTIME")) {
3565
3574
  // set the time to wait before sleeping a thread
3566
3575
  // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
3567
- putenv("KMP_BLOCKTIME=200"); // 200ms
3576
+ #ifdef _WIN32
3577
+ _putenv_s("KMP_BLOCKTIME", "200"); // 200ms
3578
+ #else
3579
+ setenv("KMP_BLOCKTIME", "200", 0); // 200ms
3580
+ #endif
3568
3581
  }
3569
3582
  #endif
3570
3583
  }
@@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary(
8993
8993
  {
8994
8994
  ggml_compute_forward_exp(params, dst);
8995
8995
  } break;
8996
+ case GGML_UNARY_OP_FLOOR:
8997
+ {
8998
+ ggml_compute_forward_floor(params, dst);
8999
+ } break;
9000
+ case GGML_UNARY_OP_CEIL:
9001
+ {
9002
+ ggml_compute_forward_ceil(params, dst);
9003
+ } break;
9004
+ case GGML_UNARY_OP_ROUND:
9005
+ {
9006
+ ggml_compute_forward_round(params, dst);
9007
+ } break;
9008
+ case GGML_UNARY_OP_TRUNC:
9009
+ {
9010
+ ggml_compute_forward_trunc(params, dst);
9011
+ } break;
8996
9012
  case GGML_UNARY_OP_XIELU:
8997
9013
  {
8998
9014
  ggml_compute_forward_xielu(params, dst);
@@ -73,6 +73,22 @@ static inline float op_log(float x) {
73
73
  return logf(x);
74
74
  }
75
75
 
76
+ static inline float op_floor(float x) {
77
+ return floorf(x);
78
+ }
79
+
80
+ static inline float op_ceil(float x) {
81
+ return ceilf(x);
82
+ }
83
+
84
+ static inline float op_round(float x) {
85
+ return roundf(x);
86
+ }
87
+
88
+ static inline float op_trunc(float x) {
89
+ return truncf(x);
90
+ }
91
+
76
92
  template <float (*op)(float), typename src0_t, typename dst_t>
77
93
  static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
78
94
  constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
274
290
  unary_op<op_log>(params, dst);
275
291
  }
276
292
 
293
+ void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
294
+ unary_op<op_floor>(params, dst);
295
+ }
296
+
297
+ void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
298
+ unary_op<op_ceil>(params, dst);
299
+ }
300
+
301
+ void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
302
+ unary_op<op_round>(params, dst);
303
+ }
304
+
305
+ void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
306
+ unary_op<op_trunc>(params, dst);
307
+ }
308
+
277
309
  void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
278
310
  const float alpha_n = ggml_get_op_params_f32(dst, 1);
279
311
  const float alpha_p = ggml_get_op_params_f32(dst, 2);
@@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
22
22
  void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
23
23
  void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
24
24
  void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
25
+ void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
26
+ void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
27
+ void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
28
+ void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
25
29
  void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
26
30
 
27
31
  #ifdef __cplusplus
@@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
463
463
  #endif
464
464
  for (; i < n; ++i) {
465
465
  float val = x[i] - mean;
466
+ y[i] = val;
466
467
  val *= val;
467
468
  sum += (ggml_float)val;
468
- y[i] = val;
469
469
  }
470
470
  return sum/n;
471
471
  }
@@ -5,6 +5,7 @@
5
5
  #include <map>
6
6
 
7
7
  static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8
+ { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
8
9
  { LLM_ARCH_LLAMA, "llama" },
9
10
  { LLM_ARCH_LLAMA4, "llama4" },
10
11
  { LLM_ARCH_DECI, "deci" },
@@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
275
276
  };
276
277
 
277
278
  static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
279
+ {
280
+ LLM_ARCH_CLIP,
281
+ {},
282
+ },
278
283
  {
279
284
  LLM_ARCH_LLAMA,
280
285
  {
@@ -9,6 +9,7 @@
9
9
  //
10
10
 
11
11
  enum llm_arch {
12
+ LLM_ARCH_CLIP,
12
13
  LLM_ARCH_LLAMA,
13
14
  LLM_ARCH_LLAMA4,
14
15
  LLM_ARCH_DECI,
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261
261
  }
262
262
  }
263
263
 
264
- static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264
+ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265
265
  LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266
- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267
- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268
- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269
- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
266
+ const char * swa_type_str = "unknown";
267
+
268
+ switch (swa_type) {
269
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273
+ };
274
+
270
275
  LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271
276
  LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272
277
  LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295
300
  const int64_t n_kv = ubatch->n_tokens;
296
301
  const int64_t n_tokens = ubatch->n_tokens;
297
302
 
298
- GGML_ASSERT(kq_mask);
299
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300
-
301
- float * data = (float *) kq_mask->data;
302
-
303
- // [TAG_NO_CACHE_ISWA]
304
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
303
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304
+ for (int h = 0; h < 1; ++h) {
305
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
306
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
307
+ const llama_pos p1 = ubatch->pos[i1];
305
308
 
306
- for (int h = 0; h < 1; ++h) {
307
- for (int i1 = 0; i1 < n_tokens; ++i1) {
308
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
309
+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
309
310
 
310
- for (int i0 = 0; i0 < n_tokens; ++i0) {
311
- float f = -INFINITY;
312
-
313
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
311
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
314
312
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
313
+ const llama_pos p0 = ubatch->pos[i0];
315
314
 
315
+ // mask different sequences
316
316
  if (s0 != s1) {
317
- continue; // skip different sequences
317
+ continue;
318
318
  }
319
319
 
320
- if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321
- continue; // skip future tokens for causal attention
320
+ // mask future tokens
321
+ if (cparams.causal_attn && p0 > p1) {
322
+ continue;
322
323
  }
323
324
 
324
- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325
- //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326
- // continue; // skip masked tokens for SWA
327
- //}
328
-
329
- // TODO: reimplement this like in llama_kv_cache_unified
330
- if (hparams.use_alibi) {
331
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332
- } else {
333
- f = 0.0f;
325
+ // apply SWA if any
326
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
327
+ continue;
334
328
  }
329
+
330
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
335
331
  }
336
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337
332
  }
338
333
  }
334
+ };
335
+
336
+ {
337
+ GGML_ASSERT(self_kq_mask);
338
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
339
+
340
+ float * data = (float *) self_kq_mask->data;
341
+
342
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
343
+
344
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
345
+
346
+ if (debug) {
347
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
348
+ }
339
349
  }
340
- if (debug) {
341
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
350
+
351
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352
+ GGML_ASSERT(self_kq_mask_swa);
353
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
354
+
355
+ float * data = (float *) self_kq_mask_swa->data;
356
+
357
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
358
+
359
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
360
+
361
+ if (debug) {
362
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
363
+ }
342
364
  }
343
365
  }
344
366
 
@@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1299
1321
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1300
1322
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1301
1323
 
1302
- const auto n_kv = k->ne[1];
1303
-
1304
1324
  ggml_tensor * cur;
1305
1325
 
1306
- // TODO: replace hardcoded padding with ggml-provided padding
1307
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1326
+ if (cparams.flash_attn && kq_b == nullptr) {
1308
1327
  GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1309
1328
 
1310
1329
  if (v_trans) {
@@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1419
1438
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1420
1439
 
1421
1440
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422
- inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423
- ggml_set_input(inp->kq_mask);
1441
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1442
+ ggml_set_input(inp->self_kq_mask);
1443
+
1444
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1424
1445
 
1425
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1446
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1447
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1448
+ ggml_set_input(inp->self_kq_mask_swa);
1449
+
1450
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1451
+ } else {
1452
+ inp->self_kq_mask_swa = nullptr;
1453
+ inp->self_kq_mask_swa_cnv = nullptr;
1454
+ }
1426
1455
 
1427
1456
  return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1428
1457
  }
@@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
1447
1476
  ggml_build_forward_expand(gf, k_cur);
1448
1477
  ggml_build_forward_expand(gf, v_cur);
1449
1478
 
1450
- const auto & kq_mask = inp->get_kq_mask();
1479
+ const bool is_swa = hparams.is_swa(il);
1480
+
1481
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1451
1482
 
1452
1483
  // [TAG_NO_CACHE_PAD]
1453
1484
  // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
@@ -257,10 +257,14 @@ public:
257
257
 
258
258
  void set_input(const llama_ubatch * ubatch) override;
259
259
 
260
- ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
260
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261
+ ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261
262
 
262
- ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263
- ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
263
+ // n_tokens == n_batch
264
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
264
268
 
265
269
  const llama_hparams hparams;
266
270
  const llama_cparams cparams;