llama_cpp 0.12.4 → 0.12.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,24 +8,29 @@ extern "C" {
8
8
  #endif
9
9
 
10
10
  #define GGML_VK_NAME "Vulkan"
11
+ #define GGML_VK_MAX_DEVICES 16
11
12
 
12
- GGML_API void ggml_vk_init(void);
13
+ GGML_API void ggml_vk_init_cpu_assist(void);
13
14
 
14
- GGML_API void ggml_vk_preallocate_buffers_graph(struct ggml_tensor * node);
15
- GGML_API void ggml_vk_preallocate_buffers(void);
16
- GGML_API void ggml_vk_build_graph(struct ggml_tensor * node, bool last_node);
17
- GGML_API bool ggml_vk_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
15
+ GGML_API void ggml_vk_preallocate_buffers_graph_cpu_assist(struct ggml_tensor * node);
16
+ GGML_API void ggml_vk_preallocate_buffers_cpu_assist(void);
17
+ GGML_API void ggml_vk_build_graph_cpu_assist(struct ggml_tensor * node, bool last_node);
18
+ GGML_API bool ggml_vk_compute_forward_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor);
18
19
  #ifdef GGML_VULKAN_CHECK_RESULTS
19
- void ggml_vk_check_results_1(struct ggml_compute_params * params, struct ggml_tensor * tensor);
20
+ void ggml_vk_check_results_1_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor);
20
21
  #endif
21
- GGML_API void ggml_vk_graph_cleanup(void);
22
+ GGML_API void ggml_vk_graph_cleanup_cpu_assist(void);
23
+ GGML_API void ggml_vk_free_cpu_assist(void);
22
24
 
23
25
  // backend API
24
- GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(void);
26
+ GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
25
27
 
26
28
  GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
29
+ GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void);
30
+ GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
31
+ GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
27
32
 
28
- GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(void);
33
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
29
34
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
30
35
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
31
36
 
@@ -2343,7 +2343,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2343
2343
  #elif defined(GGML_USE_CLBLAST)
2344
2344
  ggml_cl_init();
2345
2345
  #elif defined(GGML_USE_VULKAN)
2346
- ggml_vk_init();
2346
+ ggml_vk_init_cpu_assist();
2347
2347
  #elif defined(GGML_USE_SYCL)
2348
2348
  ggml_init_sycl();
2349
2349
  #endif
@@ -2470,7 +2470,8 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
2470
2470
  size_t max_size = 0;
2471
2471
 
2472
2472
  for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
2473
- max_size = MAX(max_size, ggml_nbytes(tensor));
2473
+ size_t bytes = ggml_nbytes(tensor);
2474
+ max_size = MAX(max_size, bytes);
2474
2475
  }
2475
2476
 
2476
2477
  return max_size;
@@ -11887,8 +11888,10 @@ GGML_CALL void ggml_rope_yarn_corr_dims(
11887
11888
  int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
11888
11889
  ) {
11889
11890
  // start and end correction dims
11890
- dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
11891
- dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
11891
+ float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
11892
+ float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
11893
+ dims[0] = MAX(0, start);
11894
+ dims[1] = MIN(n_dims - 1, end);
11892
11895
  }
11893
11896
 
