llama_cpp 0.14.2 → 0.14.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -214,6 +214,7 @@ enum llm_arch {
214
214
  LLM_ARCH_GEMMA,
215
215
  LLM_ARCH_STARCODER2,
216
216
  LLM_ARCH_MAMBA,
217
+ LLM_ARCH_COMMAND_R,
217
218
  LLM_ARCH_UNKNOWN,
218
219
  };
219
220
 
@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
243
244
  { LLM_ARCH_GEMMA, "gemma" },
244
245
  { LLM_ARCH_STARCODER2, "starcoder2" },
245
246
  { LLM_ARCH_MAMBA, "mamba" },
247
+ { LLM_ARCH_COMMAND_R, "command-r" },
246
248
  { LLM_ARCH_UNKNOWN, "(unknown)" },
247
249
  };
248
250
 
@@ -268,6 +270,7 @@ enum llm_kv {
268
270
  LLM_KV_EXPERT_COUNT,
269
271
  LLM_KV_EXPERT_USED_COUNT,
270
272
  LLM_KV_POOLING_TYPE,
273
+ LLM_KV_LOGIT_SCALE,
271
274
 
272
275
  LLM_KV_ATTENTION_HEAD_COUNT,
273
276
  LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -332,6 +335,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
332
335
  { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
333
336
  { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
334
337
  { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
338
+ { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
335
339
 
336
340
  { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
337
341
  { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -536,6 +540,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
536
540
  {
537
541
  { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
538
542
  { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
543
+ { LLM_TENSOR_OUTPUT, "output"},
539
544
  { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
540
545
  { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
541
546
  { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
@@ -838,6 +843,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
838
843
  { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
839
844
  },
840
845
  },
846
+ {
847
+ LLM_ARCH_COMMAND_R,
848
+ {
849
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
850
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
851
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
852
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
853
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
854
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
855
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
856
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
857
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
858
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
859
+ },
860
+ },
841
861
  {
842
862
  LLM_ARCH_UNKNOWN,
843
863
  {
@@ -1597,6 +1617,7 @@ enum e_model {
1597
1617
  MODEL_20B,
1598
1618
  MODEL_30B,
1599
1619
  MODEL_34B,
1620
+ MODEL_35B,
1600
1621
  MODEL_40B,
1601
1622
  MODEL_65B,
1602
1623
  MODEL_70B,
@@ -1643,6 +1664,7 @@ struct llama_hparams {
1643
1664
 
1644
1665
  float f_clamp_kqv = 0.0f;
1645
1666
  float f_max_alibi_bias = 0.0f;
1667
+ float f_logit_scale = 0.0f;
1646
1668
 
1647
1669
  bool causal_attn = true;
1648
1670
  bool need_kq_pos = false;
@@ -1873,6 +1895,31 @@ struct llama_kv_cache {
1873
1895
  }
1874
1896
  };
1875
1897
 
1898
+ struct llama_control_vector {
1899
+ std::vector<struct ggml_tensor *> tensors; // per layer
1900
+ std::vector<struct ggml_context *> ctxs;
1901
+ std::vector<ggml_backend_buffer_t> bufs;
1902
+
1903
+ int32_t layer_start = -1;
1904
+ int32_t layer_end = -1;
1905
+
1906
+ ggml_tensor * tensor_for(int il) const {
1907
+ if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
1908
+ return nullptr;
1909
+ }
1910
+ return tensors[il];
1911
+ }
1912
+
1913
+ ~llama_control_vector() {
1914
+ for (struct ggml_context * ctx : ctxs) {
1915
+ ggml_free(ctx);
1916
+ }
1917
+ for (ggml_backend_buffer_t buf : bufs) {
1918
+ ggml_backend_buffer_free(buf);
1919
+ }
1920
+ }
1921
+ };
1922
+
1876
1923
  struct llama_vocab {
1877
1924
  using id = int32_t;
1878
1925
  using token = std::string;
@@ -1994,6 +2041,11 @@ struct llama_model {
1994
2041
  ggml_free(ctx);
1995
2042
  }
1996
2043
  for (ggml_backend_buffer_t buf : bufs) {
2044
+ #ifdef GGML_USE_CUBLAS
2045
+ if (ggml_backend_buffer_get_type(buf) == ggml_backend_cpu_buffer_type()) {
2046
+ ggml_backend_cuda_unregister_host_buffer(ggml_backend_buffer_get_base(buf));
2047
+ }
2048
+ #endif
1997
2049
  ggml_backend_buffer_free(buf);
1998
2050
  }
1999
2051
  }
@@ -2087,6 +2139,9 @@ struct llama_context {
2087
2139
  struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
2088
2140
  struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2089
2141
 
2142
+ // control vectors
2143
+ struct llama_control_vector cvec;
2144
+
2090
2145
  #ifdef GGML_USE_MPI
2091
2146
  ggml_mpi_context * ctx_mpi = NULL;
2092
2147
  #endif
@@ -3231,6 +3286,7 @@ static const char * llama_model_type_name(e_model type) {
3231
3286
  case MODEL_20B: return "20B";
3232
3287
  case MODEL_30B: return "30B";
3233
3288
  case MODEL_34B: return "34B";
3289
+ case MODEL_35B: return "35B";
3234
3290
  case MODEL_40B: return "40B";
3235
3291
  case MODEL_65B: return "65B";
3236
3292
  case MODEL_70B: return "70B";
@@ -3623,6 +3679,15 @@ static void llm_load_hparams(
3623
3679
  default: model.type = e_model::MODEL_UNKNOWN;
3624
3680
  }
3625
3681
  } break;
3682
+ case LLM_ARCH_COMMAND_R:
3683
+ {
3684
+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
3685
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
3686
+ switch (hparams.n_layer) {
3687
+ case 40: model.type = e_model::MODEL_35B; break;
3688
+ default: model.type = e_model::MODEL_UNKNOWN;
3689
+ }
3690
+ } break;
3626
3691
  default: (void)0;
3627
3692
  }
3628
3693
 
@@ -3944,6 +4009,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3944
4009
  LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
3945
4010
  LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
3946
4011
  LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
4012
+ LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
3947
4013
  LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
3948
4014
  LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
3949
4015
  LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@@ -4235,9 +4301,9 @@ static bool llm_load_tensors(
4235
4301
  {
4236
4302
  model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4237
4303
  model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
4238
- if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_OUTPUT, "weight").c_str()) >= 0) {
4239
- model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
4240
- } else {
4304
+
4305
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4306
+ if (!model.output) {
4241
4307
  model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU
4242
4308
  ml.n_created--; // artificial tensor
4243
4309
  ml.size_data += ggml_nbytes(model.output);
@@ -4442,10 +4508,12 @@ static bool llm_load_tensors(
4442
4508
  model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4443
4509
  model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, false);
4444
4510
 
4445
- // same as tok_embd, duplicated to allow offloading
4446
- model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4447
- ml.n_created--; // artificial tensor
4448
- ml.size_data += ggml_nbytes(model.output);
4511
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4512
+ if (!model.output) {
4513
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU
4514
+ ml.n_created--; // artificial tensor
4515
+ ml.size_data += ggml_nbytes(model.output);
4516
+ }
4449
4517
  }
4450
4518
 
4451
4519
  for (int i = 0; i < n_layer; ++i) {
@@ -4918,6 +4986,37 @@ static bool llm_load_tensors(
4918
4986
  layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
4919
4987
  }
4920
4988
  } break;
4989
+ case LLM_ARCH_COMMAND_R:
4990
+ {
4991
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4992
+
4993
+ // output
4994
+ {
4995
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4996
+ // init output from the input tok embed
4997
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4998
+ ml.n_created--; // artificial tensor
4999
+ ml.size_data += ggml_nbytes(model.output);
5000
+ }
5001
+
5002
+ for (int i = 0; i < n_layer; ++i) {
5003
+ ggml_context * ctx_layer = ctx_for_layer(i);
5004
+ ggml_context * ctx_split = ctx_for_layer_split(i);
5005
+
5006
+ auto & layer = model.layers[i];
5007
+
5008
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
5009
+
5010
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
5011
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
5012
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
5013
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
5014
+
5015
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
5016
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
5017
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
5018
+ }
5019
+ } break;
4921
5020
  default:
4922
5021
  throw std::runtime_error("unknown architecture");
4923
5022
  }
@@ -4942,6 +5041,13 @@ static bool llm_load_tensors(
4942
5041
  size_t first, last;
4943
5042
  ml.get_mapping_range(&first, &last, ctx);
4944
5043
  buf = ggml_backend_cpu_buffer_from_ptr((char *) ml.mapping->addr + first, last - first);
5044
+ #ifdef GGML_USE_CUBLAS
5045
+ if (n_layer >= n_gpu_layers) {
5046
+ ggml_backend_cuda_register_host_buffer(
5047
+ ggml_backend_buffer_get_base(buf),
5048
+ ggml_backend_buffer_get_size(buf));
5049
+ }
5050
+ #endif
4945
5051
  }
4946
5052
  #ifdef GGML_USE_METAL
4947
5053
  else if (ml.use_mmap && buft == ggml_backend_metal_buffer_type()) {
@@ -5064,6 +5170,16 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
5064
5170
  }
5065
5171
  #endif
5066
5172
 
5173
+ #ifdef GGML_USE_SYCL
5174
+ if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
5175
+ ggml_backend_sycl_set_single_device_mode(params.main_gpu);
5176
+ //SYCL use device index (0, 1, 2) directly, uer input device id, then convert to device index.
5177
+ params.main_gpu = ggml_backend_sycl_get_device_index(params.main_gpu);
5178
+ } else {
5179
+ ggml_backend_sycl_set_mul_device_mode();
5180
+ }
5181
+ #endif
5182
+
5067
5183
  if (!llm_load_tensors(
5068
5184
  ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
5069
5185
  params.progress_callback, params.progress_callback_user_data
@@ -5858,6 +5974,12 @@ struct llm_build_context {
5858
5974
  }
5859
5975
 
5860
5976
  cur = ggml_add(ctx0, cur, ffn_inp);
5977
+ cb(cur, "ffn_out", il);
5978
+
5979
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
5980
+ if (layer_dir != nullptr) {
5981
+ cur = ggml_add(ctx0, cur, layer_dir);
5982
+ }
5861
5983
  cb(cur, "l_out", il);
5862
5984
 
5863
5985
  // input for next layer
@@ -5893,7 +6015,7 @@ struct llm_build_context {
5893
6015
  inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
5894
6016
 
5895
6017
  // inp_pos - contains the positions
5896
- struct ggml_tensor * inp_pos = build_inp_pos();
6018
+ struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
5897
6019
 
5898
6020
  // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5899
6021
  struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -5943,7 +6065,6 @@ struct llm_build_context {
5943
6065
  cb(Qcur, "Qcur", il);
5944
6066
  cb(Kcur, "Kcur", il);
5945
6067
 
5946
-
5947
6068
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5948
6069
  model.layers[il].wo, NULL,
5949
6070
  Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
@@ -8125,7 +8246,6 @@ struct llm_build_context {
8125
8246
  cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8126
8247
  model.layers[il].wo, model.layers[il].bo,
8127
8248
  Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8128
- cb(cur, "kqv_out", il);
8129
8249
  }
8130
8250
 
8131
8251
  struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -8305,6 +8425,121 @@ struct llm_build_context {
8305
8425
 
8306
8426
  return gf;
8307
8427
  }
8428
+
8429
+ struct ggml_cgraph * build_command_r() {
8430
+
8431
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8432
+
8433
+ const int64_t n_embd_head = hparams.n_embd_head_v;
8434
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8435
+ const float f_logit_scale = hparams.f_logit_scale;
8436
+
8437
+ struct ggml_tensor * cur;
8438
+ struct ggml_tensor * inpL;
8439
+
8440
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
8441
+
8442
+ // inp_pos - contains the positions
8443
+ struct ggml_tensor * inp_pos = build_inp_pos();
8444
+
8445
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8446
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
8447
+
8448
+ for (int il = 0; il < n_layer; ++il) {
8449
+
8450
+ // norm
8451
+ cur = llm_build_norm(ctx0, inpL, hparams,
8452
+ model.layers[il].attn_norm, NULL,
8453
+ LLM_NORM, cb, il);
8454
+ cb(cur, "attn_norm", il);
8455
+ struct ggml_tensor * ffn_inp = cur;
8456
+
8457
+ // self-attention
8458
+ {
8459
+ // compute Q and K and RoPE them
8460
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
8461
+ cb(Qcur, "Qcur", il);
8462
+ if (model.layers[il].bq) {
8463
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8464
+ cb(Qcur, "Qcur", il);
8465
+ }
8466
+
8467
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
8468
+ cb(Kcur, "Kcur", il);
8469
+ if (model.layers[il].bk) {
8470
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8471
+ cb(Kcur, "Kcur", il);
8472
+ }
8473
+
8474
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
8475
+ cb(Vcur, "Vcur", il);
8476
+ if (model.layers[il].bv) {
8477
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8478
+ cb(Vcur, "Vcur", il);
8479
+ }
8480
+
8481
+ Qcur = ggml_rope_custom(
8482
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8483
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8484
+ ext_factor, attn_factor, beta_fast, beta_slow
8485
+ );
8486
+ cb(Qcur, "Qcur", il);
8487
+
8488
+ Kcur = ggml_rope_custom(
8489
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8490
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8491
+ ext_factor, attn_factor, beta_fast, beta_slow
8492
+ );
8493
+ cb(Kcur, "Kcur", il);
8494
+
8495
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8496
+ model.layers[il].wo, model.layers[il].bo,
8497
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8498
+ }
8499
+
8500
+ struct ggml_tensor * attn_out = cur;
8501
+
8502
+ // feed-forward network
8503
+ {
8504
+ cur = llm_build_ffn(ctx0, ffn_inp,
8505
+ model.layers[il].ffn_up, NULL,
8506
+ model.layers[il].ffn_gate, NULL,
8507
+ model.layers[il].ffn_down, NULL,
8508
+ NULL,
8509
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8510
+ cb(cur, "ffn_out", il);
8511
+ }
8512
+
8513
+ // add together residual + FFN + self-attention
8514
+ cur = ggml_add(ctx0, cur, inpL);
8515
+ cur = ggml_add(ctx0, cur, attn_out);
8516
+ cb(cur, "l_out", il);
8517
+
8518
+ // input for next layer
8519
+ inpL = cur;
8520
+ }
8521
+
8522
+ cur = inpL;
8523
+
8524
+ cur = llm_build_norm(ctx0, cur, hparams,
8525
+ model.output_norm, NULL,
8526
+ LLM_NORM, cb, -1);
8527
+ cb(cur, "result_norm", -1);
8528
+
8529
+ // lm_head
8530
+ cur = ggml_mul_mat(ctx0, model.output, cur);
8531
+
8532
+ if (f_logit_scale) {
8533
+ cur = ggml_scale(ctx0, cur, f_logit_scale);
8534
+ }
8535
+
8536
+ cb(cur, "result_output", -1);
8537
+
8538
+ ggml_build_forward_expand(gf, cur);
8539
+
8540
+ return gf;
8541
+
8542
+ }
8308
8543
  };
8309
8544
 
8310
8545
  static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8380,12 +8615,15 @@ static struct ggml_cgraph * llama_build_graph(
8380
8615
  }
8381
8616
 
8382
8617
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
8383
- // to fix this, we assign the norm layer manually to the backend of its layer
8384
- if (il != -1 && strcmp(name, "norm") == 0) {
8385
- for (auto * backend : lctx.backends) {
8386
- if (ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
8387
- ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
8388
- break;
8618
+ // FIXME: fix in ggml_backend_sched
8619
+ const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
8620
+ if (batch.n_tokens < 32 || full_offload) {
8621
+ if (il != -1 && strcmp(name, "norm") == 0) {
8622
+ for (auto * backend : lctx.backends) {
8623
+ if (ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
8624
+ ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
8625
+ break;
8626
+ }
8389
8627
  }
8390
8628
  }
8391
8629
  }
@@ -8487,6 +8725,10 @@ static struct ggml_cgraph * llama_build_graph(
8487
8725
  {
8488
8726
  result = llm.build_mamba();
8489
8727
  } break;
8728
+ case LLM_ARCH_COMMAND_R:
8729
+ {
8730
+ result = llm.build_command_r();
8731
+ } break;
8490
8732
  default:
8491
8733
  GGML_ASSERT(false);
8492
8734
  }
@@ -12802,6 +13044,9 @@ struct llama_context * llama_new_context_with_model(
12802
13044
  cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
12803
13045
  cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
12804
13046
 
13047
+ // this is necessary due to kv_self.n being padded later during inference
13048
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
13049
+
12805
13050
  // with causal attention, the batch size is limited by the context size
12806
13051
  cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
12807
13052
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
@@ -12882,27 +13127,25 @@ struct llama_context * llama_new_context_with_model(
12882
13127
  ctx->backends.push_back(ctx->backend_metal);
12883
13128
  }
12884
13129
  #elif defined(GGML_USE_CUBLAS)
12885
- if (model->n_gpu_layers > 0) {
13130
+ if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
12886
13131
  // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
12887
- if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
12888
- ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
13132
+ ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
13133
+ if (backend == nullptr) {
13134
+ LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
13135
+ llama_free(ctx);
13136
+ return nullptr;
13137
+ }
13138
+ ctx->backends.push_back(backend);
13139
+ } else {
13140
+ // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
13141
+ for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
13142
+ ggml_backend_t backend = ggml_backend_cuda_init(device);
12889
13143
  if (backend == nullptr) {
12890
- LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
13144
+ LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
12891
13145
  llama_free(ctx);
12892
13146
  return nullptr;
12893
13147
  }
12894
13148
  ctx->backends.push_back(backend);
12895
- } else {
12896
- // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
12897
- for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
12898
- ggml_backend_t backend = ggml_backend_cuda_init(device);
12899
- if (backend == nullptr) {
12900
- LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
12901
- llama_free(ctx);
12902
- return nullptr;
12903
- }
12904
- ctx->backends.push_back(backend);
12905
- }
12906
13149
  }
12907
13150
  }
12908
13151
  #elif defined(GGML_USE_VULKAN)
@@ -12921,23 +13164,22 @@ struct llama_context * llama_new_context_with_model(
12921
13164
  if (model->n_gpu_layers > 0) {
12922
13165
  // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
12923
13166
  if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
12924
- int main_gpu_index = ggml_backend_sycl_get_device_index(model->main_gpu);
12925
- ggml_backend_t backend = ggml_backend_sycl_init(main_gpu_index);
13167
+ ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
12926
13168
  if (backend == nullptr) {
12927
- LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d)backend\n", __func__, model->main_gpu, main_gpu_index);
13169
+ int main_gpu_id = ggml_backend_sycl_get_device_id(model->main_gpu);
13170
+ LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, main_gpu_id, model->main_gpu);
12928
13171
  llama_free(ctx);
12929
13172
  return nullptr;
12930
13173
  }
12931
13174
  ctx->backends.push_back(backend);
12932
13175
  } else {
12933
13176
  // LLAMA_SPLIT_LAYER requires a backend for each GPU
12934
- int id_list[GGML_SYCL_MAX_DEVICES];
12935
- ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
12936
13177
  for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
12937
- int device_id = id_list[i];
12938
13178
  ggml_backend_t backend = ggml_backend_sycl_init(i);
12939
13179
  if (backend == nullptr) {
12940
- LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d)backend\n", __func__, device_id, i);
13180
+ int id_list[GGML_SYCL_MAX_DEVICES];
13181
+ ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
13182
+ LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, id_list[i], i);
12941
13183
  llama_free(ctx);
12942
13184
  return nullptr;
12943
13185
  }
@@ -13061,14 +13303,17 @@ struct llama_context * llama_new_context_with_model(
13061
13303
  ggml_backend_t backend = ctx->backends[i];
13062
13304
  ggml_backend_buffer_type_t buft = backend_buft[i];
13063
13305
  size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
13064
- LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
13065
- ggml_backend_buft_name(buft),
13066
- size / 1024.0 / 1024.0);
13306
+ if (size > 1) {
13307
+ LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
13308
+ ggml_backend_buft_name(buft),
13309
+ size / 1024.0 / 1024.0);
13310
+ }
13067
13311
  }
13068
13312
 
13069
13313
  // note: the number of splits during measure is higher than during inference due to the kv shift
13070
13314
  int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
13071
- LLAMA_LOG_INFO("%s: graph splits: %d\n", __func__, n_splits);
13315
+ LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, gf->n_nodes);
13316
+ LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
13072
13317
  }
13073
13318
  }
13074
13319
 
@@ -13138,6 +13383,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
13138
13383
  case LLM_ARCH_ORION:
13139
13384
  case LLM_ARCH_INTERNLM2:
13140
13385
  case LLM_ARCH_MINICPM:
13386
+ case LLM_ARCH_COMMAND_R:
13141
13387
  return LLAMA_ROPE_TYPE_NORM;
13142
13388
 
13143
13389
  // the pairs of head values are offset by n_rot/2
@@ -13174,6 +13420,10 @@ int32_t llama_n_embd(const struct llama_model * model) {
13174
13420
  return model->hparams.n_embd;
13175
13421
  }
13176
13422
 
13423
+ int32_t llama_n_layer(const struct llama_model * model) {
13424
+ return model->hparams.n_layer;
13425
+ }
13426
+
13177
13427
  float llama_rope_freq_scale_train(const struct llama_model * model) {
13178
13428
  return model->hparams.rope_freq_scale_train;
13179
13429
  }
@@ -13273,6 +13523,96 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
13273
13523
  }
13274
13524
  }
13275
13525
 
13526
+ static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
13527
+ GGML_ASSERT(cvec.tensors.empty());
13528
+ GGML_ASSERT(cvec.ctxs.empty());
13529
+ GGML_ASSERT(cvec.bufs.empty());
13530
+
13531
+ // count layer buffer types
13532
+ std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
13533
+ for (int64_t i = 0; i < model.hparams.n_layer; i++) {
13534
+ buft_layer_count[model.buft_layer[i].buft]++;
13535
+ }
13536
+
13537
+ // allocate contexts
13538
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
13539
+ for (auto & it : buft_layer_count) {
13540
+ int n_layers = it.second;
13541
+ struct ggml_init_params params = {
13542
+ /*.mem_size =*/ n_layers * ggml_tensor_overhead(),
13543
+ /*.mem_buffer =*/ NULL,
13544
+ /*.no_alloc =*/ true,
13545
+ };
13546
+ ggml_context * ctx = ggml_init(params);
13547
+ if (!ctx) {
13548
+ LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
13549
+ return 1;
13550
+ }
13551
+ ctx_map[it.first] = ctx;
13552
+ }
13553
+
13554
+ // make tensors
13555
+ cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
13556
+ for (size_t il = 1; il < model.hparams.n_layer; il++) {
13557
+ struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
13558
+ ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
13559
+ cvec.tensors.push_back(tensor);
13560
+ }
13561
+
13562
+ // allocate tensors / buffers and zero
13563
+ for (auto it : ctx_map) {
13564
+ ggml_backend_buffer_type_t buft = it.first;
13565
+ ggml_context * ctx = it.second;
13566
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
13567
+ if (!buf) {
13568
+ LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
13569
+ return false;
13570
+ }
13571
+ ggml_backend_buffer_clear(buf, 0);
13572
+ cvec.ctxs.push_back(ctx);
13573
+ cvec.bufs.push_back(buf);
13574
+ }
13575
+
13576
+ return true;
13577
+ }
13578
+
13579
+ int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
13580
+ const llama_model & model = lctx->model;
13581
+ llama_control_vector & cvec = lctx->cvec;
13582
+
13583
+ if (data == nullptr) {
13584
+ // disable the current control vector (but leave allocated for later)
13585
+ cvec.layer_start = -1;
13586
+ cvec.layer_end = -1;
13587
+ return 0;
13588
+ }
13589
+
13590
+ if (n_embd != (int) model.hparams.n_embd) {
13591
+ LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
13592
+ return 1;
13593
+ }
13594
+
13595
+ if (cvec.tensors.empty()) {
13596
+ if (!llama_control_vector_init(cvec, model)) {
13597
+ return 1;
13598
+ }
13599
+ }
13600
+
13601
+ cvec.layer_start = il_start;
13602
+ cvec.layer_end = il_end;
13603
+
13604
+ for (size_t il = 1; il < model.hparams.n_layer; il++) {
13605
+ assert(cvec.tensors[il] != nullptr);
13606
+
13607
+ const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
13608
+ if (off + n_embd <= len) {
13609
+ ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
13610
+ }
13611
+ }
13612
+
13613
+ return 0;
13614
+ }
13615
+
13276
13616
  struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
13277
13617
  struct llama_kv_cache_view result = {
13278
13618
  /*.n_cells = */ 0,