llama_cpp 0.14.0 → 0.14.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -213,6 +213,7 @@ enum llm_arch {
213
213
  LLM_ARCH_MINICPM,
214
214
  LLM_ARCH_GEMMA,
215
215
  LLM_ARCH_STARCODER2,
216
+ LLM_ARCH_MAMBA,
216
217
  LLM_ARCH_UNKNOWN,
217
218
  };
218
219
 
@@ -241,6 +242,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
241
242
  { LLM_ARCH_MINICPM, "minicpm" },
242
243
  { LLM_ARCH_GEMMA, "gemma" },
243
244
  { LLM_ARCH_STARCODER2, "starcoder2" },
245
+ { LLM_ARCH_MAMBA, "mamba" },
244
246
  { LLM_ARCH_UNKNOWN, "(unknown)" },
245
247
  };
246
248
 
@@ -256,6 +258,7 @@ enum llm_kv {
256
258
  LLM_KV_GENERAL_SOURCE_URL,
257
259
  LLM_KV_GENERAL_SOURCE_HF_REPO,
258
260
 
261
+ LLM_KV_VOCAB_SIZE,
259
262
  LLM_KV_CONTEXT_LENGTH,
260
263
  LLM_KV_EMBEDDING_LENGTH,
261
264
  LLM_KV_BLOCK_COUNT,
@@ -284,6 +287,11 @@ enum llm_kv {
284
287
  LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
285
288
  LLM_KV_ROPE_SCALING_FINETUNED,
286
289
 
290
+ LLM_KV_SSM_INNER_SIZE,
291
+ LLM_KV_SSM_CONV_KERNEL,
292
+ LLM_KV_SSM_STATE_SIZE,
293
+ LLM_KV_SSM_TIME_STEP_RANK,
294
+
287
295
  LLM_KV_TOKENIZER_MODEL,
288
296
  LLM_KV_TOKENIZER_LIST,
289
297
  LLM_KV_TOKENIZER_TOKEN_TYPE,
@@ -314,6 +322,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
314
322
  { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
315
323
  { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
316
324
 
325
+ { LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
317
326
  { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
318
327
  { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
319
328
  { LLM_KV_BLOCK_COUNT, "%s.block_count" },
@@ -342,6 +351,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
342
351
  { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
343
352
  { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
344
353
 
354
+ { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
355
+ { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
356
+ { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
357
+ { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
358
+
345
359
  { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
346
360
  { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
347
361
  { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
@@ -399,6 +413,13 @@ enum llm_tensor {
399
413
  LLM_TENSOR_ATTN_Q_NORM,
400
414
  LLM_TENSOR_ATTN_K_NORM,
401
415
  LLM_TENSOR_LAYER_OUT_NORM,
416
+ LLM_TENSOR_SSM_IN,
417
+ LLM_TENSOR_SSM_CONV1D,
418
+ LLM_TENSOR_SSM_X,
419
+ LLM_TENSOR_SSM_DT,
420
+ LLM_TENSOR_SSM_A,
421
+ LLM_TENSOR_SSM_D,
422
+ LLM_TENSOR_SSM_OUT,
402
423
  };
403
424
 
404
425
  static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -801,6 +822,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
801
822
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
802
823
  },
803
824
  },
825
+ {
826
+ LLM_ARCH_MAMBA,
827
+ {
828
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
829
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
830
+ { LLM_TENSOR_OUTPUT, "output" },
831
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
832
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
833
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
834
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
835
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
836
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
837
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
838
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
839
+ },
840
+ },
804
841
  {
805
842
  LLM_ARCH_UNKNOWN,
806
843
  {
@@ -943,21 +980,6 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
943
980
  }
944
981
  }
945
982
 
946
- //
947
- // ggml helpers
948
- //
949
-
950
- static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
951
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
952
-
953
- if (plan.work_size > 0) {
954
- buf.resize(plan.work_size);
955
- plan.work_data = buf.data();
956
- }
957
-
958
- ggml_graph_compute(graph, &plan);
959
- }
960
-
961
983
  //
962
984
  // llama helpers
963
985
  //
@@ -1613,6 +1635,12 @@ struct llama_hparams {
1613
1635
  float rope_freq_scale_train;
1614
1636
  uint32_t n_yarn_orig_ctx;
1615
1637
 
1638
+ // for State Space Models
1639
+ uint32_t ssm_d_conv = 0;
1640
+ uint32_t ssm_d_inner = 0;
1641
+ uint32_t ssm_d_state = 0;
1642
+ uint32_t ssm_dt_rank = 0;
1643
+
1616
1644
  float f_clamp_kqv = 0.0f;
1617
1645
  float f_max_alibi_bias = 0.0f;
1618
1646
 
@@ -1641,6 +1669,11 @@ struct llama_hparams {
1641
1669
  if (this->rope_finetuned != other.rope_finetuned) return true;
1642
1670
  if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
1643
1671
 
1672
+ if (this->ssm_d_conv != other.ssm_d_conv) return true;
1673
+ if (this->ssm_d_inner != other.ssm_d_inner) return true;
1674
+ if (this->ssm_d_state != other.ssm_d_state) return true;
1675
+ if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
1676
+
1644
1677
  const float EPSILON = 1e-9f;
1645
1678
 
1646
1679
  if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
@@ -1652,6 +1685,9 @@ struct llama_hparams {
1652
1685
  }
1653
1686
 
1654
1687
  uint32_t n_gqa() const {
1688
+ if (n_head_kv == 0) {
1689
+ return 0;
1690
+ }
1655
1691
  return n_head/n_head_kv;
1656
1692
  }
1657
1693
 
@@ -1662,11 +1698,24 @@ struct llama_hparams {
1662
1698
  uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
1663
1699
  return n_embd_head_v * n_head_kv;
1664
1700
  }
1701
+
1702
+ uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
1703
+ // corresponds to Mamba's conv_states size
1704
+ // TODO: maybe support other convolution strides than 1
1705
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
1706
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
1707
+ }
1708
+
1709
+ uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
1710
+ // corresponds to Mamba's ssm_states size
1711
+ return ssm_d_state * ssm_d_inner;
1712
+ }
1665
1713
  };
1666
1714
 
1667
1715
  struct llama_cparams {
1668
1716
  uint32_t n_ctx; // context size used during inference
1669
1717
  uint32_t n_batch;
1718
+ uint32_t n_ubatch;
1670
1719
  uint32_t n_threads; // number of threads to use for generation
1671
1720
  uint32_t n_threads_batch; // number of threads to use for batch processing
1672
1721
 
@@ -1683,6 +1732,7 @@ struct llama_cparams {
1683
1732
  float defrag_thold;
1684
1733
 
1685
1734
  bool embeddings;
1735
+ bool causal_attn;
1686
1736
  bool offload_kqv;
1687
1737
 
1688
1738
  enum llama_pooling_type pooling_type;
@@ -1739,11 +1789,27 @@ struct llama_layer {
1739
1789
  struct ggml_tensor * ffn_down_b; // b2
1740
1790
  struct ggml_tensor * ffn_up_b; // b3
1741
1791
  struct ggml_tensor * ffn_act;
1792
+
1793
+ // mamba proj
1794
+ struct ggml_tensor * ssm_in;
1795
+ struct ggml_tensor * ssm_x;
1796
+ struct ggml_tensor * ssm_dt;
1797
+ struct ggml_tensor * ssm_out;
1798
+
1799
+ // mamba
1800
+ struct ggml_tensor * ssm_conv1d;
1801
+ struct ggml_tensor * ssm_a;
1802
+ struct ggml_tensor * ssm_d;
1803
+
1804
+ // mamba bias
1805
+ struct ggml_tensor * ssm_conv1d_b;
1806
+ struct ggml_tensor * ssm_dt_b;
1742
1807
  };
1743
1808
 
1744
1809
  struct llama_kv_cell {
1745
1810
  llama_pos pos = -1;
1746
1811
  llama_pos delta = 0;
1812
+ int32_t src = 0; // used by recurrent state models to copy states
1747
1813
 
1748
1814
  std::set<llama_seq_id> seq_id;
1749
1815
 
@@ -1764,6 +1830,9 @@ struct llama_kv_cell {
1764
1830
  struct llama_kv_cache {
1765
1831
  bool has_shift = false;
1766
1832
  bool do_defrag = false;
1833
+ bool do_copy = false;
1834
+ // with recurrent state models, a cell can hold the state for more than one past token
1835
+ bool recurrent = false;
1767
1836
 
1768
1837
  // Note: The value of head isn't only used to optimize searching
1769
1838
  // for a free KV slot. llama_decode_internal also uses it, so it
@@ -1943,8 +2012,7 @@ struct llama_context {
1943
2012
  ggml_vk_free_cpu_assist();
1944
2013
  #endif
1945
2014
 
1946
- ggml_backend_buffer_free(buf_input);
1947
- ggml_free(ctx_input);
2015
+ ggml_backend_buffer_free(buf_output);
1948
2016
  }
1949
2017
 
1950
2018
  llama_cparams cparams;
@@ -1970,12 +2038,20 @@ struct llama_context {
1970
2038
  int64_t t_p_eval_us = 0;
1971
2039
  int64_t t_eval_us = 0;
1972
2040
 
2041
+ int64_t t_compute_start_us = 0;
2042
+ int64_t n_queued_tokens = 0;
2043
+
1973
2044
  int32_t n_sample = 0; // number of tokens sampled
1974
2045
  int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
1975
2046
  int32_t n_eval = 0; // number of eval calls
1976
2047
 
1977
- // logits output (2-dimensional array: [n_tokens][n_vocab])
1978
- std::vector<float> logits;
2048
+ // host buffer for the model output (logits and embeddings)
2049
+ ggml_backend_buffer_t buf_output = nullptr;
2050
+
2051
+ // decode output (2-dimensional array: [n_tokens][n_vocab])
2052
+ size_t logits_size = 0;
2053
+ float * logits = nullptr;
2054
+
1979
2055
  #ifndef NDEBUG
1980
2056
  // guard against access to unset logits
1981
2057
  std::vector<bool> logits_valid;
@@ -1984,7 +2060,8 @@ struct llama_context {
1984
2060
 
1985
2061
  // embeddings output (2-dimensional array: [n_tokens][n_embd])
1986
2062
  // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
1987
- std::vector<float> embd;
2063
+ size_t embd_size = 0;
2064
+ float * embd = nullptr;
1988
2065
 
1989
2066
  // sequence embeddings output (map of [n_embd] vectors)
1990
2067
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@@ -1998,16 +2075,17 @@ struct llama_context {
1998
2075
  void * abort_callback_data = nullptr;
1999
2076
 
2000
2077
  // input tensors
2001
- ggml_backend_buffer_t buf_input = nullptr;
2002
- ggml_context * ctx_input = nullptr;
2003
2078
  struct ggml_tensor * inp_tokens; // I32 [n_batch]
2004
2079
  struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2005
2080
  struct ggml_tensor * inp_pos; // I32 [n_batch]
2006
- struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
2007
- struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
2008
- struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
2081
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2082
+ struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
2083
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2009
2084
  struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2010
2085
  struct ggml_tensor * inp_cls; // I32 [n_batch]
2086
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2087
+ struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
2088
+ struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2011
2089
 
2012
2090
  #ifdef GGML_USE_MPI
2013
2091
  ggml_mpi_context * ctx_mpi = NULL;
@@ -2023,25 +2101,42 @@ static bool llama_kv_cache_init(
2023
2101
  const llama_model & model,
2024
2102
  ggml_type type_k,
2025
2103
  ggml_type type_v,
2026
- uint32_t n_ctx,
2104
+ uint32_t kv_size,
2027
2105
  bool offload) {
2028
2106
  const struct llama_hparams & hparams = model.hparams;
2029
2107
 
2030
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
2031
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
2108
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
2109
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
2032
2110
  const int64_t n_layer = hparams.n_layer;
2033
2111
 
2034
2112
  cache.has_shift = false;
2035
2113
 
2114
+ // TODO: find a nicer way to add other recurrent model architectures
2115
+ cache.recurrent = model.arch == LLM_ARCH_MAMBA;
2116
+
2117
+ // TODO: support mixed reccurent Transformer architectues
2118
+ // NOTE: (!a || b) is a logical implication (a -> b)
2119
+ GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
2120
+ GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
2121
+ GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
2122
+ GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
2123
+
2036
2124
  cache.head = 0;
2037
- cache.size = n_ctx;
2125
+ cache.size = kv_size;
2038
2126
  cache.used = 0;
2039
2127
 
2040
2128
  cache.type_k = type_k;
2041
2129
  cache.type_v = type_v;
2042
2130
 
2043
2131
  cache.cells.clear();
2044
- cache.cells.resize(n_ctx);
2132
+ cache.cells.resize(kv_size);
2133
+
2134
+ if (cache.recurrent) {
2135
+ // init state copy sources
2136
+ for (uint32_t i = 0; i < cache.size; ++i) {
2137
+ cache.cells[i].src = i;
2138
+ }
2139
+ }
2045
2140
 
2046
2141
  #ifdef GGML_USE_CLBLAST
2047
2142
  offload = false;
@@ -2080,8 +2175,8 @@ static bool llama_kv_cache_init(
2080
2175
 
2081
2176
  for (int i = 0; i < (int) n_layer; i++) {
2082
2177
  struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
2083
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
2084
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
2178
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
2179
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
2085
2180
  ggml_format_name(k, "cache_k_l%d", i);
2086
2181
  ggml_format_name(v, "cache_v_l%d", i);
2087
2182
  cache.k_l.push_back(k);
@@ -2115,6 +2210,54 @@ static bool llama_kv_cache_find_slot(
2115
2210
  const uint32_t n_ctx = cache.size;
2116
2211
  const uint32_t n_tokens = batch.n_tokens;
2117
2212
 
2213
+ if (cache.recurrent) {
2214
+ // For recurrent state architectures (like Mamba),
2215
+ // each KV cache cell can store the state for a whole sequence.
2216
+
2217
+ llama_seq_id min = cache.size - 1;
2218
+ llama_seq_id max = 0;
2219
+
2220
+ for (uint32_t i = 0; i < n_tokens; ++i) {
2221
+ for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
2222
+ llama_seq_id seq_id = batch.seq_id[i][j];
2223
+ // make sure it's a valid seq_id
2224
+ if ((uint32_t) seq_id < cache.size) {
2225
+ if (seq_id > max) {
2226
+ max = seq_id;
2227
+ }
2228
+ if (seq_id < min) {
2229
+ min = seq_id;
2230
+ }
2231
+ // Assuming the tokens are in-order
2232
+ if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
2233
+ // What should happen when the pos backtracks or skips a value?
2234
+ // Clearing the state mid-batch would require special-casing which isn't done.
2235
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
2236
+ __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2237
+ }
2238
+ if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
2239
+ cache.used += 1;
2240
+ }
2241
+ cache.cells[seq_id].pos = batch.pos[i];
2242
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
2243
+ } else {
2244
+ // too big seq_id
2245
+ // TODO: would it be possible to resize the KV cache size instead?
2246
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
2247
+ return false;
2248
+ }
2249
+ }
2250
+ }
2251
+
2252
+ // allow getting the range of used cells, from head to head + n
2253
+ cache.head = min;
2254
+ cache.n = max - min + 1;
2255
+
2256
+ // sanity check
2257
+ return max >= min;
2258
+ }
2259
+ // otherwise, one cell per token.
2260
+
2118
2261
  if (n_tokens > n_ctx) {
2119
2262
  LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
2120
2263
  return false;
@@ -2184,7 +2327,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
2184
2327
  cache.used = 0;
2185
2328
  }
2186
2329
 
2187
- static void llama_kv_cache_seq_rm(
2330
+ static bool llama_kv_cache_seq_rm(
2188
2331
  struct llama_kv_cache & cache,
2189
2332
  llama_seq_id seq_id,
2190
2333
  llama_pos p0,
@@ -2194,6 +2337,25 @@ static void llama_kv_cache_seq_rm(
2194
2337
  if (p0 < 0) p0 = 0;
2195
2338
  if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2196
2339
 
2340
+ // models like Mamba can't have a state partially erased
2341
+ if (cache.recurrent) {
2342
+ if (seq_id >= (int64_t) cache.size) {
2343
+ // could be fatal
2344
+ return false;
2345
+ }
2346
+ if (0 <= seq_id) {
2347
+ // partial intersection is invalid
2348
+ if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
2349
+ return false;
2350
+ }
2351
+ } else {
2352
+ // seq_id is negative, then the range should include everything or nothing
2353
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2354
+ return false;
2355
+ }
2356
+ }
2357
+ }
2358
+
2197
2359
  for (uint32_t i = 0; i < cache.size; ++i) {
2198
2360
  if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
2199
2361
  if (seq_id < 0) {
@@ -2215,6 +2377,8 @@ static void llama_kv_cache_seq_rm(
2215
2377
 
2216
2378
  // If we freed up a slot, set head to it so searching can start there.
2217
2379
  if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
2380
+
2381
+ return true;
2218
2382
  }
2219
2383
 
2220
2384
  static void llama_kv_cache_seq_cp(
@@ -2226,6 +2390,29 @@ static void llama_kv_cache_seq_cp(
2226
2390
  if (p0 < 0) p0 = 0;
2227
2391
  if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2228
2392
 
2393
+ if (cache.recurrent) {
2394
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2395
+ seq_id_src = cache.cells[seq_id_src].src;
2396
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2397
+ // intent to "copy from"
2398
+ // supports copy chains thanks to taking the source of the source
2399
+ cache.cells[seq_id_dst].src = seq_id_src;
2400
+
2401
+ // preserve the "keep or clear" status of the copied sequence
2402
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2403
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2404
+ } else {
2405
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
2406
+ }
2407
+
2408
+ cache.do_copy = true;
2409
+
2410
+ cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
2411
+ }
2412
+ return;
2413
+ }
2414
+ // otherwise, this is the KV cache of a Transformer-like model
2415
+
2229
2416
  cache.head = 0;
2230
2417
 
2231
2418
  for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2265,6 +2452,17 @@ static void llama_kv_cache_seq_add(
2265
2452
  if (p0 < 0) p0 = 0;
2266
2453
  if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2267
2454
 
2455
+ if (cache.recurrent) {
2456
+ // for Mamba-like models, only the pos needs to be shifted
2457
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2458
+ llama_kv_cell & cell = cache.cells[seq_id];
2459
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2460
+ cell.pos += delta;
2461
+ }
2462
+ }
2463
+ return;
2464
+ }
2465
+
2268
2466
  for (uint32_t i = 0; i < cache.size; ++i) {
2269
2467
  if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
2270
2468
  cache.has_shift = true;
@@ -2298,6 +2496,17 @@ static void llama_kv_cache_seq_div(
2298
2496
  if (p0 < 0) p0 = 0;
2299
2497
  if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2300
2498
 
2499
+ if (cache.recurrent) {
2500
+ // for Mamba-like models, only the pos needs to be changed
2501
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2502
+ llama_kv_cell & cell = cache.cells[seq_id];
2503
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2504
+ cell.pos /= d;
2505
+ }
2506
+ }
2507
+ return;
2508
+ }
2509
+
2301
2510
  for (uint32_t i = 0; i < cache.size; ++i) {
2302
2511
  if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
2303
2512
  cache.has_shift = true;
@@ -3035,10 +3244,11 @@ static const char * llama_model_type_name(e_model type) {
3035
3244
 
3036
3245
  static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
3037
3246
  switch (type) {
3038
- case LLAMA_VOCAB_TYPE_SPM: return "SPM";
3039
- case LLAMA_VOCAB_TYPE_BPE: return "BPE";
3040
- case LLAMA_VOCAB_TYPE_WPM: return "WPM";
3041
- default: return "unknown";
3247
+ case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
3248
+ case LLAMA_VOCAB_TYPE_SPM: return "SPM";
3249
+ case LLAMA_VOCAB_TYPE_BPE: return "BPE";
3250
+ case LLAMA_VOCAB_TYPE_WPM: return "WPM";
3251
+ default: return "unknown";
3042
3252
  }
3043
3253
  }
3044
3254
 
@@ -3070,14 +3280,14 @@ static void llm_load_hparams(
3070
3280
  ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
3071
3281
 
3072
3282
  // get hparams kv
3073
- ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
3074
- ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
3075
- ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
3076
- ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
3077
- ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
3078
- ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
3079
- ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
3080
- ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
3283
+ ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
3284
+ ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
3285
+ ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
3286
+ ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
3287
+ ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
3288
+ ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
3289
+ ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
3290
+ ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
3081
3291
 
3082
3292
  GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
3083
3293
  GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
@@ -3117,7 +3327,7 @@ static void llm_load_hparams(
3117
3327
 
3118
3328
  // sanity check for n_rot (optional)
3119
3329
  {
3120
- hparams.n_rot = hparams.n_embd / hparams.n_head;
3330
+ hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
3121
3331
 
3122
3332
  ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
3123
3333
 
@@ -3130,10 +3340,10 @@ static void llm_load_hparams(
3130
3340
  // gpt-j n_rot = rotary_dim
3131
3341
  }
3132
3342
 
3133
- hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
3343
+ hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
3134
3344
  ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
3135
3345
 
3136
- hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
3346
+ hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
3137
3347
  ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
3138
3348
 
3139
3349
  // arch-specific KVs
@@ -3383,6 +3593,36 @@ static void llm_load_hparams(
3383
3593
  default: model.type = e_model::MODEL_UNKNOWN;
3384
3594
  }
3385
3595
  } break;
3596
+ case LLM_ARCH_MAMBA:
3597
+ {
3598
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
3599
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
3600
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
3601
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
3602
+
3603
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
3604
+
3605
+ switch (hparams.n_layer) {
3606
+ case 24:
3607
+ switch (hparams.n_embd) {
3608
+ case 768: model.type = e_model::MODEL_SMALL; break;
3609
+ default: model.type = e_model::MODEL_UNKNOWN;
3610
+ } break;
3611
+ case 48:
3612
+ switch (hparams.n_embd) {
3613
+ case 1024: model.type = e_model::MODEL_MEDIUM; break;
3614
+ case 1536: model.type = e_model::MODEL_LARGE; break;
3615
+ case 2048: model.type = e_model::MODEL_XL; break;
3616
+ default: model.type = e_model::MODEL_UNKNOWN;
3617
+ } break;
3618
+ case 64:
3619
+ switch (hparams.n_embd) {
3620
+ case 2560: model.type = e_model::MODEL_3B; break;
3621
+ default: model.type = e_model::MODEL_UNKNOWN;
3622
+ } break;
3623
+ default: model.type = e_model::MODEL_UNKNOWN;
3624
+ }
3625
+ } break;
3386
3626
  default: (void)0;
3387
3627
  }
3388
3628
 
@@ -3408,30 +3648,25 @@ static void llm_load_vocab(
3408
3648
 
3409
3649
  const auto kv = LLM_KV(model.arch);
3410
3650
 
3411
- const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
3412
- if (token_idx == -1) {
3413
- throw std::runtime_error("cannot find tokenizer vocab in model file\n");
3414
- }
3415
-
3416
- const float * scores = nullptr;
3417
- const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
3418
- if (score_idx != -1) {
3419
- scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
3420
- }
3421
-
3422
- const int * toktypes = nullptr;
3423
- const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
3424
- if (toktype_idx != -1) {
3425
- toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
3426
- }
3427
-
3428
3651
  // determine vocab type
3429
3652
  {
3430
3653
  std::string tokenizer_name;
3431
3654
 
3432
3655
  ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
3433
3656
 
3434
- if (tokenizer_name == "llama") {
3657
+ if (tokenizer_name == "no_vocab") {
3658
+ vocab.type = LLAMA_VOCAB_TYPE_NONE;
3659
+
3660
+ // default special tokens
3661
+ vocab.special_bos_id = -1;
3662
+ vocab.special_eos_id = -1;
3663
+ vocab.special_unk_id = -1;
3664
+ vocab.special_sep_id = -1;
3665
+ vocab.special_pad_id = -1;
3666
+ vocab.linefeed_id = -1;
3667
+
3668
+ return;
3669
+ } else if (tokenizer_name == "llama") {
3435
3670
  vocab.type = LLAMA_VOCAB_TYPE_SPM;
3436
3671
 
3437
3672
  // default special tokens
@@ -3458,7 +3693,7 @@ static void llm_load_vocab(
3458
3693
 
3459
3694
  for (int i = 0; i < n_merges; i++) {
3460
3695
  const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
3461
- GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
3696
+ GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
3462
3697
 
3463
3698
  std::string first;
3464
3699
  std::string second;
@@ -3497,13 +3732,30 @@ static void llm_load_vocab(
3497
3732
  }
3498
3733
  }
3499
3734
 
3735
+ const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
3736
+ if (token_idx == -1) {
3737
+ throw std::runtime_error("cannot find tokenizer vocab in model file\n");
3738
+ }
3739
+
3740
+ const float * scores = nullptr;
3741
+ const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
3742
+ if (score_idx != -1) {
3743
+ scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
3744
+ }
3745
+
3746
+ const int * toktypes = nullptr;
3747
+ const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
3748
+ if (toktype_idx != -1) {
3749
+ toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
3750
+ }
3751
+
3500
3752
  const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
3501
3753
 
3502
3754
  vocab.id_to_token.resize(n_vocab);
3503
3755
 
3504
3756
  for (uint32_t i = 0; i < n_vocab; i++) {
3505
3757
  std::string word = gguf_get_arr_str(ctx, token_idx, i);
3506
- GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
3758
+ GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
3507
3759
 
3508
3760
  vocab.token_to_id[word] = i;
3509
3761
 
@@ -3695,6 +3947,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3695
3947
  LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
3696
3948
  LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
3697
3949
  LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
3950
+ LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
3698
3951
  LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
3699
3952
  LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
3700
3953
  LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
@@ -3702,6 +3955,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3702
3955
  LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
3703
3956
  LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
3704
3957
  LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
3958
+ LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
3959
+ LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
3960
+ LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
3961
+ LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
3705
3962
  LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
3706
3963
  LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
3707
3964
  if (ml.n_elements >= 1e12) {
@@ -3755,6 +4012,7 @@ static bool llm_load_tensors(
3755
4012
 
3756
4013
  // there is very little benefit to offloading the input layer, so always keep it on the CPU
3757
4014
  model.buft_input = llama_default_buffer_type_cpu(true);
4015
+ //model.buft_input = llama_default_buffer_type_offload(main_gpu);
3758
4016
 
3759
4017
  model.buft_layer.resize(n_layer);
3760
4018
 
@@ -3888,7 +4146,13 @@ static bool llm_load_tensors(
3888
4146
  {
3889
4147
  model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
3890
4148
  if (model.arch != LLM_ARCH_MINICPM){
3891
- model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
4149
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4150
+ // if output is NULL, init from the input tok embed
4151
+ if (model.output == NULL) {
4152
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4153
+ ml.n_created--; // artificial tensor
4154
+ ml.size_data += ggml_nbytes(model.output);
4155
+ }
3892
4156
  }
3893
4157
  }
3894
4158
 
@@ -4603,6 +4867,57 @@ static bool llm_load_tensors(
4603
4867
  layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff});
4604
4868
  }
4605
4869
  } break;
4870
+ case LLM_ARCH_MAMBA:
4871
+ {
4872
+ const int64_t d_conv = hparams.ssm_d_conv;
4873
+ const int64_t d_inner = hparams.ssm_d_inner;
4874
+ const int64_t d_state = hparams.ssm_d_state;
4875
+ const int64_t dt_rank = hparams.ssm_dt_rank;
4876
+ // only an expansion factor of 2 is supported for now
4877
+ GGML_ASSERT(2 * n_embd == d_inner);
4878
+
4879
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4880
+
4881
+ // output
4882
+ {
4883
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4884
+
4885
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4886
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
4887
+ if (model.output == NULL) {
4888
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4889
+ ml.n_created--; // artificial tensor
4890
+ ml.size_data += ggml_nbytes(model.output);
4891
+ }
4892
+ }
4893
+
4894
+ for (int i = 0; i < n_layer; ++i) {
4895
+ ggml_context * ctx_layer = ctx_for_layer(i);
4896
+ ggml_context * ctx_split = ctx_for_layer_split(i);
4897
+
4898
+ auto & layer = model.layers[i];
4899
+
4900
+ // norm
4901
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4902
+
4903
+ layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
4904
+
4905
+ layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
4906
+ layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
4907
+
4908
+ layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
4909
+
4910
+ layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
4911
+ layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
4912
+
4913
+ // no "weight" suffix for these
4914
+ layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
4915
+ layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
4916
+
4917
+ // out_proj
4918
+ layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
4919
+ }
4920
+ } break;
4606
4921
  default:
4607
4922
  throw std::runtime_error("unknown architecture");
4608
4923
  }
@@ -4723,7 +5038,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
4723
5038
 
4724
5039
  llm_load_print_meta(ml, model);
4725
5040
 
4726
- if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
5041
+ if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
5042
+ model.hparams.n_vocab != model.vocab.id_to_token.size()) {
4727
5043
  throw std::runtime_error("vocab size mismatch");
4728
5044
  }
4729
5045
 
@@ -4787,29 +5103,32 @@ enum llm_norm_type {
4787
5103
 
4788
5104
  static struct ggml_tensor * llm_build_inp_embd(
4789
5105
  struct ggml_context * ctx,
5106
+ struct llama_context & lctx,
4790
5107
  const llama_hparams & hparams,
4791
5108
  const llama_batch & batch,
4792
5109
  struct ggml_tensor * tok_embd,
4793
- struct ggml_tensor * inp_tokens,
4794
- struct ggml_tensor * inp_embd,
4795
5110
  const llm_build_cb & cb) {
4796
5111
  const int64_t n_embd = hparams.n_embd;
4797
5112
 
4798
5113
  struct ggml_tensor * inpL;
4799
5114
 
4800
5115
  if (batch.token) {
4801
- struct ggml_tensor * inp_tokens_v = ggml_view_1d(ctx, inp_tokens, batch.n_tokens, 0);
4802
- cb(inp_tokens, "inp_tokens", -1);
5116
+ lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
5117
+ cb(lctx.inp_tokens, "inp_tokens", -1);
5118
+ ggml_set_input(lctx.inp_tokens);
4803
5119
 
4804
- inpL = ggml_get_rows(ctx, tok_embd, inp_tokens_v);
5120
+ inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
4805
5121
  } else {
4806
5122
  #ifdef GGML_USE_MPI
4807
5123
  GGML_ASSERT(false && "not implemented");
4808
5124
  #endif
4809
-
4810
- inpL = ggml_view_2d(ctx, inp_embd, n_embd, batch.n_tokens, inp_embd->nb[1], 0);
5125
+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
5126
+ inpL = lctx.inp_embd;
5127
+ ggml_set_input(lctx.inp_embd);
4811
5128
  }
4812
5129
 
5130
+ cb(inpL, "inp_embd", -1);
5131
+
4813
5132
  return inpL;
4814
5133
  }
4815
5134
 
@@ -4828,6 +5147,8 @@ static void llm_build_kv_store(
4828
5147
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
4829
5148
  const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
4830
5149
 
5150
+ GGML_ASSERT(kv.size == n_ctx);
5151
+
4831
5152
  // compute the transposed [n_tokens, n_embd] V matrix
4832
5153
  struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
4833
5154
  //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
@@ -5037,6 +5358,8 @@ static struct ggml_tensor * llm_build_kqv(
5037
5358
  cb(kq, "kq_soft_max_ext", il);
5038
5359
  }
5039
5360
 
5361
+ GGML_ASSERT(kv.size == n_ctx);
5362
+
5040
5363
  // split cached v into n_head heads
5041
5364
  struct ggml_tensor * v =
5042
5365
  ggml_view_3d(ctx, kv.v_l[il],
@@ -5109,7 +5432,7 @@ static struct ggml_tensor * llm_build_kv(
5109
5432
 
5110
5433
  struct llm_build_context {
5111
5434
  const llama_model & model;
5112
- const llama_context & lctx;
5435
+ llama_context & lctx;
5113
5436
  const llama_hparams & hparams;
5114
5437
  const llama_cparams & cparams;
5115
5438
  const llama_batch & batch;
@@ -5184,8 +5507,8 @@ struct llm_build_context {
5184
5507
  norm_eps (hparams.f_norm_eps),
5185
5508
  norm_rms_eps (hparams.f_norm_rms_eps),
5186
5509
  n_tokens (batch.n_tokens),
5187
- n_kv (worst_case ? n_ctx : kv_self.n),
5188
- kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
5510
+ n_kv (worst_case ? kv_self.size : kv_self.n),
5511
+ kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
5189
5512
  n_orig_ctx (cparams.n_yarn_orig_ctx),
5190
5513
  pooling_type (cparams.pooling_type),
5191
5514
  rope_type (hparams.rope_type),
@@ -5202,6 +5525,18 @@ struct llm_build_context {
5202
5525
  };
5203
5526
 
5204
5527
  ctx0 = ggml_init(params);
5528
+
5529
+ lctx.inp_tokens = nullptr;
5530
+ lctx.inp_embd = nullptr;
5531
+ lctx.inp_pos = nullptr;
5532
+ lctx.inp_KQ_mask = nullptr;
5533
+ lctx.inp_KQ_pos = nullptr;
5534
+ lctx.inp_K_shift = nullptr;
5535
+ lctx.inp_mean = nullptr;
5536
+ lctx.inp_cls = nullptr;
5537
+ lctx.inp_s_copy = nullptr;
5538
+ lctx.inp_s_mask = nullptr;
5539
+ lctx.inp_s_seq = nullptr;
5205
5540
  }
5206
5541
 
5207
5542
  void free() {
@@ -5214,6 +5549,12 @@ struct llm_build_context {
5214
5549
  struct ggml_cgraph * build_k_shift() {
5215
5550
  struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5216
5551
 
5552
+ GGML_ASSERT(kv_self.size == n_ctx);
5553
+
5554
+ lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
5555
+ cb(lctx.inp_K_shift, "K_shift", -1);
5556
+ ggml_set_input(lctx.inp_K_shift);
5557
+
5217
5558
  for (int il = 0; il < n_layer; ++il) {
5218
5559
  struct ggml_tensor * tmp =
5219
5560
  // we rotate only the first n_rot dimensions
@@ -5232,6 +5573,29 @@ struct llm_build_context {
5232
5573
  return gf;
5233
5574
  }
5234
5575
 
5576
+ struct ggml_cgraph * build_s_copy() {
5577
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5578
+
5579
+ GGML_ASSERT(kv_self.recurrent);
5580
+
5581
+ struct ggml_tensor * state_copy = build_inp_s_copy();
5582
+
5583
+ for (int il = 0; il < n_layer; ++il) {
5584
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5585
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5586
+
5587
+ conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
5588
+ ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
5589
+
5590
+ // TODO: name the intermediate tensors with cb()
5591
+
5592
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5593
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5594
+ }
5595
+
5596
+ return gf;
5597
+ }
5598
+
5235
5599
  struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
5236
5600
  struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5237
5601
 
@@ -5281,6 +5645,66 @@ struct llm_build_context {
5281
5645
  return gf;
5282
5646
  }
5283
5647
 
5648
+ struct ggml_tensor * build_inp_pos() {
5649
+ lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
5650
+ cb(lctx.inp_pos, "inp_pos", -1);
5651
+ ggml_set_input(lctx.inp_pos);
5652
+ return lctx.inp_pos;
5653
+ }
5654
+
5655
+ struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
5656
+ if (causal) {
5657
+ lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
5658
+ } else {
5659
+ lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
5660
+ }
5661
+ cb(lctx.inp_KQ_mask, "KQ_mask", -1);
5662
+ ggml_set_input(lctx.inp_KQ_mask);
5663
+ return lctx.inp_KQ_mask;
5664
+ }
5665
+
5666
+ struct ggml_tensor * build_inp_KQ_pos() {
5667
+ lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
5668
+ cb(lctx.inp_KQ_pos, "KQ_pos", -1);
5669
+ ggml_set_input(lctx.inp_KQ_pos);
5670
+ return lctx.inp_KQ_pos;
5671
+ }
5672
+
5673
+ struct ggml_tensor * build_inp_mean() {
5674
+ lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
5675
+ cb(lctx.inp_mean, "inp_mean", -1);
5676
+ ggml_set_input(lctx.inp_mean);
5677
+ return lctx.inp_mean;
5678
+ }
5679
+
5680
+ struct ggml_tensor * build_inp_cls() {
5681
+ lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
5682
+ cb(lctx.inp_cls, "inp_cls", -1);
5683
+ ggml_set_input(lctx.inp_cls);
5684
+ return lctx.inp_cls;
5685
+ }
5686
+
5687
+ struct ggml_tensor * build_inp_s_copy() {
5688
+ lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5689
+ cb(lctx.inp_s_copy, "inp_s_copy", -1);
5690
+ ggml_set_input(lctx.inp_s_copy);
5691
+ return lctx.inp_s_copy;
5692
+ }
5693
+
5694
+ struct ggml_tensor * build_inp_s_mask() {
5695
+ lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
5696
+ cb(lctx.inp_s_mask, "inp_s_mask", -1);
5697
+ ggml_set_input(lctx.inp_s_mask);
5698
+ return lctx.inp_s_mask;
5699
+ }
5700
+
5701
+ struct ggml_tensor * build_inp_s_seq() {
5702
+ lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
5703
+ cb(lctx.inp_s_seq, "inp_s_seq", -1);
5704
+ ggml_set_input(lctx.inp_s_seq);
5705
+ return lctx.inp_s_seq;
5706
+ }
5707
+
5284
5708
  struct ggml_cgraph * build_llama() {
5285
5709
  struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5286
5710
 
@@ -5291,16 +5715,13 @@ struct llm_build_context {
5291
5715
  struct ggml_tensor * cur;
5292
5716
  struct ggml_tensor * inpL;
5293
5717
 
5294
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
5295
- cb(inpL, "inp_embd", -1);
5718
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5296
5719
 
5297
5720
  // inp_pos - contains the positions
5298
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5299
- cb(inp_pos, "inp_pos", -1);
5721
+ struct ggml_tensor * inp_pos = build_inp_pos();
5300
5722
 
5301
5723
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5302
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
5303
- cb(KQ_mask, "KQ_mask", -1);
5724
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
5304
5725
 
5305
5726
  for (int il = 0; il < n_layer; ++il) {
5306
5727
  struct ggml_tensor * inpSA = inpL;
@@ -5352,7 +5773,6 @@ struct llm_build_context {
5352
5773
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5353
5774
  model.layers[il].wo, model.layers[il].bo,
5354
5775
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5355
- cb(cur, "kqv_out", il);
5356
5776
  }
5357
5777
 
5358
5778
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -5470,20 +5890,16 @@ struct llm_build_context {
5470
5890
  struct ggml_tensor * cur;
5471
5891
  struct ggml_tensor * inpL;
5472
5892
 
5473
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
5474
- cb(inpL, "inp_embd", -1);
5893
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5475
5894
 
5476
5895
  // inp_pos - contains the positions
5477
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5478
- cb(inp_pos, "inp_pos", -1);
5896
+ struct ggml_tensor * inp_pos = build_inp_pos();
5479
5897
 
5480
5898
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5481
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
5482
- cb(KQ_mask, "KQ_mask", -1);
5899
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
5483
5900
 
5484
5901
  // positions of the tokens in the KV cache
5485
- struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
5486
- cb(KQ_pos, "KQ_pos", -1);
5902
+ struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
5487
5903
 
5488
5904
  for (int il = 0; il < n_layer; ++il) {
5489
5905
  struct ggml_tensor * inpSA = inpL;
@@ -5531,7 +5947,6 @@ struct llm_build_context {
5531
5947
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5532
5948
  model.layers[il].wo, NULL,
5533
5949
  Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5534
- cb(cur, "kqv_out", il);
5535
5950
  }
5536
5951
 
5537
5952
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -5587,16 +6002,13 @@ struct llm_build_context {
5587
6002
  struct ggml_tensor * cur;
5588
6003
  struct ggml_tensor * inpL;
5589
6004
 
5590
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
5591
- cb(inpL, "inp_embd", -1);
6005
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5592
6006
 
5593
6007
  // inp_pos - contains the positions
5594
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5595
- cb(inp_pos, "inp_pos", -1);
6008
+ struct ggml_tensor * inp_pos = build_inp_pos();
5596
6009
 
5597
6010
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5598
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
5599
- cb(KQ_mask, "KQ_mask", -1);
6011
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
5600
6012
 
5601
6013
  for (int il = 0; il < n_layer; ++il) {
5602
6014
  struct ggml_tensor * attn_norm;
@@ -5650,7 +6062,6 @@ struct llm_build_context {
5650
6062
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5651
6063
  model.layers[il].wo, NULL,
5652
6064
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5653
- cb(cur, "kqv_out", il);
5654
6065
  }
5655
6066
 
5656
6067
  struct ggml_tensor * ffn_inp = cur;
@@ -5701,21 +6112,17 @@ struct llm_build_context {
5701
6112
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
5702
6113
 
5703
6114
  struct ggml_tensor * cur;
5704
- struct ggml_tensor * pos;
5705
6115
  struct ggml_tensor * inpL;
5706
6116
 
5707
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
5708
- cb(inpL, "inp_embd", -1);
6117
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5709
6118
 
5710
6119
  // inp_pos - contains the positions
5711
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5712
- cb(inp_pos, "inp_pos", -1);
6120
+ struct ggml_tensor * inp_pos = build_inp_pos();
5713
6121
 
5714
6122
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5715
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
5716
- cb(KQ_mask, "KQ_mask", -1);
6123
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
5717
6124
 
5718
- pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
6125
+ struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
5719
6126
  cb(pos, "pos_embd", -1);
5720
6127
 
5721
6128
  inpL = ggml_add(ctx0, inpL, pos);
@@ -5749,7 +6156,6 @@ struct llm_build_context {
5749
6156
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5750
6157
  model.layers[il].wo, model.layers[il].bo,
5751
6158
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5752
- cb(cur, "kqv_out", il);
5753
6159
  }
5754
6160
 
5755
6161
  // add the input
@@ -5801,16 +6207,13 @@ struct llm_build_context {
5801
6207
  struct ggml_tensor * cur;
5802
6208
  struct ggml_tensor * inpL;
5803
6209
 
5804
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
5805
- cb(inpL, "inp_embd", -1);
6210
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5806
6211
 
5807
6212
  // inp_pos - contains the positions
5808
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5809
- cb(inp_pos, "inp_pos", -1);
6213
+ struct ggml_tensor * inp_pos = build_inp_pos();
5810
6214
 
5811
6215
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5812
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
5813
- cb(KQ_mask, "KQ_mask", -1);
6216
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
5814
6217
 
5815
6218
  for (int il = 0; il < n_layer; ++il) {
5816
6219
  struct ggml_tensor * residual = inpL;
@@ -5950,7 +6353,6 @@ struct llm_build_context {
5950
6353
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5951
6354
  model.layers[il].wo, model.layers[il].bo,
5952
6355
  Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5953
- cb(cur, "kqv_out", il);
5954
6356
  }
5955
6357
 
5956
6358
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
@@ -6004,16 +6406,13 @@ struct llm_build_context {
6004
6406
  struct ggml_tensor * cur;
6005
6407
  struct ggml_tensor * inpL;
6006
6408
 
6007
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6008
- cb(inpL, "inp_embd", -1);
6409
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6009
6410
 
6010
6411
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6011
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6012
- cb(KQ_mask, "KQ_mask", -1);
6412
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6013
6413
 
6014
6414
  // positions of the tokens in the KV cache
6015
- struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
6016
- cb(KQ_pos, "KQ_pos", -1);
6415
+ struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
6017
6416
 
6018
6417
  for (int il = 0; il < n_layer; ++il) {
6019
6418
  struct ggml_tensor * inpSA = inpL;
@@ -6043,7 +6442,6 @@ struct llm_build_context {
6043
6442
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6044
6443
  model.layers[il].wo, NULL,
6045
6444
  Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6046
- cb(cur, "kqv_out", il);
6047
6445
  }
6048
6446
 
6049
6447
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6099,15 +6497,12 @@ struct llm_build_context {
6099
6497
  struct ggml_tensor * cur;
6100
6498
  struct ggml_tensor * inpL;
6101
6499
 
6102
- // get input vectors with right size
6103
- const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
6104
-
6105
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6106
- struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
6107
- struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
6500
+ struct ggml_tensor * inp_pos = build_inp_pos();
6501
+ struct ggml_tensor * inp_mean = build_inp_mean();
6502
+ struct ggml_tensor * inp_cls = build_inp_cls();
6108
6503
 
6109
6504
  // construct input embeddings (token, type, position)
6110
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6505
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6111
6506
 
6112
6507
  // token types are hardcoded to zero ("Sentence A")
6113
6508
  struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
@@ -6122,8 +6517,7 @@ struct llm_build_context {
6122
6517
  cb(inpL, "inp_norm", -1);
6123
6518
 
6124
6519
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6125
- struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
6126
- cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
6520
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
6127
6521
 
6128
6522
  // iterate layers
6129
6523
  for (int il = 0; il < n_layer; ++il) {
@@ -6285,16 +6679,13 @@ struct llm_build_context {
6285
6679
  struct ggml_tensor * cur;
6286
6680
  struct ggml_tensor * inpL;
6287
6681
 
6288
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6289
- cb(inpL, "inp_embd", -1);
6682
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6290
6683
 
6291
6684
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6292
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6293
- cb(KQ_mask, "KQ_mask", -1);
6685
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6294
6686
 
6295
6687
  // positions of the tokens in the KV cache
6296
- struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
6297
- cb(KQ_pos, "KQ_pos", -1);
6688
+ struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
6298
6689
 
6299
6690
  inpL = llm_build_norm(ctx0, inpL, hparams,
6300
6691
  model.tok_norm,
@@ -6330,7 +6721,6 @@ struct llm_build_context {
6330
6721
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6331
6722
  model.layers[il].wo, model.layers[il].bo,
6332
6723
  Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6333
- cb(cur, "kqv_out", il);
6334
6724
  }
6335
6725
 
6336
6726
  // Add the input
@@ -6382,16 +6772,13 @@ struct llm_build_context {
6382
6772
  struct ggml_tensor * cur;
6383
6773
  struct ggml_tensor * inpL;
6384
6774
 
6385
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6386
- cb(inpL, "inp_embd", -1);
6775
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6387
6776
 
6388
6777
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6389
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6390
- cb(KQ_mask, "KQ_mask", -1);
6778
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6391
6779
 
6392
6780
  // positions of the tokens in the KV cache
6393
- struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
6394
- cb(KQ_pos, "KQ_pos", -1);
6781
+ struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
6395
6782
 
6396
6783
  for (int il = 0; il < n_layer; ++il) {
6397
6784
  struct ggml_tensor * attn_norm;
@@ -6432,7 +6819,6 @@ struct llm_build_context {
6432
6819
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6433
6820
  model.layers[il].wo, model.layers[il].bo,
6434
6821
  Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6435
- cb(cur, "kqv_out", il);
6436
6822
  }
6437
6823
 
6438
6824
  // Add the input
@@ -6487,16 +6873,13 @@ struct llm_build_context {
6487
6873
  struct ggml_tensor * cur;
6488
6874
  struct ggml_tensor * inpL;
6489
6875
 
6490
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6491
- cb(inpL, "inp_embd", -1);
6876
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6492
6877
 
6493
6878
  // inp_pos - contains the positions
6494
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6495
- cb(inp_pos, "inp_pos", -1);
6879
+ struct ggml_tensor * inp_pos = build_inp_pos();
6496
6880
 
6497
6881
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6498
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6499
- cb(KQ_mask, "KQ_mask", -1);
6882
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6500
6883
 
6501
6884
  for (int il = 0; il < n_layer; ++il) {
6502
6885
  struct ggml_tensor * inpSA = inpL;
@@ -6549,7 +6932,6 @@ struct llm_build_context {
6549
6932
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6550
6933
  model.layers[il].wo, NULL,
6551
6934
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6552
- cb(cur, "kqv_out", il);
6553
6935
  }
6554
6936
 
6555
6937
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6605,16 +6987,13 @@ struct llm_build_context {
6605
6987
  struct ggml_tensor * cur;
6606
6988
  struct ggml_tensor * inpL;
6607
6989
 
6608
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6609
- cb(inpL, "inp_embd", -1);
6990
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6610
6991
 
6611
6992
  // inp_pos - contains the positions
6612
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6613
- cb(inp_pos, "inp_pos", -1);
6993
+ struct ggml_tensor * inp_pos = build_inp_pos();
6614
6994
 
6615
6995
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6616
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6617
- cb(KQ_mask, "KQ_mask", -1);
6996
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6618
6997
 
6619
6998
  for (int il = 0; il < n_layer; ++il) {
6620
6999
  struct ggml_tensor * inpSA = inpL;
@@ -6659,7 +7038,6 @@ struct llm_build_context {
6659
7038
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6660
7039
  model.layers[il].wo, NULL,
6661
7040
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6662
- cb(cur, "kqv_out", il);
6663
7041
  }
6664
7042
 
6665
7043
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6714,16 +7092,13 @@ struct llm_build_context {
6714
7092
  struct ggml_tensor * cur;
6715
7093
  struct ggml_tensor * inpL;
6716
7094
 
6717
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6718
- cb(inpL, "inp_embd", -1);
7095
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6719
7096
 
6720
7097
  // inp_pos - contains the positions
6721
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6722
- cb(inp_pos, "inp_pos", -1);
7098
+ struct ggml_tensor * inp_pos = build_inp_pos();
6723
7099
 
6724
7100
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6725
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6726
- cb(KQ_mask, "KQ_mask", -1);
7101
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6727
7102
 
6728
7103
  for (int il = 0; il < n_layer; ++il) {
6729
7104
  struct ggml_tensor * inpSA = inpL;
@@ -6775,7 +7150,6 @@ struct llm_build_context {
6775
7150
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6776
7151
  model.layers[il].wo, model.layers[il].bo,
6777
7152
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6778
- cb(cur, "kqv_out", il);
6779
7153
  }
6780
7154
 
6781
7155
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6830,16 +7204,13 @@ struct llm_build_context {
6830
7204
  struct ggml_tensor * ffn_output;
6831
7205
  struct ggml_tensor * inpL;
6832
7206
 
6833
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6834
- cb(inpL, "inp_embd", -1);
7207
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6835
7208
 
6836
7209
  // inp_pos - contains the positions
6837
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6838
- cb(inp_pos, "inp_pos", -1);
7210
+ struct ggml_tensor * inp_pos = build_inp_pos();
6839
7211
 
6840
7212
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6841
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6842
- cb(KQ_mask, "KQ_mask", -1);
7213
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6843
7214
 
6844
7215
  for (int il = 0; il < n_layer; ++il) {
6845
7216
  attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
@@ -6897,7 +7268,6 @@ struct llm_build_context {
6897
7268
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6898
7269
  model.layers[il].wo, model.layers[il].bo,
6899
7270
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
6900
- cb(cur, "kqv_out", il);
6901
7271
  }
6902
7272
 
6903
7273
  // FF
@@ -6947,16 +7317,13 @@ struct llm_build_context {
6947
7317
  struct ggml_tensor * cur;
6948
7318
  struct ggml_tensor * inpL;
6949
7319
 
6950
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6951
- cb(inpL, "inp_embd", -1);
7320
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
6952
7321
 
6953
7322
  // inp_pos - contains the positions
6954
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6955
- cb(inp_pos, "inp_pos", -1);
7323
+ struct ggml_tensor * inp_pos = build_inp_pos();
6956
7324
 
6957
7325
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6958
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6959
- cb(KQ_mask, "KQ_mask", -1);
7326
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
6960
7327
 
6961
7328
  for (int il = 0; il < n_layer; ++il) {
6962
7329
 
@@ -6995,7 +7362,6 @@ struct llm_build_context {
6995
7362
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6996
7363
  model.layers[il].wo, NULL,
6997
7364
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6998
- cb(cur, "kqv_out", il);
6999
7365
  }
7000
7366
  struct ggml_tensor * sa_out = cur;
7001
7367
 
@@ -7049,16 +7415,13 @@ struct llm_build_context {
7049
7415
  struct ggml_tensor * pos;
7050
7416
  struct ggml_tensor * inpL;
7051
7417
 
7052
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7053
- cb(inpL, "inp_embd", -1);
7418
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7054
7419
 
7055
7420
  // inp_pos - contains the positions
7056
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7057
- cb(inp_pos, "inp_pos", -1);
7421
+ struct ggml_tensor * inp_pos = build_inp_pos();
7058
7422
 
7059
7423
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7060
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7061
- cb(KQ_mask, "KQ_mask", -1);
7424
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7062
7425
 
7063
7426
  pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
7064
7427
  cb(pos, "pos_embd", -1);
@@ -7094,7 +7457,6 @@ struct llm_build_context {
7094
7457
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7095
7458
  model.layers[il].wo, model.layers[il].bo,
7096
7459
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7097
- cb(cur, "kqv_out", il);
7098
7460
  }
7099
7461
 
7100
7462
  // add the input
@@ -7147,16 +7509,13 @@ struct llm_build_context {
7147
7509
  struct ggml_tensor * cur;
7148
7510
  struct ggml_tensor * inpL;
7149
7511
 
7150
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7151
- cb(inpL, "inp_embd", -1);
7512
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7152
7513
 
7153
7514
  // inp_pos - contains the positions
7154
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7155
- cb(inp_pos, "inp_pos", -1);
7515
+ struct ggml_tensor * inp_pos = build_inp_pos();
7156
7516
 
7157
7517
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7158
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7159
- cb(KQ_mask, "KQ_mask", -1);
7518
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7160
7519
 
7161
7520
  for (int il = 0; il < n_layer; ++il) {
7162
7521
  cur = llm_build_norm(ctx0, inpL, hparams,
@@ -7198,7 +7557,6 @@ struct llm_build_context {
7198
7557
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7199
7558
  model.layers[il].wo, model.layers[il].bo,
7200
7559
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7201
- cb(cur, "kqv_out", il);
7202
7560
  }
7203
7561
 
7204
7562
  // add the input
@@ -7250,16 +7608,13 @@ struct llm_build_context {
7250
7608
  struct ggml_tensor * cur;
7251
7609
  struct ggml_tensor * inpL;
7252
7610
 
7253
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7254
- cb(inpL, "inp_embd", -1);
7611
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7255
7612
 
7256
7613
  // inp_pos - contains the positions
7257
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7258
- cb(inp_pos, "inp_pos", -1);
7614
+ struct ggml_tensor * inp_pos = build_inp_pos();
7259
7615
 
7260
7616
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7261
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7262
- cb(KQ_mask, "KQ_mask", -1);
7617
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7263
7618
 
7264
7619
  for (int il = 0; il < n_layer; ++il) {
7265
7620
  struct ggml_tensor * inpSA = inpL;
@@ -7311,7 +7666,6 @@ struct llm_build_context {
7311
7666
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7312
7667
  model.layers[il].wo, NULL,
7313
7668
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7314
- cb(cur, "kqv_out", il);
7315
7669
  }
7316
7670
 
7317
7671
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7364,16 +7718,13 @@ struct llm_build_context {
7364
7718
  struct ggml_tensor * cur;
7365
7719
  struct ggml_tensor * inpL;
7366
7720
 
7367
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7368
- cb(inpL, "inp_embd", -1);
7721
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7369
7722
 
7370
7723
  // inp_pos - contains the positions
7371
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7372
- cb(inp_pos, "inp_pos", -1);
7724
+ struct ggml_tensor * inp_pos = build_inp_pos();
7373
7725
 
7374
7726
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7375
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7376
- cb(KQ_mask, "KQ_mask", -1);
7727
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7377
7728
 
7378
7729
  for (int il = 0; il < n_layer; ++il) {
7379
7730
  struct ggml_tensor * inpSA = inpL;
@@ -7425,7 +7776,6 @@ struct llm_build_context {
7425
7776
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7426
7777
  model.layers[il].wo, model.layers[il].bo,
7427
7778
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7428
- cb(cur, "kqv_out", il);
7429
7779
  }
7430
7780
 
7431
7781
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7487,20 +7837,17 @@ struct llm_build_context {
7487
7837
  struct ggml_tensor * cur;
7488
7838
  struct ggml_tensor * inpL;
7489
7839
 
7490
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7491
- cb(inpL, "inp_embd", -1);
7840
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7492
7841
 
7493
7842
  // scale the input embeddings
7494
7843
  inpL = ggml_scale(ctx0, inpL, scale_embd);
7495
7844
  cb(inpL, "inp_scaled", -1);
7496
7845
 
7497
7846
  // inp_pos - contains the positions
7498
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7499
- cb(inp_pos, "inp_pos", -1);
7847
+ struct ggml_tensor * inp_pos = build_inp_pos();
7500
7848
 
7501
7849
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7502
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7503
- cb(KQ_mask, "KQ_mask", -1);
7850
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7504
7851
 
7505
7852
  for (int il = 0; il < n_layer; ++il) {
7506
7853
  struct ggml_tensor * inpSA = inpL;
@@ -7552,7 +7899,6 @@ struct llm_build_context {
7552
7899
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7553
7900
  model.layers[il].wo, model.layers[il].bo,
7554
7901
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7555
- cb(cur, "kqv_out", il);
7556
7902
  }
7557
7903
 
7558
7904
  // scale_res - scale the hidden states for residual connection
@@ -7619,22 +7965,18 @@ struct llm_build_context {
7619
7965
  struct ggml_tensor * cur;
7620
7966
  struct ggml_tensor * inpL;
7621
7967
 
7622
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7623
- cb(inpL, "inp_embd", -1);
7968
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7624
7969
 
7625
7970
  inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
7626
7971
  cb(inpL, "inp_scaled", -1);
7627
7972
 
7628
7973
  // inp_pos - contains the positions
7629
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7630
- cb(inp_pos, "inp_pos", -1);
7974
+ struct ggml_tensor * inp_pos = build_inp_pos();
7631
7975
 
7632
7976
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7633
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7634
- cb(KQ_mask, "KQ_mask", -1);
7977
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7635
7978
 
7636
7979
  for (int il = 0; il < n_layer; ++il) {
7637
-
7638
7980
  // norm
7639
7981
  cur = llm_build_norm(ctx0, inpL, hparams,
7640
7982
  model.layers[il].attn_norm, NULL,
@@ -7671,7 +8013,6 @@ struct llm_build_context {
7671
8013
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7672
8014
  model.layers[il].wo, NULL,
7673
8015
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
7674
- cb(cur, "kqv_out", il);
7675
8016
  }
7676
8017
 
7677
8018
  struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
@@ -7726,16 +8067,13 @@ struct llm_build_context {
7726
8067
  struct ggml_tensor * cur;
7727
8068
  struct ggml_tensor * inpL;
7728
8069
 
7729
- inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7730
- cb(inpL, "inp_embd", -1);
8070
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7731
8071
 
7732
8072
  // inp_pos - contains the positions
7733
- struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7734
- cb(inp_pos, "inp_pos", -1);
8073
+ struct ggml_tensor * inp_pos = build_inp_pos();
7735
8074
 
7736
8075
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7737
- struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7738
- cb(KQ_mask, "KQ_mask", -1);
8076
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7739
8077
 
7740
8078
  for (int il = 0; il < n_layer; ++il) {
7741
8079
  struct ggml_tensor * inpSA = inpL;
@@ -7829,11 +8167,149 @@ struct llm_build_context {
7829
8167
 
7830
8168
  return gf;
7831
8169
  }
7832
- };
7833
8170
 
7834
- static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
7835
- llama_batch dummy;
7836
- dummy.n_tokens = 0;
8171
+ struct ggml_cgraph * build_mamba() {
8172
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8173
+
8174
+ const int64_t d_model = n_embd;
8175
+ const int64_t d_conv = hparams.ssm_d_conv;
8176
+ const int64_t d_inner = hparams.ssm_d_inner;
8177
+ GGML_ASSERT(2 * d_model == d_inner);
8178
+ const int64_t d_state = hparams.ssm_d_state;
8179
+ const int64_t dt_rank = hparams.ssm_dt_rank;
8180
+
8181
+ struct ggml_tensor * cur;
8182
+ struct ggml_tensor * inpL;
8183
+
8184
+ // {n_embd, n_tokens}
8185
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
8186
+
8187
+ struct ggml_tensor * state_mask = build_inp_s_mask();
8188
+ struct ggml_tensor * state_seq = build_inp_s_seq();
8189
+
8190
+ for (int il = 0; il < n_layer; ++il) {
8191
+ // (ab)using the KV cache to store the states
8192
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8193
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
8194
+
8195
+ // clear states of sequences which are starting at the beginning of this batch
8196
+ {
8197
+ conv_states = ggml_mul(ctx0,
8198
+ ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
8199
+ state_mask);
8200
+ ssm_states = ggml_mul(ctx0,
8201
+ ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
8202
+ state_mask);
8203
+ }
8204
+
8205
+ conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
8206
+ ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
8207
+
8208
+ // norm
8209
+ cur = llm_build_norm(ctx0, inpL, hparams,
8210
+ model.layers[il].attn_norm, NULL,
8211
+ LLM_NORM_RMS, cb, il);
8212
+ cb(cur, "attn_norm", il);
8213
+
8214
+ // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
8215
+ struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
8216
+ // split the above in two
8217
+ // => {d_inner, n_tokens}
8218
+ struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
8219
+ struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
8220
+
8221
+ // conv
8222
+ {
8223
+ // Custom operator which is needed only to ease simultaneous sequence processing.
8224
+ // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
8225
+ // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
8226
+ // then element-wise multiply that with the conv1d weigth,
8227
+ // then sum the elements of each row,
8228
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
8229
+ // then permute away the ne[0] dimension,
8230
+ // and then you're left with the resulting x tensor.
8231
+ // The new conv_states is the last (d_conv - 1) columns
8232
+ // of the last 3rd dimensional "layer" of the self-overlapping view.
8233
+ // For simultaneous sequences, it's more complicated.
8234
+ struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
8235
+
8236
+ // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
8237
+ ggml_build_forward_expand(gf,
8238
+ ggml_cpy(ctx0,
8239
+ ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
8240
+ ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
8241
+
8242
+ // extract x from x_conv
8243
+ x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
8244
+
8245
+ // bias
8246
+ x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
8247
+
8248
+ x = ggml_silu(ctx0, x);
8249
+ }
8250
+
8251
+ // ssm
8252
+ {
8253
+ // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
8254
+ struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
8255
+ // split
8256
+ struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
8257
+ struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
8258
+ struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
8259
+
8260
+ // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
8261
+ dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
8262
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
8263
+
8264
+ // Custom operator to optimize the parallel associative scan
8265
+ // as described in the Annex D of the Mamba paper.
8266
+ // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
8267
+ // because only a single tensor can be returned.
8268
+ struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
8269
+
8270
+ // store last states (the second part of y_ssm_states)
8271
+ ggml_build_forward_expand(gf,
8272
+ ggml_cpy(ctx0,
8273
+ ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
8274
+ ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
8275
+
8276
+ struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
8277
+
8278
+ // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
8279
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
8280
+ y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
8281
+
8282
+ // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
8283
+ cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
8284
+ }
8285
+
8286
+ // residual
8287
+ cur = ggml_add(ctx0, cur, inpL);
8288
+ cb(cur, "l_out", il);
8289
+
8290
+ // input for next layer
8291
+ inpL = cur;
8292
+ }
8293
+
8294
+ // final rmsnorm
8295
+ cur = llm_build_norm(ctx0, inpL, hparams,
8296
+ model.output_norm, NULL,
8297
+ LLM_NORM_RMS, cb, -1);
8298
+ cb(cur, "result_norm", -1);
8299
+
8300
+ // lm_head
8301
+ cur = ggml_mul_mat(ctx0, model.output, cur);
8302
+ cb(cur, "result_output", -1);
8303
+
8304
+ ggml_build_forward_expand(gf, cur);
8305
+
8306
+ return gf;
8307
+ }
8308
+ };
8309
+
8310
+ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
8311
+ llama_batch dummy;
8312
+ dummy.n_tokens = 0;
7837
8313
 
7838
8314
  llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
7839
8315
 
@@ -7865,6 +8341,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
7865
8341
  return result;
7866
8342
  }
7867
8343
 
8344
+ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
8345
+ llama_batch dummy;
8346
+ dummy.n_tokens = 0;
8347
+
8348
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
8349
+
8350
+ struct llm_build_context llm(lctx, dummy, cb, false);
8351
+
8352
+ llm.init();
8353
+
8354
+ struct ggml_cgraph * result = llm.build_s_copy();
8355
+
8356
+ llm.free();
8357
+
8358
+ return result;
8359
+ }
8360
+
7868
8361
  static struct ggml_cgraph * llama_build_graph(
7869
8362
  llama_context & lctx,
7870
8363
  const llama_batch & batch,
@@ -7882,7 +8375,18 @@ static struct ggml_cgraph * llama_build_graph(
7882
8375
  if (!lctx.cparams.offload_kqv) {
7883
8376
  if (strcmp(name, "kqv_merged_cont") == 0) {
7884
8377
  // all nodes between the KV store and the attention output are run on the CPU
7885
- ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
8378
+ ggml_backend_sched_set_tensor_backend(lctx.sched, cur, lctx.backend_cpu);
8379
+ }
8380
+ }
8381
+
8382
+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
8383
+ // to fix this, we assign the norm layer manually to the backend of its layer
8384
+ if (il != -1 && strcmp(name, "norm") == 0) {
8385
+ for (auto * backend : lctx.backends) {
8386
+ if (ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
8387
+ ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
8388
+ break;
8389
+ }
7886
8390
  }
7887
8391
  }
7888
8392
  };
@@ -7979,6 +8483,10 @@ static struct ggml_cgraph * llama_build_graph(
7979
8483
  {
7980
8484
  result = llm.build_starcoder2();
7981
8485
  } break;
8486
+ case LLM_ARCH_MAMBA:
8487
+ {
8488
+ result = llm.build_mamba();
8489
+ } break;
7982
8490
  default:
7983
8491
  GGML_ASSERT(false);
7984
8492
  }
@@ -7989,19 +8497,29 @@ static struct ggml_cgraph * llama_build_graph(
7989
8497
  }
7990
8498
 
7991
8499
  static void llama_set_k_shift(llama_context & lctx) {
7992
- const auto & cparams = lctx.cparams;
7993
-
7994
- const int64_t n_ctx = cparams.n_ctx;
8500
+ const int64_t kv_size = lctx.kv_self.size;
7995
8501
 
7996
8502
  assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
7997
8503
 
7998
8504
  int32_t * data = (int32_t *) lctx.inp_K_shift->data;
7999
8505
 
8000
- for (int i = 0; i < n_ctx; ++i) {
8506
+ for (int i = 0; i < kv_size; ++i) {
8001
8507
  data[i] = lctx.kv_self.cells[i].delta;
8002
8508
  }
8003
8509
  }
8004
8510
 
8511
+ static void llama_set_s_copy(llama_context & lctx) {
8512
+ const int64_t kv_size = lctx.kv_self.size;
8513
+
8514
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8515
+
8516
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8517
+
8518
+ for (int i = 0; i < kv_size; ++i) {
8519
+ data[i] = lctx.kv_self.cells[i].src;
8520
+ }
8521
+ }
8522
+
8005
8523
  static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8006
8524
  //
8007
8525
  // set input data
@@ -8024,58 +8542,74 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8024
8542
  ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
8025
8543
  }
8026
8544
 
8027
- if (batch.pos) {
8545
+ if (batch.pos && lctx.inp_pos) {
8028
8546
  const int64_t n_tokens = batch.n_tokens;
8029
8547
 
8030
8548
  ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
8031
8549
  }
8032
8550
 
8033
- if (hparams.causal_attn) {
8034
- const int64_t n_kv = kv_self.n;
8035
- const int64_t n_tokens = batch.n_tokens;
8551
+ GGML_ASSERT(
8552
+ (hparams.causal_attn || !cparams.causal_attn) &&
8553
+ "non-causal attention with generative models is not supported"
8554
+ );
8036
8555
 
8037
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8556
+ if (lctx.inp_KQ_mask) {
8557
+ // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8558
+ if (cparams.causal_attn) {
8559
+ const int64_t n_kv = kv_self.n;
8560
+ const int64_t n_tokens = batch.n_tokens;
8038
8561
 
8039
- float * data = (float *) lctx.inp_KQ_mask->data;
8562
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8040
8563
 
8041
- for (int h = 0; h < 1; ++h) {
8042
- for (int j = 0; j < n_tokens; ++j) {
8043
- const llama_pos pos = batch.pos[j];
8044
- const llama_seq_id seq_id = batch.seq_id[j][0];
8564
+ float * data = (float *) lctx.inp_KQ_mask->data;
8045
8565
 
8046
- for (int i = 0; i < n_kv; ++i) {
8047
- float f;
8048
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8049
- f = -INFINITY;
8050
- } else {
8051
- f = 0.0f;
8566
+ // For causal attention, use only the previous KV cells
8567
+ // of the correct sequence for each token of the batch.
8568
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8569
+ for (int h = 0; h < 1; ++h) {
8570
+ for (int j = 0; j < n_tokens; ++j) {
8571
+ const llama_pos pos = batch.pos[j];
8572
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8573
+
8574
+ for (int i = 0; i < n_kv; ++i) {
8575
+ float f;
8576
+ if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8577
+ f = -INFINITY;
8578
+ } else {
8579
+ f = 0.0f;
8580
+ }
8581
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8052
8582
  }
8053
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8054
8583
  }
8055
8584
  }
8056
- }
8057
- } else {
8058
- // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
8059
- const int64_t n_tokens = batch.n_tokens;
8585
+ } else {
8586
+ // when using kv cache, the mask needs to match the kv cache size
8587
+ const int64_t n_tokens = batch.n_tokens;
8588
+ const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8060
8589
 
8061
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8590
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8062
8591
 
8063
- float * data = (float *) lctx.inp_KQ_mask->data;
8592
+ float * data = (float *) lctx.inp_KQ_mask->data;
8064
8593
 
8065
- for (int h = 0; h < 1; ++h) {
8066
- for (int j = 0; j < n_tokens; ++j) {
8067
- const llama_seq_id seq_id = batch.seq_id[j][0];
8594
+ for (int h = 0; h < 1; ++h) {
8595
+ for (int j = 0; j < n_tokens; ++j) {
8596
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8068
8597
 
8069
- for (int i = 0; i < n_tokens; ++i) {
8070
- float f = -INFINITY;
8071
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8072
- if (batch.seq_id[i][s] == seq_id) {
8073
- f = 0.0f;
8074
- break;
8598
+ for (int i = 0; i < n_tokens; ++i) {
8599
+ float f = -INFINITY;
8600
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8601
+ if (batch.seq_id[i][s] == seq_id) {
8602
+ f = 0.0f;
8603
+ break;
8604
+ }
8075
8605
  }
8606
+
8607
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8076
8608
  }
8077
8609
 
8078
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
8610
+ for (int i = n_tokens; i < n_stride; ++i) {
8611
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8612
+ }
8079
8613
  }
8080
8614
  }
8081
8615
  }
@@ -8084,7 +8618,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8084
8618
  if (hparams.need_kq_pos) {
8085
8619
  const int64_t n_kv = kv_self.n;
8086
8620
 
8087
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
8621
+ GGML_ASSERT(lctx.inp_KQ_pos);
8622
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
8088
8623
 
8089
8624
  float * data = (float *) lctx.inp_KQ_pos->data;
8090
8625
 
@@ -8096,6 +8631,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8096
8631
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8097
8632
  const int64_t n_tokens = batch.n_tokens;
8098
8633
 
8634
+ GGML_ASSERT(lctx.inp_mean);
8099
8635
  GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
8100
8636
 
8101
8637
  float * data = (float *) lctx.inp_mean->data;
@@ -8127,6 +8663,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8127
8663
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
8128
8664
  const int64_t n_tokens = batch.n_tokens;
8129
8665
 
8666
+ GGML_ASSERT(lctx.inp_cls);
8130
8667
  GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
8131
8668
 
8132
8669
  uint32_t * data = (uint32_t *) lctx.inp_cls->data;
@@ -8143,6 +8680,53 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8143
8680
  }
8144
8681
  }
8145
8682
  }
8683
+
8684
+ if (kv_self.recurrent) {
8685
+ const int64_t n_kv = kv_self.n;
8686
+
8687
+ if (lctx.inp_s_mask) {
8688
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
8689
+ float * data = (float *) lctx.inp_s_mask->data;
8690
+
8691
+ // states which are not affected by the current batch are left untouched
8692
+ for (int i = 0; i < n_kv; ++i) {
8693
+ llama_seq_id seq_id = i + lctx.kv_self.head;
8694
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8695
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
8696
+
8697
+ data[i] = (float) has_self_seq;
8698
+
8699
+ // ensure current sequences will be kept
8700
+ if (!has_self_seq && kv_cell.pos >= 0) {
8701
+ kv_cell.seq_id.insert(seq_id);
8702
+ }
8703
+ }
8704
+ }
8705
+ // For Mamba (and other recurrent architectures),
8706
+ // update the correct state(s)/sequence(s) for each token of the batch.
8707
+ // Like with the KQ_mask, if a token in the batch has multiple sequences,
8708
+ // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
8709
+ if (lctx.inp_s_seq) {
8710
+ const int64_t n_tokens = batch.n_tokens;
8711
+
8712
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
8713
+ int32_t * data = (int32_t *) lctx.inp_s_seq->data;
8714
+
8715
+ for (int j = 0; j < n_tokens; ++j) {
8716
+ const int32_t n_seq = batch.n_seq_id[j];
8717
+ GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
8718
+
8719
+ for (int i = 0; i < n_kv; ++i) {
8720
+ if (i < n_seq) {
8721
+ // for this type of model, the head is the minimum seq_id of the batch
8722
+ data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
8723
+ } else {
8724
+ data[j*n_kv + i] = -1;
8725
+ }
8726
+ }
8727
+ }
8728
+ }
8729
+ }
8146
8730
  }
8147
8731
 
8148
8732
  static void llama_graph_compute(
@@ -8165,7 +8749,7 @@ static void llama_graph_compute(
8165
8749
  ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
8166
8750
  }
8167
8751
 
8168
- ggml_backend_sched_graph_compute(lctx.sched, gf);
8752
+ ggml_backend_sched_graph_compute_async(lctx.sched, gf);
8169
8753
 
8170
8754
  // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
8171
8755
 
@@ -8185,10 +8769,11 @@ static void llama_graph_compute(
8185
8769
  //
8186
8770
  static int llama_decode_internal(
8187
8771
  llama_context & lctx,
8188
- llama_batch batch) {
8189
- const uint32_t n_tokens = batch.n_tokens;
8772
+ llama_batch batch_all) { // TODO: rename back to batch
8773
+
8774
+ const uint32_t n_tokens_all = batch_all.n_tokens;
8190
8775
 
8191
- if (n_tokens == 0) {
8776
+ if (n_tokens_all == 0) {
8192
8777
  LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
8193
8778
  return -1;
8194
8779
  }
@@ -8197,14 +8782,16 @@ static int llama_decode_internal(
8197
8782
  const auto & hparams = model.hparams;
8198
8783
  const auto & cparams = lctx.cparams;
8199
8784
 
8200
- const auto n_batch = cparams.n_batch;
8785
+ GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
8201
8786
 
8202
- GGML_ASSERT(n_tokens <= n_batch);
8203
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
8787
+ GGML_ASSERT(n_tokens_all <= cparams.n_batch);
8204
8788
 
8205
- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
8789
+ GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
8206
8790
 
8207
- const int64_t t_start_us = ggml_time_us();
8791
+ if (lctx.t_compute_start_us == 0) {
8792
+ lctx.t_compute_start_us = ggml_time_us();
8793
+ }
8794
+ lctx.n_queued_tokens += n_tokens_all;
8208
8795
 
8209
8796
  #ifdef GGML_USE_MPI
8210
8797
  // TODO: needs fix after #3228
@@ -8212,258 +8799,274 @@ static int llama_decode_internal(
8212
8799
  //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
8213
8800
  #endif
8214
8801
 
8215
- GGML_ASSERT(n_threads > 0);
8216
-
8217
8802
  auto & kv_self = lctx.kv_self;
8218
8803
 
8219
8804
  const int64_t n_embd = hparams.n_embd;
8220
8805
  const int64_t n_vocab = hparams.n_vocab;
8221
8806
 
8222
- // helpers for smoother batch API transition
8223
- // after deprecating the llama_eval calls, these will be removed
8224
- std::vector<llama_pos> pos;
8225
8807
 
8808
+ auto * logits_out = lctx.logits;
8809
+
8810
+ #ifndef NDEBUG
8811
+ auto & logits_valid = lctx.logits_valid;
8812
+ logits_valid.clear();
8813
+ logits_valid.resize(n_tokens_all);
8814
+
8815
+ memset(logits_out, 0, lctx.logits_size*sizeof(float));
8816
+ #endif
8817
+
8818
+ const auto n_ubatch = cparams.n_ubatch;
8819
+
8820
+ std::vector<llama_pos> pos;
8226
8821
  std::vector<int32_t> n_seq_id;
8227
8822
  std::vector<llama_seq_id *> seq_id_arr;
8228
8823
  std::vector<std::vector<llama_seq_id>> seq_id;
8229
8824
 
8230
- if (batch.pos == nullptr) {
8231
- pos.resize(n_tokens);
8232
- for (uint32_t i = 0; i < n_tokens; i++) {
8233
- pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
8825
+ for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
8826
+ const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
8827
+ llama_batch u_batch = {
8828
+ /* .n_tokens = */ (int32_t) n_tokens,
8829
+ /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
8830
+ /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
8831
+ /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
8832
+ /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
8833
+ /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
8834
+ /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
8835
+ /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
8836
+ /* .all_pos_1 = */ batch_all.all_pos_1,
8837
+ /* .all_seq_id = */ batch_all.all_seq_id,
8838
+ };
8839
+
8840
+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
8841
+ GGML_ASSERT(n_threads > 0);
8842
+
8843
+ // helpers for smoother batch API transition
8844
+ // after deprecating the llama_eval calls, these will be removed
8845
+ if (u_batch.pos == nullptr) {
8846
+ pos.resize(n_tokens);
8847
+ for (uint32_t i = 0; i < n_tokens; i++) {
8848
+ pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
8849
+ }
8850
+
8851
+ u_batch.pos = pos.data();
8234
8852
  }
8235
8853
 
8236
- batch.pos = pos.data();
8237
- }
8854
+ if (u_batch.seq_id == nullptr) {
8855
+ n_seq_id.resize(n_tokens);
8856
+ seq_id.resize(n_tokens);
8857
+ seq_id_arr.resize(n_tokens);
8858
+ for (uint32_t i = 0; i < n_tokens; i++) {
8859
+ n_seq_id[i] = 1;
8860
+ seq_id[i].resize(1);
8861
+ seq_id[i][0] = u_batch.all_seq_id;
8862
+ seq_id_arr[i] = seq_id[i].data();
8863
+ }
8238
8864
 
8239
- if (batch.seq_id == nullptr) {
8240
- n_seq_id.resize(n_tokens);
8241
- seq_id.resize(n_tokens);
8242
- seq_id_arr.resize(n_tokens);
8243
- for (uint32_t i = 0; i < n_tokens; i++) {
8244
- n_seq_id[i] = 1;
8245
- seq_id[i].resize(1);
8246
- seq_id[i][0] = batch.all_seq_id;
8247
- seq_id_arr[i] = seq_id[i].data();
8865
+ u_batch.n_seq_id = n_seq_id.data();
8866
+ u_batch.seq_id = seq_id_arr.data();
8248
8867
  }
8249
8868
 
8250
- batch.n_seq_id = n_seq_id.data();
8251
- batch.seq_id = seq_id_arr.data();
8252
- }
8869
+ // non-causal masks do not use the KV cache
8870
+ if (hparams.causal_attn) {
8871
+ llama_kv_cache_update(&lctx);
8253
8872
 
8254
- // non-causal masks do not use the KV cache
8255
- if (hparams.causal_attn) {
8256
- llama_kv_cache_update(&lctx);
8873
+ // if we have enough unused cells before the current head ->
8874
+ // better to start searching from the beginning of the cache, hoping to fill it
8875
+ if (kv_self.head > kv_self.used + 2*n_tokens) {
8876
+ kv_self.head = 0;
8877
+ }
8257
8878
 
8258
- // if we have enough unused cells before the current head ->
8259
- // better to start searching from the beginning of the cache, hoping to fill it
8260
- if (kv_self.head > kv_self.used + 2*n_tokens) {
8261
- kv_self.head = 0;
8262
- }
8879
+ if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
8880
+ return 1;
8881
+ }
8263
8882
 
8264
- if (!llama_kv_cache_find_slot(kv_self, batch)) {
8265
- return 1;
8883
+ if (!kv_self.recurrent) {
8884
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
8885
+ // after enough generations, the benefit from this heuristic disappears
8886
+ // if we start defragmenting the cache, the benefit from this will be more important
8887
+ kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8888
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
8889
+ }
8266
8890
  }
8267
8891
 
8268
- // a heuristic, to avoid attending the full cache if it is not yet utilized
8269
- // after enough generations, the benefit from this heuristic disappears
8270
- // if we start defragmenting the cache, the benefit from this will be more important
8271
- kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8272
- //kv_self.n = llama_kv_cache_cell_max(kv_self);
8273
- }
8274
-
8275
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
8892
+ //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
8276
8893
 
8277
- ggml_backend_sched_reset(lctx.sched);
8278
- ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
8894
+ ggml_backend_sched_reset(lctx.sched);
8895
+ ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
8279
8896
 
8280
- ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
8897
+ ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
8281
8898
 
8282
- // the output is always the last tensor in the graph
8283
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
8284
- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
8899
+ // the output is always the last tensor in the graph
8900
+ struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
8901
+ struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
8285
8902
 
8286
- if (!hparams.causal_attn) {
8287
- res = nullptr; // do not extract logits for embedding models such as BERT
8903
+ if (!hparams.causal_attn) {
8904
+ res = nullptr; // do not extract logits for embedding models such as BERT
8288
8905
 
8289
- // token or sequence embeddings
8290
- embd = gf->nodes[gf->n_nodes - 1];
8906
+ // token or sequence embeddings
8907
+ embd = gf->nodes[gf->n_nodes - 1];
8291
8908
 
8292
- GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
8293
- } else {
8294
- if (strcmp(res->name, "result_output") == 0) {
8295
- // the token embeddings could be the second to last tensor, or the third to last tensor
8296
- if (strcmp(embd->name, "result_norm") != 0) {
8297
- embd = gf->nodes[gf->n_nodes - 3];
8298
- GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
8299
- }
8909
+ GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
8300
8910
  } else {
8301
- GGML_ASSERT(false && "missing result_output tensor");
8911
+ if (strcmp(res->name, "result_output") == 0) {
8912
+ // the token embeddings could be the second to last tensor, or the third to last tensor
8913
+ if (strcmp(embd->name, "result_norm") != 0) {
8914
+ embd = gf->nodes[gf->n_nodes - 3];
8915
+ GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
8916
+ }
8917
+ } else {
8918
+ GGML_ASSERT(false && "missing result_output tensor");
8919
+ }
8302
8920
  }
8303
- }
8304
-
8305
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
8921
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
8306
8922
 
8307
- // for big prompts, if BLAS is enabled, it is better to use only one thread
8308
- // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
8309
- // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
8310
- // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
8311
- // with the BLAS calls. need a better solution
8312
- // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
8313
- // being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
8314
- if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
8315
- n_threads = std::min(4, n_threads);
8316
- }
8317
-
8318
- llama_set_inputs(lctx, batch);
8319
-
8320
- llama_graph_compute(lctx, gf, n_threads);
8923
+ // for big prompts, if BLAS is enabled, it is better to use only one thread
8924
+ // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
8925
+ // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
8926
+ // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
8927
+ // with the BLAS calls. need a better solution
8928
+ // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
8929
+ // being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
8930
+ if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
8931
+ n_threads = std::min(4, n_threads);
8932
+ }
8321
8933
 
8322
- // update the kv ring buffer
8323
- {
8324
- kv_self.head += n_tokens;
8934
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
8325
8935
 
8326
- // Ensure kv cache head points to a valid index.
8327
- if (kv_self.head >= kv_self.size) {
8328
- kv_self.head = 0;
8329
- }
8330
- }
8936
+ llama_set_inputs(lctx, u_batch);
8331
8937
 
8332
- // decide if we need to defrag the kv cache
8333
- if (cparams.defrag_thold >= 0.0f) {
8334
- const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
8938
+ llama_graph_compute(lctx, gf, n_threads);
8335
8939
 
8336
- // queue defragmentation for next llama_kv_cache_update
8337
- if (fragmentation > cparams.defrag_thold) {
8338
- //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
8940
+ // update the kv ring buffer
8941
+ {
8942
+ kv_self.head += n_tokens;
8339
8943
 
8340
- llama_kv_cache_defrag(kv_self);
8944
+ // Ensure kv cache head points to a valid index.
8945
+ if (kv_self.head >= kv_self.size) {
8946
+ kv_self.head = 0;
8947
+ }
8341
8948
  }
8342
- }
8343
8949
 
8344
8950
  #ifdef GGML_PERF
8345
- // print timing information per ggml operation (for debugging purposes)
8346
- // requires GGML_PERF to be defined
8347
- ggml_graph_print(gf);
8951
+ // print timing information per ggml operation (for debugging purposes)
8952
+ // requires GGML_PERF to be defined
8953
+ ggml_graph_print(gf);
8348
8954
  #endif
8349
8955
 
8350
- // plot the computation graph in dot format (for debugging purposes)
8351
- //if (n_past%100 == 0) {
8352
- // ggml_graph_dump_dot(gf, NULL, "llama.dot");
8353
- //}
8354
-
8355
- // extract logits
8356
- // TODO: do not compute and extract logits if only embeddings are needed
8357
- // need to update the graphs to skip "result_output"
8358
- if (res) {
8359
- auto & logits_out = lctx.logits;
8360
-
8956
+ // plot the computation graph in dot format (for debugging purposes)
8957
+ //if (n_past%100 == 0) {
8958
+ // ggml_graph_dump_dot(gf, NULL, "llama.dot");
8959
+ //}
8960
+
8961
+ // extract logits
8962
+ // TODO: do not compute and extract logits if only embeddings are needed
8963
+ // update the graphs to skip "result_output" if logits are not needed
8964
+ if (res) {
8965
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
8966
+ GGML_ASSERT(backend_res != nullptr);
8967
+ if (u_batch.logits) {
8968
+ int32_t i_first = -1;
8969
+ for (uint32_t i = 0; i < n_tokens; i++) {
8970
+ if (u_batch.logits[i] && i_first == -1) {
8971
+ i_first = (int32_t) i;
8972
+ }
8973
+ if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
8974
+ if (i_first != -1) {
8975
+ int i_last = u_batch.logits[i] == 0 ? i : i + 1;
8976
+ // extract logits for the range [i_first, i_last)
8977
+ // group the requests to minimize the number of calls to the backend
8978
+ ggml_backend_tensor_get_async(backend_res, res,
8979
+ logits_out + n_vocab*(cur_token + i_first),
8980
+ i_first*n_vocab*sizeof(float),
8981
+ (i_last - i_first)*n_vocab*sizeof(float));
8982
+ i_first = -1;
8983
+ }
8984
+ }
8361
8985
  #ifndef NDEBUG
8362
- auto & logits_valid = lctx.logits_valid;
8363
- logits_valid.clear();
8364
- logits_valid.resize(n_tokens);
8365
-
8366
- logits_out.clear();
8986
+ logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
8367
8987
  #endif
8368
-
8369
- ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
8370
- GGML_ASSERT(backend_res != nullptr);
8371
-
8372
- if (batch.logits) {
8373
- logits_out.resize(n_vocab * n_tokens);
8374
- for (uint32_t i = 0; i < n_tokens; i++) {
8375
- if (batch.logits[i] == 0) {
8376
- continue;
8377
8988
  }
8378
- ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
8379
- #ifndef NDEBUG
8380
- logits_valid[i] = true;
8381
- #endif
8382
- }
8383
- } else if (lctx.logits_all) {
8384
- logits_out.resize(n_vocab * n_tokens);
8385
- ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
8989
+ } else if (lctx.logits_all) {
8990
+ ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
8386
8991
  #ifndef NDEBUG
8387
- std::fill(logits_valid.begin(), logits_valid.end(), true);
8992
+ std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
8388
8993
  #endif
8389
- } else {
8390
- logits_out.resize(n_vocab);
8391
- ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
8994
+ } else {
8995
+ if (cur_token + n_tokens >= n_tokens_all) {
8996
+ ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
8392
8997
  #ifndef NDEBUG
8393
- logits_valid[0] = true;
8998
+ logits_valid[0] = true;
8394
8999
  #endif
9000
+ }
9001
+ }
8395
9002
  }
8396
- ggml_backend_synchronize(backend_res);
8397
- }
8398
-
8399
- // extract embeddings
8400
- if (cparams.embeddings && embd) {
8401
- ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
8402
- GGML_ASSERT(backend_embd != nullptr);
8403
9003
 
8404
- switch (cparams.pooling_type) {
8405
- case LLAMA_POOLING_TYPE_NONE:
8406
- {
8407
- // extract token embeddings
8408
- auto & embd_out = lctx.embd;
9004
+ // extract embeddings
9005
+ if (cparams.embeddings && embd) {
9006
+ ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
9007
+ GGML_ASSERT(backend_embd != nullptr);
8409
9008
 
8410
- if (batch.logits) {
8411
- embd_out.resize(n_embd * n_tokens);
8412
- for (uint32_t i = 0; i < n_tokens; i++) {
8413
- if (batch.logits[i] == 0) {
8414
- continue;
9009
+ switch (cparams.pooling_type) {
9010
+ case LLAMA_POOLING_TYPE_NONE:
9011
+ {
9012
+ // extract token embeddings
9013
+ auto & embd_out = lctx.embd;
9014
+
9015
+ if (u_batch.logits) {
9016
+ //embd_out.resize(n_embd * n_tokens);
9017
+ for (uint32_t i = 0; i < n_tokens; i++) {
9018
+ if (u_batch.logits[i] == 0) {
9019
+ continue;
9020
+ }
9021
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
8415
9022
  }
8416
-
8417
- ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
8418
9023
  }
8419
- }
8420
- } break;
8421
- case LLAMA_POOLING_TYPE_CLS:
8422
- case LLAMA_POOLING_TYPE_MEAN:
8423
- {
8424
- GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
9024
+ } break;
9025
+ case LLAMA_POOLING_TYPE_CLS:
9026
+ case LLAMA_POOLING_TYPE_MEAN:
9027
+ {
9028
+ GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
8425
9029
 
8426
- // extract sequence embeddings
8427
- auto & embd_seq_out = lctx.embd_seq;
8428
- embd_seq_out.clear();
9030
+ // extract sequence embeddings
9031
+ auto & embd_seq_out = lctx.embd_seq;
9032
+ embd_seq_out.clear();
8429
9033
 
8430
- for (uint32_t i = 0; i < n_tokens; i++) {
8431
- const llama_seq_id seq_id = batch.seq_id[i][0];
8432
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
8433
- continue;
9034
+ for (uint32_t i = 0; i < n_tokens; i++) {
9035
+ const llama_seq_id seq_id = u_batch.seq_id[i][0];
9036
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
9037
+ continue;
9038
+ }
9039
+ embd_seq_out[seq_id].resize(n_embd);
9040
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
8434
9041
  }
8435
- embd_seq_out[seq_id].resize(n_embd);
8436
- ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
8437
- }
8438
- } break;
8439
- case LLAMA_POOLING_TYPE_UNSPECIFIED:
8440
- {
8441
- GGML_ASSERT(false && "unknown pooling type");
8442
- } break;
9042
+ } break;
9043
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
9044
+ {
9045
+ GGML_ASSERT(false && "unknown pooling type");
9046
+ } break;
9047
+ }
8443
9048
  }
8444
- ggml_backend_synchronize(backend_embd);
8445
9049
  }
8446
9050
 
8447
- // measure the performance only for the single-token evals
8448
- if (n_tokens == 1) {
8449
- lctx.t_eval_us += ggml_time_us() - t_start_us;
8450
- lctx.n_eval++;
8451
- }
8452
- else if (n_tokens > 1) {
8453
- lctx.t_p_eval_us += ggml_time_us() - t_start_us;
8454
- lctx.n_p_eval += n_tokens;
8455
- }
9051
+ // wait for the computation to finish (automatically done when obtaining the model output)
9052
+ //llama_synchronize(&lctx);
8456
9053
 
8457
- // get a more accurate load time, upon first eval
8458
- // TODO: fix this
8459
- if (!lctx.has_evaluated_once) {
8460
- lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
8461
- lctx.has_evaluated_once = true;
9054
+ // decide if we need to defrag the kv cache
9055
+ if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
9056
+ const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
9057
+
9058
+ // queue defragmentation for next llama_kv_cache_update
9059
+ if (fragmentation > cparams.defrag_thold) {
9060
+ //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
9061
+
9062
+ llama_kv_cache_defrag(kv_self);
9063
+ }
8462
9064
  }
8463
9065
 
8464
9066
  return 0;
8465
9067
  }
8466
9068
 
9069
+
8467
9070
  // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
8468
9071
  static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8469
9072
  auto & kv_self = lctx.kv_self;
@@ -8482,6 +9085,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8482
9085
  // number of cells moved
8483
9086
  uint32_t n_moves = 0;
8484
9087
 
9088
+ // each move requires 6*n_layer tensors (see build_defrag)
9089
+ // - source view, destination view, copy operation
9090
+ // - x2 for keys and values
9091
+ const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
9092
+
8485
9093
  // determine which KV cells to move where
8486
9094
  //
8487
9095
  // cell i moves to ids[i]
@@ -8508,15 +9116,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8508
9116
  nh++;
8509
9117
  }
8510
9118
 
8511
- // each move requires 6*n_layer tensors (see build_defrag)
8512
- // - source view, destination view, copy operation
8513
- // - x2 for keys and values
8514
- //
8515
- if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
8516
- // the graph is too big, we cannot move more cells
8517
- break;
8518
- }
8519
-
8520
9119
  uint32_t nf = 0;
8521
9120
  uint32_t is = n_kv - 1;
8522
9121
 
@@ -8546,11 +9145,19 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8546
9145
  // are we moving a continuous block of memory?
8547
9146
  bool cont = false;
8548
9147
 
9148
+ // should we stop searching for the next move?
9149
+ bool stop = false;
9150
+
8549
9151
  // go back and move the nf cells to the hole
8550
9152
  for (; i1 < n_kv; ++i1) {
8551
9153
  auto & cell1 = kv_self.cells[i1];
8552
9154
 
8553
9155
  if (cell1.is_empty() || ids[i1] != n_kv) {
9156
+ if (n_moves == max_moves) {
9157
+ stop = true;
9158
+ break;
9159
+ }
9160
+
8554
9161
  cont = false;
8555
9162
  continue;
8556
9163
  }
@@ -8577,6 +9184,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8577
9184
  }
8578
9185
  }
8579
9186
 
9187
+ if (stop || n_moves == max_moves) {
9188
+ break;
9189
+ }
9190
+
8580
9191
  //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
8581
9192
 
8582
9193
  i0 += nh - 1;
@@ -8663,6 +9274,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8663
9274
  #else
8664
9275
  // ggml_graph defrag
8665
9276
 
9277
+ ggml_backend_sched_reset(lctx.sched);
9278
+
8666
9279
  ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
8667
9280
 
8668
9281
  llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
@@ -8674,14 +9287,22 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8674
9287
  }
8675
9288
 
8676
9289
  static void llama_kv_cache_update_internal(struct llama_context & lctx) {
9290
+ bool need_reserve = false;
9291
+
8677
9292
  // apply K-shift if needed
8678
9293
  if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
8679
- llama_set_k_shift(lctx);
8680
-
8681
9294
  {
9295
+ ggml_backend_sched_reset(lctx.sched);
9296
+
8682
9297
  ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
8683
9298
 
9299
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
9300
+
9301
+ llama_set_k_shift(lctx);
9302
+
8684
9303
  llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9304
+
9305
+ need_reserve = true;
8685
9306
  }
8686
9307
 
8687
9308
  {
@@ -8695,12 +9316,56 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8695
9316
  }
8696
9317
  }
8697
9318
 
9319
+ if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
9320
+ {
9321
+ ggml_backend_sched_reset(lctx.sched);
9322
+
9323
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
9324
+
9325
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
9326
+
9327
+ llama_set_s_copy(lctx);
9328
+
9329
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9330
+
9331
+ need_reserve = true;
9332
+ }
9333
+
9334
+ {
9335
+ auto & kv_self = lctx.kv_self;
9336
+
9337
+ kv_self.do_copy = false;
9338
+
9339
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
9340
+ kv_self.cells[i].src = i;
9341
+ }
9342
+ }
9343
+ }
9344
+
8698
9345
  // defragment the KV cache if needed
8699
9346
  if (lctx.kv_self.do_defrag) {
8700
9347
  llama_kv_cache_defrag_internal(lctx);
8701
9348
 
9349
+ need_reserve = true;
9350
+
8702
9351
  lctx.kv_self.do_defrag = false;
8703
9352
  }
9353
+
9354
+ // reserve a worst case graph again
9355
+ if (need_reserve) {
9356
+ // TODO: extract to a function
9357
+ // build worst-case graph
9358
+ int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
9359
+ int n_past = lctx.cparams.n_ctx - n_tokens;
9360
+ llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
9361
+ ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
9362
+
9363
+ // initialize scheduler with the worst-case graph
9364
+ ggml_backend_sched_reset(lctx.sched);
9365
+ if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
9366
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
9367
+ }
9368
+ }
8704
9369
  }
8705
9370
 
8706
9371
  //
@@ -8712,26 +9377,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
8712
9377
  }
8713
9378
 
8714
9379
  static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
9380
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
8715
9381
  return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
8716
9382
  }
8717
9383
 
8718
9384
  static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
9385
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
8719
9386
  return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
8720
9387
  }
8721
9388
 
8722
9389
  static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
9390
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
8723
9391
  return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
8724
9392
  }
8725
9393
 
8726
9394
  static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
9395
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
8727
9396
  return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
8728
9397
  }
8729
9398
 
8730
9399
  static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
9400
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
8731
9401
  return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
8732
9402
  }
8733
9403
 
8734
9404
  static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
9405
+ GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
8735
9406
  GGML_ASSERT(llama_is_byte_token(vocab, id));
8736
9407
  const auto& token_data = vocab.id_to_token.at(id);
8737
9408
  switch (llama_vocab_get_type(vocab)) {
@@ -8741,7 +9412,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
8741
9412
  }
8742
9413
  case LLAMA_VOCAB_TYPE_BPE: {
8743
9414
  GGML_ASSERT(false);
8744
- return unicode_to_bytes_bpe(token_data.text);
9415
+ return unicode_utf8_to_byte(token_data.text);
8745
9416
  }
8746
9417
  case LLAMA_VOCAB_TYPE_WPM: {
8747
9418
  GGML_ASSERT(false);
@@ -8752,6 +9423,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
8752
9423
  }
8753
9424
 
8754
9425
  static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
9426
+ GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
8755
9427
  static const char * hex = "0123456789ABCDEF";
8756
9428
  switch (llama_vocab_get_type(vocab)) {
8757
9429
  case LLAMA_VOCAB_TYPE_SPM: {
@@ -8766,7 +9438,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
8766
9438
  }
8767
9439
  case LLAMA_VOCAB_TYPE_WPM:
8768
9440
  case LLAMA_VOCAB_TYPE_BPE: {
8769
- return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
9441
+ return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
8770
9442
  }
8771
9443
  default:
8772
9444
  GGML_ASSERT(false);
@@ -9106,9 +9778,9 @@ private:
9106
9778
  bpe_words.reserve(text.size());
9107
9779
  bpe_encoded_words.reserve(text.size());
9108
9780
 
9109
- auto cps = codepoints_from_utf8(text);
9110
- for (size_t i = 0; i < cps.size(); ++i)
9111
- text_utf.emplace_back(codepoint_to_utf8(cps[i]));
9781
+ const auto cpts = unicode_cpts_from_utf8(text);
9782
+ for (size_t i = 0; i < cpts.size(); ++i)
9783
+ text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
9112
9784
 
9113
9785
  for (int i = 0; i < (int)text_utf.size(); i++) {
9114
9786
  const std::string & utf_char = text_utf[i];
@@ -9158,40 +9830,40 @@ private:
9158
9830
  }
9159
9831
 
9160
9832
  if (!split_condition && !collecting) {
9161
- if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
9833
+ if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
9162
9834
  collecting_letter = true;
9163
9835
  collecting = true;
9164
9836
  }
9165
- else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9837
+ else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9166
9838
  collecting_numeric = true;
9167
9839
  collecting = true;
9168
9840
  }
9169
9841
  else if (
9170
- ((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
9171
- (!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
9842
+ ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
9843
+ (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
9172
9844
  ) {
9173
9845
  collecting_special = true;
9174
9846
  collecting = true;
9175
9847
  }
9176
- else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
9848
+ else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
9177
9849
  collecting_whitespace_lookahead = true;
9178
9850
  collecting = true;
9179
9851
  }
9180
- else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
9852
+ else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
9181
9853
  split_condition = true;
9182
9854
  }
9183
9855
  }
9184
9856
  else if (!split_condition && collecting) {
9185
- if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
9857
+ if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
9186
9858
  split_condition = true;
9187
9859
  }
9188
- else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
9860
+ else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
9189
9861
  split_condition = true;
9190
9862
  }
9191
- else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
9863
+ else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
9192
9864
  split_condition = true;
9193
9865
  }
9194
- else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9866
+ else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9195
9867
  split_condition = true;
9196
9868
  }
9197
9869
  }
@@ -9220,7 +9892,7 @@ private:
9220
9892
  for (std::string & word : bpe_words) {
9221
9893
  std::string encoded_token = "";
9222
9894
  for (char & c : word) {
9223
- encoded_token += bytes_to_unicode_bpe(c);
9895
+ encoded_token += unicode_byte_to_utf8(c);
9224
9896
  }
9225
9897
  bpe_encoded_words.emplace_back(encoded_token);
9226
9898
  }
@@ -9294,25 +9966,13 @@ struct llm_tokenizer_wpm {
9294
9966
  }
9295
9967
 
9296
9968
  std::vector<std::string> preprocess(const std::string & text) {
9297
- // normalalization form D
9298
- std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
9299
- std::vector<uint32_t> nfd_codepoints;
9300
- for (uint32_t code : codepoints) {
9301
- auto it = nfd_map.equal_range(code);
9302
- if (it.first != it.second) {
9303
- for (auto jt = it.first; jt != it.second; jt++) {
9304
- nfd_codepoints.push_back(jt->second);
9305
- }
9306
- } else {
9307
- nfd_codepoints.push_back(code);
9308
- }
9309
- }
9969
+ std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
9310
9970
 
9311
9971
  // strip accents, strip control, uniformize whitespace,
9312
9972
  // to lowercase, pad chinese characters, pad punctuation
9313
9973
  std::string new_str = "";
9314
- for (uint32_t code : nfd_codepoints) {
9315
- int type = codepoint_type(code);
9974
+ for (uint32_t code : cpts_nfd) {
9975
+ int type = unicode_cpt_type(code);
9316
9976
  if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
9317
9977
  continue;
9318
9978
  }
@@ -9320,7 +9980,7 @@ struct llm_tokenizer_wpm {
9320
9980
  if (type == CODEPOINT_TYPE_WHITESPACE) {
9321
9981
  code = ' ';
9322
9982
  }
9323
- std::string s = codepoint_to_utf8(code);
9983
+ std::string s = unicode_cpt_to_utf8(code);
9324
9984
  if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
9325
9985
  new_str += " ";
9326
9986
  new_str += s;
@@ -9340,8 +10000,7 @@ struct llm_tokenizer_wpm {
9340
10000
  if (r > l) words.push_back(new_str.substr(l, (r - l)));
9341
10001
  l = r + 1;
9342
10002
  r = l;
9343
- }
9344
- else {
10003
+ } else {
9345
10004
  r += 1;
9346
10005
  }
9347
10006
  }
@@ -9365,17 +10024,17 @@ struct llm_tokenizer_wpm {
9365
10024
  return code < 256 && ispunct(code);
9366
10025
  }
9367
10026
 
9368
- bool is_chinese_char(uint32_t codepoint) {
9369
- if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
9370
- (codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
9371
- (codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
9372
- (codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
9373
- (codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
9374
- (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
9375
- (codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
9376
- (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
9377
- (codepoint >= 0x3000 && codepoint <= 0x303F) ||
9378
- (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
10027
+ bool is_chinese_char(uint32_t cpt) {
10028
+ if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
10029
+ (cpt >= 0x3400 && cpt <= 0x4DBF) ||
10030
+ (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
10031
+ (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
10032
+ (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
10033
+ (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
10034
+ (cpt >= 0xF900 && cpt <= 0xFAFF) ||
10035
+ (cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
10036
+ (cpt >= 0x3000 && cpt <= 0x303F) ||
10037
+ (cpt >= 0xFF00 && cpt <= 0xFFEF)) {
9379
10038
  return true; // NOLINT
9380
10039
  }
9381
10040
  return false;
@@ -9596,6 +10255,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
9596
10255
  }
9597
10256
  }
9598
10257
  } break;
10258
+ case LLAMA_VOCAB_TYPE_NONE:
10259
+ GGML_ASSERT(false);
9599
10260
  }
9600
10261
 
9601
10262
  return output;
@@ -9952,7 +10613,7 @@ struct llama_grammar * llama_grammar_init(
9952
10613
 
9953
10614
  // loop over alternates of start rule to build initial stacks
9954
10615
  std::vector<std::vector<const llama_grammar_element *>> stacks;
9955
- pos = rules[start_rule_index];
10616
+ pos = vec_rules[start_rule_index].data();
9956
10617
  do {
9957
10618
  std::vector<const llama_grammar_element *> stack;
9958
10619
  if (!llama_grammar_is_end_of_sequence(pos)) {
@@ -10967,6 +11628,9 @@ struct quantize_state_internal {
10967
11628
 
10968
11629
  bool has_imatrix = false;
10969
11630
 
11631
+ // used to figure out if a model shares tok_embd with the output weight
11632
+ bool has_output = false;
11633
+
10970
11634
  quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
10971
11635
  : model(model)
10972
11636
  , params(params)
@@ -11034,7 +11698,7 @@ static void llama_tensor_dequantize_internal(
11034
11698
  workers.clear();
11035
11699
  }
11036
11700
 
11037
- static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
11701
+ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
11038
11702
  const std::string name = ggml_get_name(tensor);
11039
11703
 
11040
11704
  // TODO: avoid hardcoded tensor names - use the TN_* constants
@@ -11064,8 +11728,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
11064
11728
 
11065
11729
  // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
11066
11730
  // with the quantization of the output tensor
11067
- if (name == tn(LLM_TENSOR_OUTPUT, "weight") ||
11068
- (LLM_TENSOR_NAMES.at(arch).find(LLM_TENSOR_OUTPUT) == LLM_TENSOR_NAMES.at(arch).end() && name == "token_embd.weight")) {
11731
+ if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
11069
11732
  int nx = tensor->ne[0];
11070
11733
  if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
11071
11734
  new_type = GGML_TYPE_Q8_0;
@@ -11314,17 +11977,16 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
11314
11977
  return new_type;
11315
11978
  }
11316
11979
 
11317
- static int32_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, int64_t * hist_cur, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
11980
+ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
11318
11981
  std::mutex mutex;
11319
11982
  int counter = 0;
11320
11983
  size_t new_size = 0;
11321
11984
  if (nthread < 2) {
11322
11985
  // single-thread
11323
- return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur, imatrix);
11986
+ return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
11324
11987
  }
11325
- auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size,
11988
+ auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size,
11326
11989
  nrows, n_per_row, imatrix]() {
11327
- std::array<int64_t, 1 << 4> local_hist = {};
11328
11990
  const int nrows_per_chunk = chunk_size / n_per_row;
11329
11991
  size_t local_size = 0;
11330
11992
  while (true) {
@@ -11332,17 +11994,13 @@ static int32_t llama_tensor_quantize_internal(enum ggml_type new_type, const flo
11332
11994
  int first_row = counter; counter += nrows_per_chunk;
11333
11995
  if (first_row >= nrows) {
11334
11996
  if (local_size > 0) {
11335
- for (int j=0; j<int(local_hist.size()); ++j) {
11336
- hist_cur[j] += local_hist[j];
11337
- }
11338
11997
  new_size += local_size;
11339
11998
  }
11340
11999
  break;
11341
12000
  }
11342
12001
  lock.unlock();
11343
12002
  const int this_nrow = std::min(nrows - first_row, nrows_per_chunk);
11344
- local_size += ggml_quantize_chunk(new_type, f32_data, new_data,
11345
- first_row * n_per_row, this_nrow, n_per_row, local_hist.data(), imatrix);
12003
+ local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
11346
12004
  }
11347
12005
  };
11348
12006
  for (int it = 0; it < nthread - 1; ++it) {
@@ -11355,40 +12013,40 @@ static int32_t llama_tensor_quantize_internal(enum ggml_type new_type, const flo
11355
12013
  }
11356
12014
 
11357
12015
  static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
11358
- ggml_type quantized_type;
12016
+ ggml_type default_type;
11359
12017
  llama_ftype ftype = params->ftype;
11360
12018
 
11361
12019
  switch (params->ftype) {
11362
- case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
11363
- case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
11364
- case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break;
11365
- case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break;
11366
- case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break;
11367
- case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break;
11368
- case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break;
12020
+ case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
12021
+ case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
12022
+ case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
12023
+ case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
12024
+ case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
12025
+ case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
12026
+ case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
11369
12027
 
11370
12028
  // K-quants
11371
12029
  case LLAMA_FTYPE_MOSTLY_Q2_K_S:
11372
- case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break;
11373
- case LLAMA_FTYPE_MOSTLY_IQ3_XS: quantized_type = GGML_TYPE_IQ3_S; break;
12030
+ case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
12031
+ case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
11374
12032
  case LLAMA_FTYPE_MOSTLY_Q3_K_S:
11375
12033
  case LLAMA_FTYPE_MOSTLY_Q3_K_M:
11376
- case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break;
12034
+ case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
11377
12035
  case LLAMA_FTYPE_MOSTLY_Q4_K_S:
11378
- case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break;
12036
+ case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
11379
12037
  case LLAMA_FTYPE_MOSTLY_Q5_K_S:
11380
- case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
11381
- case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
11382
- case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break;
11383
- case LLAMA_FTYPE_MOSTLY_IQ2_XS: quantized_type = GGML_TYPE_IQ2_XS; break;
11384
- case LLAMA_FTYPE_MOSTLY_IQ2_S: quantized_type = GGML_TYPE_IQ2_XS; break;
11385
- case LLAMA_FTYPE_MOSTLY_IQ2_M: quantized_type = GGML_TYPE_IQ2_S; break;
11386
- case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
11387
- case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break;
11388
- case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break;
11389
- case LLAMA_FTYPE_MOSTLY_IQ4_XS: quantized_type = GGML_TYPE_IQ4_XS; break;
11390
- case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break;
11391
- case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break;
12038
+ case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
12039
+ case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
12040
+ case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
12041
+ case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
12042
+ case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
12043
+ case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
12044
+ case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
12045
+ case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
12046
+ case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
12047
+ case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
12048
+ case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
12049
+ case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
11392
12050
 
11393
12051
  default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
11394
12052
  }
@@ -11454,6 +12112,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11454
12112
  else if (name.find("ffn_up") != std::string::npos) {
11455
12113
  ++qs.n_ffn_up;
11456
12114
  }
12115
+ else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
12116
+ qs.has_output = true;
12117
+ }
11457
12118
  }
11458
12119
  if (qs.n_attention_wv != qs.n_ffn_down || (uint32_t)qs.n_attention_wv != model.hparams.n_layer) {
11459
12120
  LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_ffn_down = %d, hparams.n_layer = %d\n",
@@ -11462,7 +12123,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11462
12123
 
11463
12124
  size_t total_size_org = 0;
11464
12125
  size_t total_size_new = 0;
11465
- std::vector<int64_t> hist_all(1 << 4, 0);
11466
12126
 
11467
12127
  std::vector<std::thread> workers;
11468
12128
  workers.reserve(nthread);
@@ -11524,20 +12184,29 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11524
12184
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
11525
12185
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
11526
12186
 
12187
+ // do not quantize Mamba's small yet 2D weights
12188
+ // NOTE: can't use LLM_TN here because the layer number is not known
12189
+ quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
12190
+ quantize &= name.find("ssm_x.weight") == std::string::npos;
12191
+ quantize &= name.find("ssm_dt.weight") == std::string::npos;
12192
+
11527
12193
  enum ggml_type new_type;
11528
12194
  void * new_data;
11529
12195
  size_t new_size;
11530
12196
 
11531
12197
  if (quantize) {
11532
- new_type = quantized_type;
11533
- if (!params->pure) {
11534
- new_type = get_k_quant_type(qs, new_type, tensor, ftype);
12198
+ new_type = default_type;
12199
+
12200
+ // get more optimal quantization type based on the tensor shape, layer, etc.
12201
+ if (!params->pure && ggml_is_quantized(default_type)) {
12202
+ new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
11535
12203
  }
11536
12204
 
11537
12205
  // If we've decided to quantize to the same type the tensor is already
11538
12206
  // in then there's nothing to do.
11539
12207
  quantize = tensor->type != new_type;
11540
12208
  }
12209
+
11541
12210
  if (!quantize) {
11542
12211
  new_type = tensor->type;
11543
12212
  new_data = tensor->data;
@@ -11583,14 +12252,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11583
12252
  f32_data = (float *) f32_conv_buf.data();
11584
12253
  }
11585
12254
 
11586
- LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type));
12255
+ LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
11587
12256
  fflush(stdout);
11588
12257
 
11589
12258
  if (work.size() < nelements * 4) {
11590
12259
  work.resize(nelements * 4); // upper bound on size
11591
12260
  }
11592
12261
  new_data = work.data();
11593
- std::array<int64_t, 1 << 4> hist_cur = {};
11594
12262
 
11595
12263
  const int n_per_row = tensor->ne[0];
11596
12264
  const int nrows = nelements / n_per_row;
@@ -11600,22 +12268,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11600
12268
 
11601
12269
  const int nchunk = (nelements + chunk_size - 1)/chunk_size;
11602
12270
  const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
11603
- new_size = llama_tensor_quantize_internal(new_type, f32_data, new_data, chunk_size, nrows, n_per_row, hist_cur.data(), imatrix, workers, nthread_use);
11604
-
11605
- LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
11606
- int64_t tot_count = 0;
11607
- for (size_t i = 0; i < hist_cur.size(); i++) {
11608
- hist_all[i] += hist_cur[i];
11609
- tot_count += hist_cur[i];
11610
- }
12271
+ new_size = llama_tensor_quantize_internal(new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix, workers, nthread_use);
11611
12272
 
11612
- if (tot_count > 0) {
11613
- LLAMA_LOG_INFO(" | hist: ");
11614
- for (size_t i = 0; i < hist_cur.size(); i++) {
11615
- LLAMA_LOG_INFO("%5.3f ", hist_cur[i] / float(nelements));
11616
- }
11617
- }
11618
- LLAMA_LOG_INFO("\n");
12273
+ LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
11619
12274
  }
11620
12275
  total_size_org += ggml_nbytes(tensor);
11621
12276
  total_size_new += new_size;
@@ -11644,24 +12299,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
11644
12299
  LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
11645
12300
  LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
11646
12301
 
11647
- // print histogram for all tensors
11648
- {
11649
- int64_t sum_all = 0;
11650
- for (size_t i = 0; i < hist_all.size(); i++) {
11651
- sum_all += hist_all[i];
11652
- }
11653
-
11654
- if (sum_all > 0) {
11655
- LLAMA_LOG_INFO("%s: hist: ", __func__);
11656
- for (size_t i = 0; i < hist_all.size(); i++) {
11657
- LLAMA_LOG_INFO("%5.3f ", hist_all[i] / float(sum_all));
11658
- }
11659
- LLAMA_LOG_INFO("\n");
11660
- }
11661
- }
11662
-
11663
12302
  if (qs.n_fallback > 0) {
11664
- LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) incompatible with k-quants and required fallback quantization\n",
12303
+ LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
11665
12304
  __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
11666
12305
  }
11667
12306
  }
@@ -11973,7 +12612,9 @@ struct llama_context_params llama_context_default_params() {
11973
12612
  struct llama_context_params result = {
11974
12613
  /*.seed =*/ LLAMA_DEFAULT_SEED,
11975
12614
  /*.n_ctx =*/ 512,
11976
- /*.n_batch =*/ 512,
12615
+ /*.n_batch =*/ 2048,
12616
+ /*.n_ubatch =*/ 512,
12617
+ /*.n_seq_max =*/ 1,
11977
12618
  /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
11978
12619
  /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
11979
12620
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -12126,6 +12767,17 @@ struct llama_context * llama_new_context_with_model(
12126
12767
  struct llama_context_params params) {
12127
12768
 
12128
12769
  if (!model) {
12770
+ LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
12771
+ return nullptr;
12772
+ }
12773
+
12774
+ if (params.n_batch == 0 && params.n_ubatch == 0) {
12775
+ LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
12776
+ return nullptr;
12777
+ }
12778
+
12779
+ if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
12780
+ LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
12129
12781
  return nullptr;
12130
12782
  }
12131
12783
 
@@ -12134,7 +12786,7 @@ struct llama_context * llama_new_context_with_model(
12134
12786
  const auto & hparams = model->hparams;
12135
12787
  auto & cparams = ctx->cparams;
12136
12788
 
12137
- cparams.n_batch = params.n_batch;
12789
+ // TODO: maybe add n_seq_max here too
12138
12790
  cparams.n_threads = params.n_threads;
12139
12791
  cparams.n_threads_batch = params.n_threads_batch;
12140
12792
  cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -12150,6 +12802,11 @@ struct llama_context * llama_new_context_with_model(
12150
12802
  cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
12151
12803
  cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
12152
12804
 
12805
+ // with causal attention, the batch size is limited by the context size
12806
+ cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
12807
+ cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
12808
+
12809
+
12153
12810
  cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
12154
12811
  hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
12155
12812
  hparams.n_ctx_train;
@@ -12170,6 +12827,8 @@ struct llama_context * llama_new_context_with_model(
12170
12827
  cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
12171
12828
  }
12172
12829
 
12830
+ cparams.causal_attn = hparams.causal_attn;
12831
+
12173
12832
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12174
12833
  if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12175
12834
  cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
@@ -12183,6 +12842,8 @@ struct llama_context * llama_new_context_with_model(
12183
12842
  }
12184
12843
 
12185
12844
  LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
12845
+ LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
12846
+ LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
12186
12847
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
12187
12848
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
12188
12849
 
@@ -12192,8 +12853,18 @@ struct llama_context * llama_new_context_with_model(
12192
12853
  ctx->rng = std::mt19937(params.seed);
12193
12854
  ctx->logits_all = params.logits_all;
12194
12855
 
12195
- const ggml_type type_k = params.type_k;
12196
- const ggml_type type_v = params.type_v;
12856
+ uint32_t kv_size = cparams.n_ctx;
12857
+ ggml_type type_k = params.type_k;
12858
+ ggml_type type_v = params.type_v;
12859
+
12860
+ // Mamba only needs a constant number of KV cache cells per sequence
12861
+ if (model->arch == LLM_ARCH_MAMBA) {
12862
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
12863
+ kv_size = std::max((uint32_t) 1, params.n_seq_max);
12864
+ // it's probably best to keep as much precision as possible for the states
12865
+ type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
12866
+ type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
12867
+ }
12197
12868
 
12198
12869
  GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
12199
12870
  GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
@@ -12293,7 +12964,7 @@ struct llama_context * llama_new_context_with_model(
12293
12964
  }
12294
12965
  ctx->backends.push_back(ctx->backend_cpu);
12295
12966
 
12296
- if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
12967
+ if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
12297
12968
  LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
12298
12969
  llama_free(ctx);
12299
12970
  return nullptr;
@@ -12317,44 +12988,31 @@ struct llama_context * llama_new_context_with_model(
12317
12988
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
12318
12989
  }
12319
12990
 
12320
- // resized during inference, reserve maximum
12321
- ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
12991
+ // graph outputs buffer
12992
+ {
12993
+ // resized during inference, reserve maximum
12994
+ ctx->logits_size = hparams.n_vocab*cparams.n_batch;
12995
+ ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
12322
12996
 
12323
- if (params.embeddings) {
12324
- ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
12325
- }
12997
+ const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
12326
12998
 
12327
- // graph inputs
12328
- {
12329
- ggml_init_params init_params = {
12330
- /* .mem_size */ ggml_tensor_overhead()*8,
12331
- /* .mem_buffer */ nullptr,
12332
- /* .no_alloc */ true,
12333
- };
12334
- ctx->ctx_input = ggml_init(init_params);
12335
-
12336
- ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
12337
- ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
12338
- ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
12339
- ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
12340
- ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
12341
- ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
12342
- ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
12343
- ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
12344
-
12345
- ggml_set_name(ctx->inp_tokens, "inp_tokens");
12346
- ggml_set_name(ctx->inp_embd, "inp_embd");
12347
- ggml_set_name(ctx->inp_pos, "inp_pos");
12348
- ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
12349
- ggml_set_name(ctx->inp_KQ_pos, "inp_KQ_pos");
12350
- ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
12351
- ggml_set_name(ctx->inp_mean, "inp_mean");
12352
- ggml_set_name(ctx->inp_cls, "inp_cls");
12353
-
12354
- ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
12355
- LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
12356
- ggml_backend_buffer_name(ctx->buf_input),
12357
- ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
12999
+ ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
13000
+ if (ctx->buf_output == nullptr) {
13001
+ LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
13002
+ llama_free(ctx);
13003
+ return nullptr;
13004
+ }
13005
+ ggml_backend_buffer_clear(ctx->buf_output, 0);
13006
+
13007
+
13008
+ ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
13009
+ if (params.embeddings) {
13010
+ ctx->embd = ctx->logits + ctx->logits_size;
13011
+ }
13012
+
13013
+ LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
13014
+ ggml_backend_buffer_name(ctx->buf_output),
13015
+ ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
12358
13016
  }
12359
13017
 
12360
13018
  // scheduler and compute buffers
@@ -12373,10 +13031,21 @@ struct llama_context * llama_new_context_with_model(
12373
13031
  // buffer used to store the computation graph and the tensor meta data
12374
13032
  ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
12375
13033
 
12376
- ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
13034
+ // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
13035
+ bool pipeline_parallel = llama_get_device_count() > 1 && model->n_gpu_layers > (int)model->hparams.n_layer && model->split_mode == LLAMA_SPLIT_MODE_LAYER;
13036
+ #ifndef GGML_USE_CUBLAS
13037
+ // pipeline parallelism requires support for async compute and events
13038
+ // currently this is only implemented in the CUDA backend
13039
+ pipeline_parallel = false;
13040
+ #endif
13041
+ ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel);
13042
+
13043
+ if (pipeline_parallel) {
13044
+ LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
13045
+ }
12377
13046
 
12378
13047
  // build worst-case graph
12379
- int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
13048
+ int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
12380
13049
  int n_past = cparams.n_ctx - n_tokens;
12381
13050
  llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
12382
13051
  ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
@@ -12399,7 +13068,7 @@ struct llama_context * llama_new_context_with_model(
12399
13068
 
12400
13069
  // note: the number of splits during measure is higher than during inference due to the kv shift
12401
13070
  int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
12402
- LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits);
13071
+ LLAMA_LOG_INFO("%s: graph splits: %d\n", __func__, n_splits);
12403
13072
  }
12404
13073
  }
12405
13074
 
@@ -12436,6 +13105,14 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
12436
13105
  return ctx->cparams.n_batch;
12437
13106
  }
12438
13107
 
13108
+ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
13109
+ return ctx->cparams.n_ubatch;
13110
+ }
13111
+
13112
+ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
13113
+ return ctx->kv_self.size;
13114
+ }
13115
+
12439
13116
  enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
12440
13117
  return model->vocab.type;
12441
13118
  }
@@ -12449,6 +13126,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
12449
13126
  case LLM_ARCH_MPT:
12450
13127
  case LLM_ARCH_REFACT:
12451
13128
  case LLM_ARCH_BLOOM:
13129
+ case LLM_ARCH_MAMBA:
12452
13130
  return LLAMA_ROPE_TYPE_NONE;
12453
13131
 
12454
13132
  // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -12485,7 +13163,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
12485
13163
  }
12486
13164
 
12487
13165
  int32_t llama_n_vocab(const struct llama_model * model) {
12488
- return model->vocab.id_to_token.size();
13166
+ return model->hparams.n_vocab;
12489
13167
  }
12490
13168
 
12491
13169
  int32_t llama_n_ctx_train(const struct llama_model * model) {
@@ -12595,10 +13273,10 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
12595
13273
  }
12596
13274
  }
12597
13275
 
12598
- struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
13276
+ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
12599
13277
  struct llama_kv_cache_view result = {
12600
13278
  /*.n_cells = */ 0,
12601
- /*.n_max_seq = */ n_max_seq,
13279
+ /*.n_seq_max = */ n_seq_max,
12602
13280
  /*.token_count = */ 0,
12603
13281
  /*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
12604
13282
  /*.max_contiguous = */ 0,
@@ -12626,7 +13304,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
12626
13304
  void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
12627
13305
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
12628
13306
  view->cells = (struct llama_kv_cache_view_cell *)p;
12629
- p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
13307
+ p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
12630
13308
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
12631
13309
  view->cells_sequences = (llama_seq_id *)p;
12632
13310
  }
@@ -12640,7 +13318,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
12640
13318
  uint32_t max_contig = 0;
12641
13319
  int32_t max_contig_idx = -1;
12642
13320
 
12643
- for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
13321
+ for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
12644
13322
  const size_t curr_size = kv_cells[i].seq_id.size();
12645
13323
  token_count += curr_size;
12646
13324
  c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -12657,7 +13335,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
12657
13335
 
12658
13336
  int seq_idx = 0;
12659
13337
  for (const llama_seq_id it : kv_cells[i].seq_id) {
12660
- if (seq_idx >= view->n_max_seq) {
13338
+ if (seq_idx >= view->n_seq_max) {
12661
13339
  break;
12662
13340
  }
12663
13341
  cs_curr[seq_idx] = it;
@@ -12666,7 +13344,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
12666
13344
  if (seq_idx != 0) {
12667
13345
  used_cells++;
12668
13346
  }
12669
- for (; seq_idx < view->n_max_seq; seq_idx++) {
13347
+ for (; seq_idx < view->n_seq_max; seq_idx++) {
12670
13348
  cs_curr[seq_idx] = -1;
12671
13349
  }
12672
13350
  }
@@ -12702,8 +13380,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
12702
13380
  llama_kv_cache_clear(ctx->kv_self);
12703
13381
  }
12704
13382
 
12705
- void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
12706
- llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
13383
+ bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
13384
+ return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
12707
13385
  }
12708
13386
 
12709
13387
  void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -12754,9 +13432,9 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
12754
13432
  const size_t s_rng = LLAMA_MAX_RNG_STATE;
12755
13433
  const size_t s_logits_size = sizeof(size_t);
12756
13434
  // assume worst case for logits although only currently set ones are serialized
12757
- const size_t s_logits = ctx->logits.capacity() * sizeof(float);
13435
+ const size_t s_logits = ctx->logits_size * sizeof(float);
12758
13436
  const size_t s_embedding_size = sizeof(size_t);
12759
- const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
13437
+ const size_t s_embedding = ctx->embd_size * sizeof(float);
12760
13438
  const size_t s_kv_buf_size = sizeof(size_t);
12761
13439
  const size_t s_kv_head = sizeof(uint32_t);
12762
13440
  const size_t s_kv_size = sizeof(uint32_t);
@@ -12854,23 +13532,23 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12854
13532
 
12855
13533
  // copy logits
12856
13534
  {
12857
- const size_t logits_size = ctx->logits.size();
13535
+ const size_t logits_size = ctx->logits_size;
12858
13536
 
12859
13537
  data_ctx->write(&logits_size, sizeof(logits_size));
12860
13538
 
12861
13539
  if (logits_size) {
12862
- data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
13540
+ data_ctx->write(ctx->logits, logits_size * sizeof(float));
12863
13541
  }
12864
13542
  }
12865
13543
 
12866
13544
  // copy embeddings
12867
13545
  {
12868
- const size_t embeddings_size = ctx->embd.size();
13546
+ const size_t embeddings_size = ctx->embd_size;
12869
13547
 
12870
13548
  data_ctx->write(&embeddings_size, sizeof(embeddings_size));
12871
13549
 
12872
13550
  if (embeddings_size) {
12873
- data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
13551
+ data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
12874
13552
  }
12875
13553
  }
12876
13554
 
@@ -12880,8 +13558,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12880
13558
  const auto & hparams = ctx->model.hparams;
12881
13559
 
12882
13560
  const uint32_t n_layer = hparams.n_layer;
12883
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
12884
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
13561
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
13562
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
12885
13563
 
12886
13564
  const size_t kv_buf_size = kv_self.total_size();
12887
13565
  const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@@ -12902,6 +13580,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12902
13580
  ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
12903
13581
  data_ctx->write(tmp_buf.data(), tmp_buf.size());
12904
13582
 
13583
+ if (kv_self.recurrent) {
13584
+ // v is contiguous for recurrent models
13585
+ // TODO: use other tensors for state models than k and v
13586
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
13587
+
13588
+ tmp_buf.resize(v_size);
13589
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
13590
+ data_ctx->write(tmp_buf.data(), tmp_buf.size());
13591
+ continue;
13592
+ }
13593
+
12905
13594
  // v is not contiguous, copy row by row
12906
13595
  const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12907
13596
  const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@@ -12962,12 +13651,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12962
13651
 
12963
13652
  memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
12964
13653
 
12965
- GGML_ASSERT(ctx->logits.capacity() >= logits_size);
13654
+ GGML_ASSERT(ctx->logits_size >= logits_size);
12966
13655
 
12967
13656
  if (logits_size) {
12968
- ctx->logits.resize(logits_size);
12969
-
12970
- memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
13657
+ memcpy(ctx->logits, inp, logits_size * sizeof(float));
12971
13658
  inp += logits_size * sizeof(float);
12972
13659
  }
12973
13660
  }
@@ -12978,12 +13665,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12978
13665
 
12979
13666
  memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
12980
13667
 
12981
- GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
13668
+ GGML_ASSERT(ctx->embd_size == embeddings_size);
12982
13669
 
12983
13670
  if (embeddings_size) {
12984
- ctx->embd.resize(embeddings_size);
12985
-
12986
- memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
13671
+ memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
12987
13672
  inp += embeddings_size * sizeof(float);
12988
13673
  }
12989
13674
  }
@@ -12994,8 +13679,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12994
13679
  const auto & hparams = ctx->model.hparams;
12995
13680
 
12996
13681
  const uint32_t n_layer = hparams.n_layer;
12997
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
12998
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
13682
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
13683
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
12999
13684
 
13000
13685
  size_t kv_buf_size;
13001
13686
  uint32_t kv_head;
@@ -13016,6 +13701,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
13016
13701
  ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
13017
13702
  inp += k_size;
13018
13703
 
13704
+ if (kv_self.recurrent) {
13705
+ // v is contiguous for recurrent models
13706
+ // TODO: use other tensors for state models than k and v
13707
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
13708
+
13709
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
13710
+ inp += v_size;
13711
+ continue;
13712
+ }
13713
+
13019
13714
  // v is not contiguous, copy row by row
13020
13715
  const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
13021
13716
  const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@@ -13158,6 +13853,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
13158
13853
  ctx->abort_callback_data = abort_callback_data;
13159
13854
  }
13160
13855
 
13856
+ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
13857
+ ctx->cparams.causal_attn = causal_attn;
13858
+ }
13859
+
13161
13860
  struct llama_batch llama_batch_get_one(
13162
13861
  llama_token * tokens,
13163
13862
  int32_t n_tokens,
@@ -13224,24 +13923,61 @@ int32_t llama_decode(
13224
13923
  return ret;
13225
13924
  }
13226
13925
 
13926
+ void llama_synchronize(struct llama_context * ctx) {
13927
+ ggml_backend_sched_synchronize(ctx->sched);
13928
+
13929
+ // FIXME: if multiple single tokens are evaluated without a synchronization,
13930
+ // the stats will be added to the prompt evaluation stats
13931
+ // this should only happen when using batch size 1 to evaluate a batch
13932
+
13933
+ // add the evaluation to the stats
13934
+ if (ctx->n_queued_tokens == 1) {
13935
+ ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
13936
+ ctx->n_eval++;
13937
+ } else if (ctx->n_queued_tokens > 1) {
13938
+ ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
13939
+ ctx->n_p_eval += ctx->n_queued_tokens;
13940
+ }
13941
+
13942
+ // get a more accurate load time, upon first eval
13943
+ if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
13944
+ ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
13945
+ ctx->has_evaluated_once = true;
13946
+ }
13947
+
13948
+ ctx->n_queued_tokens = 0;
13949
+ ctx->t_compute_start_us = 0;
13950
+ }
13951
+
13227
13952
  float * llama_get_logits(struct llama_context * ctx) {
13228
- return ctx->logits.data();
13953
+ llama_synchronize(ctx);
13954
+
13955
+ return ctx->logits;
13229
13956
  }
13230
13957
 
13231
13958
  float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
13232
13959
  assert(ctx->logits_valid.at(i));
13233
- return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
13960
+
13961
+ llama_synchronize(ctx);
13962
+
13963
+ return ctx->logits + i*ctx->model.hparams.n_vocab;
13234
13964
  }
13235
13965
 
13236
13966
  float * llama_get_embeddings(struct llama_context * ctx) {
13237
- return ctx->embd.data();
13967
+ llama_synchronize(ctx);
13968
+
13969
+ return ctx->embd;
13238
13970
  }
13239
13971
 
13240
13972
  float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
13241
- return ctx->embd.data() + i*ctx->model.hparams.n_embd;
13973
+ llama_synchronize(ctx);
13974
+
13975
+ return ctx->embd + i*ctx->model.hparams.n_embd;
13242
13976
  }
13243
13977
 
13244
13978
  float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
13979
+ llama_synchronize(ctx);
13980
+
13245
13981
  auto it = ctx->embd_seq.find(seq_id);
13246
13982
  if (it == ctx->embd_seq.end()) {
13247
13983
  return nullptr;
@@ -13251,14 +13987,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
13251
13987
  }
13252
13988
 
13253
13989
  const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
13990
+ GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
13254
13991
  return model->vocab.id_to_token[token].text.c_str();
13255
13992
  }
13256
13993
 
13257
13994
  float llama_token_get_score(const struct llama_model * model, llama_token token) {
13995
+ GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
13258
13996
  return model->vocab.id_to_token[token].score;
13259
13997
  }
13260
13998
 
13261
13999
  llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
14000
+ GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
13262
14001
  return model->vocab.id_to_token[token].type;
13263
14002
  }
13264
14003
 
@@ -13303,12 +14042,12 @@ int32_t llama_tokenize(
13303
14042
  const char * text,
13304
14043
  int32_t text_len,
13305
14044
  llama_token * tokens,
13306
- int32_t n_max_tokens,
14045
+ int32_t n_tokens_max,
13307
14046
  bool add_bos,
13308
14047
  bool special) {
13309
14048
  auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
13310
14049
 
13311
- if (n_max_tokens < (int) res.size()) {
14050
+ if (n_tokens_max < (int) res.size()) {
13312
14051
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
13313
14052
  return -((int) res.size());
13314
14053
  }
@@ -13322,9 +14061,9 @@ int32_t llama_tokenize(
13322
14061
 
13323
14062
  static std::string llama_decode_text(const std::string & text) {
13324
14063
  std::string decoded_text;
13325
- auto unicode_sequences = codepoints_from_utf8(text);
13326
- for (auto& unicode_sequence : unicode_sequences) {
13327
- decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
14064
+ auto unicode_sequences = unicode_cpts_from_utf8(text);
14065
+ for (auto & unicode_sequence : unicode_sequences) {
14066
+ decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(unicode_sequence));
13328
14067
  }
13329
14068
 
13330
14069
  return decoded_text;
@@ -13349,7 +14088,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
13349
14088
  } else if (llama_is_user_defined_token(model->vocab, token)) {
13350
14089
  std::string result = model->vocab.id_to_token[token].text;
13351
14090
  if (length < (int) result.length()) {
13352
- return -result.length();
14091
+ return -(int) result.length();
13353
14092
  }
13354
14093
  memcpy(buf, result.c_str(), result.length());
13355
14094
  return result.length();
@@ -13384,7 +14123,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
13384
14123
  } else if (llama_is_user_defined_token(model->vocab, token)) {
13385
14124
  std::string result = model->vocab.id_to_token[token].text;
13386
14125
  if (length < (int) result.length()) {
13387
- return -result.length();
14126
+ return -(int) result.length();
13388
14127
  }
13389
14128
  memcpy(buf, result.c_str(), result.length());
13390
14129
  return result.length();
@@ -13503,6 +14242,26 @@ static int32_t llama_chat_apply_template_internal(
13503
14242
  if (add_ass) {
13504
14243
  ss << "<start_of_turn>model\n";
13505
14244
  }
14245
+ } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
14246
+ // OrionStarAI/Orion-14B-Chat
14247
+ std::string system_prompt = "";
14248
+ for (auto message : chat) {
14249
+ std::string role(message->role);
14250
+ if (role == "system") {
14251
+ // there is no system message support, we will merge it with user prompt
14252
+ system_prompt = message->content;
14253
+ continue;
14254
+ } else if (role == "user") {
14255
+ ss << "Human: ";
14256
+ if (!system_prompt.empty()) {
14257
+ ss << system_prompt << "\n\n";
14258
+ system_prompt = "";
14259
+ }
14260
+ ss << message->content << "\n\nAssistant: </s>";
14261
+ } else {
14262
+ ss << message->content << "</s>";
14263
+ }
14264
+ }
13506
14265
  } else {
13507
14266
  // template not supported
13508
14267
  return -1;