11894
11897
  static void ggml_compute_forward_rope_f32(
@@ -14847,10 +14850,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14847
14850
  GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
14848
14851
  GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
14849
14852
  #elif defined(GGML_USE_VULKAN)
14850
- const bool skip_cpu = ggml_vk_compute_forward(params, tensor);
14853
+ const bool skip_cpu = ggml_vk_compute_forward_cpu_assist(params, tensor);
14851
14854
  #ifdef GGML_VULKAN_CHECK_RESULTS
14852
14855
  if (skip_cpu) {
14853
- ggml_vk_check_results_1(params, tensor);
14856
+ ggml_vk_check_results_1_cpu_assist(params, tensor);
14854
14857
  }
14855
14858
  #endif
14856
14859
  if (skip_cpu) {
@@ -17266,12 +17269,12 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
17266
17269
 
17267
17270
  #ifdef GGML_USE_VULKAN
17268
17271
  for (int i = 0; i < cgraph->n_nodes; i++) {
17269
- ggml_vk_preallocate_buffers_graph(cgraph->nodes[i]);
17272
+ ggml_vk_preallocate_buffers_graph_cpu_assist(cgraph->nodes[i]);
17270
17273
  }
17271
- ggml_vk_preallocate_buffers();
17274
+ ggml_vk_preallocate_buffers_cpu_assist();
17272
17275
 
17273
17276
  for (int i = 0; i < cgraph->n_nodes; i++) {
17274
- ggml_vk_build_graph(cgraph->nodes[i], i == cgraph->n_nodes - 1);
17277
+ ggml_vk_build_graph_cpu_assist(cgraph->nodes[i], i == cgraph->n_nodes - 1);
17275
17278
  }
17276
17279
  #endif
17277
17280
 
@@ -17327,7 +17330,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
17327
17330
  }
17328
17331
 
17329
17332
  #ifdef GGML_USE_VULKAN
17330
- ggml_vk_graph_cleanup();
17333
+ ggml_vk_graph_cleanup_cpu_assist();
17331
17334
  #endif
17332
17335
 
17333
17336
  // performance stats (graph)
@@ -205,10 +205,11 @@ enum llm_arch {
205
205
  LLM_ARCH_CODESHELL,
206
206
  LLM_ARCH_ORION,
207
207
  LLM_ARCH_INTERNLM2,
208
+ LLM_ARCH_MINICPM,
208
209
  LLM_ARCH_UNKNOWN,
209
210
  };
210
211
 
211
- static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
212
+ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
212
213
  { LLM_ARCH_LLAMA, "llama" },
213
214
  { LLM_ARCH_FALCON, "falcon" },
214
215
  { LLM_ARCH_GPT2, "gpt2" },
@@ -228,6 +229,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
228
229
  { LLM_ARCH_CODESHELL, "codeshell" },
229
230
  { LLM_ARCH_ORION, "orion" },
230
231
  { LLM_ARCH_INTERNLM2, "internlm2" },
232
+ { LLM_ARCH_MINICPM, "minicpm" },
231
233
  };
232
234
 
233
235
  enum llm_kv {
@@ -285,7 +287,7 @@ enum llm_kv {
285
287
  LLM_KV_TOKENIZER_RWKV,
286
288
  };
287
289
 
288
- static std::map<llm_kv, std::string> LLM_KV_NAMES = {
290
+ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
289
291
  { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
290
292
  { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
291
293
  { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
@@ -346,7 +348,7 @@ struct LLM_KV {
346
348
  llm_arch arch;
347
349
 
348
350
  std::string operator()(llm_kv kv) const {
349
- return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str());
351
+ return ::format(LLM_KV_NAMES[kv], LLM_ARCH_NAMES[arch]);
350
352
  }
351
353
  };
352
354
 
@@ -690,6 +692,29 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
690
692
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
691
693
  },
692
694
  },
695
+ {
696
+ LLM_ARCH_MINICPM,
697
+ {
698
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
699
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
700
+ { LLM_TENSOR_OUTPUT, "output" },
701
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
702
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
703
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
704
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
705
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
706
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
707
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
708
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
709
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
710
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
711
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
712
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
713
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
714
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
715
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
716
+ },
717
+ },
693
718
  {
694
719
  LLM_ARCH_UNKNOWN,
695
720
  {
@@ -747,13 +772,13 @@ struct LLM_TN {
747
772
  // gguf helpers
748
773
  //
749
774
 
750
- static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
775
+ static std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
751
776
  { LLAMA_ROPE_SCALING_NONE, "none" },
752
777
  { LLAMA_ROPE_SCALING_LINEAR, "linear" },
753
778
  { LLAMA_ROPE_SCALING_YARN, "yarn" },
754
779
  };
755
780
 
756
- static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
781
+ static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
757
782
  for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
758
783
  if (kv.second == name) {
759
784
  return kv.first;
@@ -1330,7 +1355,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
1330
1355
  #elif defined(GGML_USE_CUBLAS)
1331
1356
  buft = ggml_backend_cuda_buffer_type(gpu);
1332
1357
  #elif defined(GGML_USE_VULKAN)
1333
- buft = ggml_backend_vk_buffer_type();
1358
+ buft = ggml_backend_vk_buffer_type(gpu);
1334
1359
  #elif defined(GGML_USE_SYCL)
1335
1360
  buft = ggml_backend_sycl_buffer_type(gpu);
1336
1361
  #elif defined(GGML_USE_CLBLAST)
@@ -1367,6 +1392,33 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
1367
1392
  GGML_UNUSED(tensor_split);
1368
1393
  }
1369
1394
 
1395
+ static size_t llama_get_device_count() {
1396
+ #if defined(GGML_USE_CUBLAS)
1397
+ return ggml_backend_cuda_get_device_count();
1398
+ #elif defined(GGML_USE_VULKAN)
1399
+ return ggml_backend_vk_get_device_count();
1400
+ #else
1401
+ return 1;
1402
+ #endif
1403
+ }
1404
+
1405
+ static size_t llama_get_device_memory(int device) {
1406
+ #if defined(GGML_USE_CUBLAS)
1407
+ size_t total;
1408
+ size_t free;
1409
+ ggml_backend_cuda_get_device_memory(device, &total, &free);
1410
+ return free;
1411
+ #elif defined(GGML_USE_VULKAN)
1412
+ size_t total;
1413
+ size_t free;
1414
+ ggml_backend_vk_get_device_memory(device, &total, &free);
1415
+ return free;
1416
+ #else
1417
+ return 1;
1418
+ GGML_UNUSED(device);
1419
+ #endif
1420
+ }
1421
+
1370
1422
  //
1371
1423
  // globals
1372
1424
  //
@@ -1390,6 +1442,7 @@ enum e_model {
1390
1442
  MODEL_UNKNOWN,
1391
1443
  MODEL_0_5B,
1392
1444
  MODEL_1B,
1445
+ MODEL_2B,
1393
1446
  MODEL_3B,
1394
1447
  MODEL_4B,
1395
1448
  MODEL_7B,
@@ -1415,6 +1468,7 @@ static const size_t GiB = 1024*MiB;
1415
1468
 
1416
1469
  struct llama_hparams {
1417
1470
  bool vocab_only;
1471
+ bool rope_finetuned;
1418
1472
  uint32_t n_vocab;
1419
1473
  uint32_t n_ctx_train; // context size the model was trained on
1420
1474
  uint32_t n_embd;
@@ -1434,8 +1488,7 @@ struct llama_hparams {
1434
1488
  float rope_freq_base_train;
1435
1489
  float rope_freq_scale_train;
1436
1490
  uint32_t n_yarn_orig_ctx;
1437
- int8_t rope_scaling_type_train : 3;
1438
- bool rope_finetuned : 1;
1491
+ int32_t rope_scaling_type_train;
1439
1492
 
1440
1493
  float f_clamp_kqv;
1441
1494
  float f_max_alibi_bias;
@@ -1737,6 +1790,10 @@ struct llama_context {
1737
1790
  ggml_backend_free(backend);
1738
1791
  }
1739
1792
 
1793
+ #ifdef GGML_USE_VULKAN
1794
+ ggml_vk_free_cpu_assist();
1795
+ #endif
1796
+
1740
1797
  ggml_backend_buffer_free(buf_input);
1741
1798
  ggml_free(ctx_input);
1742
1799
  }
@@ -2701,7 +2758,7 @@ struct llama_model_loader {
2701
2758
  // load LLaMA models
2702
2759
  //
2703
2760
 
2704
- static std::string llama_model_arch_name(llm_arch arch) {
2761
+ static const char * llama_model_arch_name(llm_arch arch) {
2705
2762
  auto it = LLM_ARCH_NAMES.find(arch);
2706
2763
  if (it == LLM_ARCH_NAMES.end()) {
2707
2764
  return "unknown";
@@ -2748,6 +2805,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
2748
2805
  static const char * llama_model_type_name(e_model type) {
2749
2806
  switch (type) {
2750
2807
  case MODEL_1B: return "1B";
2808
+ case MODEL_2B: return "2B";
2751
2809
  case MODEL_3B: return "3B";
2752
2810
  case MODEL_7B: return "7B";
2753
2811
  case MODEL_8B: return "8B";
@@ -2887,6 +2945,15 @@ static void llm_load_hparams(
2887
2945
  default: model.type = e_model::MODEL_UNKNOWN;
2888
2946
  }
2889
2947
  } break;
2948
+ case LLM_ARCH_MINICPM:
2949
+ {
2950
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2951
+
2952
+ switch (hparams.n_layer) {
2953
+ case 40: model.type = e_model::MODEL_2B; break;
2954
+ default: model.type = e_model::MODEL_UNKNOWN;
2955
+ }
2956
+ } break;
2890
2957
  case LLM_ARCH_FALCON:
2891
2958
  {
2892
2959
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3310,11 +3377,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3310
3377
  const auto & hparams = model.hparams;
3311
3378
  const auto & vocab = model.vocab;
3312
3379
 
3313
- const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
3380
+ const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
3314
3381
 
3315
3382
  // hparams
3316
3383
  LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
3317
- LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
3384
+ LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch));
3318
3385
  LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, llama_model_vocab_type_name(vocab.type));
3319
3386
  LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
3320
3387
  LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
@@ -3336,7 +3403,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3336
3403
  LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
3337
3404
  LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
3338
3405
  LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
3339
- LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
3406
+ LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
3340
3407
  LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
3341
3408
  LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
3342
3409
  LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
@@ -3402,22 +3469,18 @@ static bool llm_load_tensors(
3402
3469
  model.buft_layer[i] = llama_default_buffer_type_cpu(true);
3403
3470
  }
3404
3471
 
3405
- #ifdef GGML_USE_CUBLAS
3406
3472
  if (split_mode == LLAMA_SPLIT_LAYER) {
3407
3473
  // calculate the split points
3408
- int device_count = ggml_backend_cuda_get_device_count();
3474
+ int device_count = llama_get_device_count();
3409
3475
  bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
3410
- float splits[GGML_CUDA_MAX_DEVICES];
3476
+ std::vector<float> splits(device_count);
3411
3477
  if (all_zero) {
3412
3478
  // default split, by free memory
3413
3479
  for (int i = 0; i < device_count; ++i) {
3414
- size_t total;
3415
- size_t free;
3416
- ggml_backend_cuda_get_device_memory(i, &total, &free);
3417
- splits[i] = free;
3480
+ splits[i] = llama_get_device_memory(i);
3418
3481
  }
3419
3482
  } else {
3420
- std::copy(tensor_split, tensor_split + device_count, splits);
3483
+ std::copy(tensor_split, tensor_split + device_count, splits.begin());
3421
3484
  }
3422
3485
 
3423
3486
  // sum and normalize the splits to get the split points
@@ -3433,19 +3496,17 @@ static bool llm_load_tensors(
3433
3496
  // assign the repeating layers to the devices according to the splits
3434
3497
  int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
3435
3498
  for (int64_t i = i_gpu_start; i < n_layer; ++i) {
3436
- int layer_gpu = std::upper_bound(splits, splits + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits;
3499
+ int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
3437
3500
  model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu);
3438
3501
  }
3439
3502
  // assign the output layer
3440
3503
  if (n_gpu_layers > n_layer) {
3441
- int layer_gpu = std::upper_bound(splits, splits + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits;
3504
+ int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
3442
3505
  model.buft_output = llama_default_buffer_type_offload(layer_gpu);
3443
3506
  } else {
3444
3507
  model.buft_output = llama_default_buffer_type_cpu(true);
3445
3508
  }
3446
- } else
3447
- #endif
3448
- {
3509
+ } else {
3449
3510
  ggml_backend_buffer_type_t split_buft;
3450
3511
  if (split_mode == LLAMA_SPLIT_ROW) {
3451
3512
  split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
@@ -3524,13 +3585,16 @@ static bool llm_load_tensors(
3524
3585
  switch (model.arch) {
3525
3586
  case LLM_ARCH_LLAMA:
3526
3587
  case LLM_ARCH_REFACT:
3588
+ case LLM_ARCH_MINICPM:
3527
3589
  {
3528
3590
  model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
3529
3591
 
3530
3592
  // output
3531
3593
  {
3532
3594
  model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
3533
- model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3595
+ if (model.arch != LLM_ARCH_MINICPM){
3596
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3597
+ }
3534
3598
  }
3535
3599
 
3536
3600
  for (int i = 0; i < n_layer; ++i) {
@@ -4145,8 +4209,7 @@ static bool llm_load_tensors(
4145
4209
  ctx_bufs.emplace_back(ctx, buf);
4146
4210
  }
4147
4211
 
4148
- // print memory requirements
4149
- {
4212
+ if (llama_supports_gpu_offload()) {
4150
4213
  const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
4151
4214
 
4152
4215
  LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
@@ -4158,10 +4221,11 @@ static bool llm_load_tensors(
4158
4221
  const int max_offloadable_layers = hparams.n_layer + 1;
4159
4222
 
4160
4223
  LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
4224
+ }
4161
4225
 
4162
- for (ggml_backend_buffer_t buf : model.bufs) {
4163
- LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
4164
- }
4226
+ // print memory requirements
4227
+ for (ggml_backend_buffer_t buf : model.bufs) {
4228
+ LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
4165
4229
  }
4166
4230
 
4167
4231
  // populate tensors_by_name
@@ -6781,6 +6845,153 @@ struct llm_build_context {
6781
6845
  return gf;
6782
6846
  }
6783
6847
 
6848
+ // ref: https://arxiv.org/abs/2203.03466
6849
+ // https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
6850
+ // based on the original build_llama() function
6851
+ struct ggml_cgraph * build_minicpm() {
6852
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
6853
+
6854
+ const int64_t n_embd_head = hparams.n_embd_head_v;
6855
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6856
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
6857
+
6858
+ const int64_t n_embd = hparams.n_embd;
6859
+ //TODO: if the model varies, these parameters need to be read from the model
6860
+ const int64_t n_embd_base = 256;
6861
+ const float scale_embd = 12.0f;
6862
+ const float scale_depth = 1.4f;
6863
+
6864
+ struct ggml_tensor * cur;
6865
+ struct ggml_tensor * inpL;
6866
+
6867
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6868
+ cb(inpL, "inp_embd", -1);
6869
+
6870
+ // scale the input embeddings
6871
+ inpL = ggml_scale(ctx0, inpL, scale_embd);
6872
+ cb(inpL, "inp_scaled", -1);
6873
+
6874
+ // inp_pos - contains the positions
6875
+ struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6876
+ cb(inp_pos, "inp_pos", -1);
6877
+
6878
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6879
+ struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6880
+ cb(KQ_mask, "KQ_mask", -1);
6881
+
6882
+ // shift the entire K-cache if needed
6883
+ if (do_rope_shift) {
6884
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
6885
+ }
6886
+
6887
+ for (int il = 0; il < n_layer; ++il) {
6888
+ struct ggml_tensor * inpSA = inpL;
6889
+
6890
+ // norm
6891
+ cur = llm_build_norm(ctx0, inpL, hparams,
6892
+ model.layers[il].attn_norm, NULL,
6893
+ LLM_NORM_RMS, cb, il);
6894
+ cb(cur, "attn_norm", il);
6895
+
6896
+ // self-attention
6897
+ {
6898
+ // compute Q and K and RoPE them
6899
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
6900
+ cb(Qcur, "Qcur", il);
6901
+ if (model.layers[il].bq) {
6902
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
6903
+ cb(Qcur, "Qcur", il);
6904
+ }
6905
+
6906
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
6907
+ cb(Kcur, "Kcur", il);
6908
+ if (model.layers[il].bk) {
6909
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
6910
+ cb(Kcur, "Kcur", il);
6911
+ }
6912
+
6913
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
6914
+ cb(Vcur, "Vcur", il);
6915
+ if (model.layers[il].bv) {
6916
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
6917
+ cb(Vcur, "Vcur", il);
6918
+ }
6919
+
6920
+ Qcur = ggml_rope_custom(
6921
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
6922
+ hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
6923
+ ext_factor, attn_factor, beta_fast, beta_slow
6924
+ );
6925
+ cb(Qcur, "Qcur", il);
6926
+
6927
+ Kcur = ggml_rope_custom(
6928
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
6929
+ hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
6930
+ ext_factor, attn_factor, beta_fast, beta_slow
6931
+ );
6932
+ cb(Kcur, "Kcur", il);
6933
+
6934
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6935
+ model.layers[il].wo, model.layers[il].bo,
6936
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6937
+ cb(cur, "kqv_out", il);
6938
+ }
6939
+
6940
+ // scale_res - scale the hidden states for residual connection
6941
+ const float scale_res = scale_depth/sqrtf(float(n_layer));
6942
+ cur = ggml_scale(ctx0, cur, scale_res);
6943
+ cb(cur, "hidden_scaled", -1);
6944
+
6945
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
6946
+ cb(ffn_inp, "ffn_inp", il);
6947
+
6948
+ // feed-forward network
6949
+ {
6950
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
6951
+ model.layers[il].ffn_norm, NULL,
6952
+ LLM_NORM_RMS, cb, il);
6953
+ cb(cur, "ffn_norm", il);
6954
+
6955
+ cur = llm_build_ffn(ctx0, cur,
6956
+ model.layers[il].ffn_up, NULL,
6957
+ model.layers[il].ffn_gate, NULL,
6958
+ model.layers[il].ffn_down, NULL,
6959
+ NULL,
6960
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
6961
+ cb(cur, "ffn_out", il);
6962
+ }
6963
+
6964
+ // scale the hidden states for residual connection
6965
+ cur = ggml_scale(ctx0, cur, scale_res);
6966
+ cb(cur, "hidden_scaled_ffn", -1);
6967
+
6968
+ cur = ggml_add(ctx0, cur, ffn_inp);
6969
+ cb(cur, "l_out", il);
6970
+
6971
+ // input for next layer
6972
+ inpL = cur;
6973
+ }
6974
+
6975
+ cur = inpL;
6976
+
6977
+ cur = llm_build_norm(ctx0, cur, hparams,
6978
+ model.output_norm, NULL,
6979
+ LLM_NORM_RMS, cb, -1);
6980
+ cb(cur, "result_norm", -1);
6981
+
6982
+ // lm_head scaling
6983
+ const float scale_lmhead = float(n_embd_base)/float(n_embd);
6984
+ cur = ggml_scale(ctx0, cur, scale_lmhead);
6985
+ cb(cur, "lmhead_scaling", -1);
6986
+
6987
+ // lm_head
6988
+ cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
6989
+ cb(cur, "result_output", -1);
6990
+
6991
+ ggml_build_forward_expand(gf, cur);
6992
+
6993
+ return gf;
6994
+ }
6784
6995
  };
6785
6996
 
6786
6997
  static struct ggml_cgraph * llama_build_graph(
@@ -6943,6 +7154,10 @@ static struct ggml_cgraph * llama_build_graph(
6943
7154
  {
6944
7155
  result = llm.build_internlm2();
6945
7156
  } break;
7157
+ case LLM_ARCH_MINICPM:
7158
+ {
7159
+ result = llm.build_minicpm();
7160
+ } break;
6946
7161
  default:
6947
7162
  GGML_ASSERT(false);
6948
7163
  }
@@ -8373,6 +8588,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
8373
8588
 
8374
8589
  const int64_t t_start_sample_us = ggml_time_us();
8375
8590
 
8591
+ if (k <= 0) {
8592
+ k = candidates->size;
8593
+ }
8594
+
8376
8595
  k = std::max(k, (int) min_keep);
8377
8596
  k = std::min(k, (int) candidates->size);
8378
8597
 
@@ -9456,8 +9675,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
9456
9675
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
9457
9676
  new_type = GGML_TYPE_Q4_K;
9458
9677
  }
9459
- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && qs.model.hparams.n_gqa() >= 4) {
9460
- new_type = GGML_TYPE_Q4_K;
9678
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
9679
+ new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS;
9461
9680
  }
9462
9681
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
9463
9682
  new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
@@ -9496,9 +9715,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
9496
9715
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
9497
9716
  if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
9498
9717
  }
9499
- //else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
9500
- // if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K;
9501
- //}
9718
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
9719
+ new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
9720
+ }
9502
9721
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
9503
9722
  new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
9504
9723
  : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
@@ -10295,6 +10514,8 @@ size_t llama_max_devices(void) {
10295
10514
  return GGML_CUDA_MAX_DEVICES;
10296
10515
  #elif defined(GGML_USE_SYCL)
10297
10516
  return GGML_SYCL_MAX_DEVICES;
10517
+ #elif defined(GGML_USE_VULKAN)
10518
+ return GGML_VK_MAX_DEVICES;
10298
10519
  #else
10299
10520
  return 1;
10300
10521
  #endif
@@ -10502,13 +10723,15 @@ struct llama_context * llama_new_context_with_model(
10502
10723
  }
10503
10724
  #elif defined(GGML_USE_VULKAN)
10504
10725
  if (model->n_gpu_layers > 0) {
10505
- ggml_backend_t backend = ggml_backend_vk_init();
10506
- if (backend == nullptr) {
10507
- LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
10508
- llama_free(ctx);
10509
- return nullptr;
10726
+ for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
10727
+ ggml_backend_t backend = ggml_backend_vk_init(device);
10728
+ if (backend == nullptr) {
10729
+ LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
10730
+ llama_free(ctx);
10731
+ return nullptr;
10732
+ }
10733
+ ctx->backends.push_back(backend);
10510
10734
  }
10511
- ctx->backends.push_back(backend);
10512
10735
  }
10513
10736
  #elif defined(GGML_USE_SYCL)
10514
10737
  if (model->n_gpu_layers > 0) {
@@ -10735,7 +10958,7 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3
10735
10958
 
10736
10959
  int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
10737
10960
  return snprintf(buf, buf_size, "%s %s %s",
10738
- llama_model_arch_name(model->arch).c_str(),
10961
+ llama_model_arch_name(model->arch),
10739
10962
  llama_model_type_name(model->type),
10740
10963
  llama_model_ftype_name(model->ftype).c_str());
10741
10964
  }
@@ -213,7 +213,7 @@ extern "C" {
213
213
  uint32_t n_batch; // prompt processing maximum batch size
214
214
  uint32_t n_threads; // number of threads to use for generation
215
215
  uint32_t n_threads_batch; // number of threads to use for batch processing
216
- int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
216
+ int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
217
217
 
218
218
  // ref: https://github.com/ggerganov/llama.cpp/pull/2054
219
219
  float rope_freq_base; // RoPE base frequency, 0 = from model