llama_cpp 0.9.4 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -46,7 +46,6 @@
46
46
  #endif
47
47
  #include <windows.h>
48
48
  #include <io.h>
49
- #include <stdio.h> // for _fseeki64
50
49
  #endif
51
50
 
52
51
  #include <algorithm>
@@ -75,6 +74,7 @@
75
74
  #include <set>
76
75
  #include <sstream>
77
76
  #include <thread>
77
+ #include <type_traits>
78
78
  #include <unordered_map>
79
79
 
80
80
  #if defined(_MSC_VER)
@@ -193,6 +193,7 @@ enum llm_arch {
193
193
  LLM_ARCH_REFACT,
194
194
  LLM_ARCH_BLOOM,
195
195
  LLM_ARCH_STABLELM,
196
+ LLM_ARCH_QWEN,
196
197
  LLM_ARCH_UNKNOWN,
197
198
  };
198
199
 
@@ -209,6 +210,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
209
210
  { LLM_ARCH_REFACT, "refact" },
210
211
  { LLM_ARCH_BLOOM, "bloom" },
211
212
  { LLM_ARCH_STABLELM, "stablelm" },
213
+ { LLM_ARCH_QWEN, "qwen" },
212
214
  };
213
215
 
214
216
  enum llm_kv {
@@ -519,6 +521,22 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
519
521
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
520
522
  },
521
523
  },
524
+ {
525
+ LLM_ARCH_QWEN,
526
+ {
527
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
528
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
529
+ { LLM_TENSOR_OUTPUT, "output" },
530
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
531
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
532
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
533
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
534
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
535
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
536
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
537
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
538
+ },
539
+ },
522
540
 
523
541
  {
524
542
  LLM_ARCH_UNKNOWN,
@@ -573,21 +591,6 @@ struct LLM_TN {
573
591
  // gguf helpers
574
592
  //
575
593
 
576
- #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
577
- do { \
578
- const std::string skey(key); \
579
- const int kid = gguf_find_key(ctx, skey.c_str()); \
580
- if (kid >= 0) { \
581
- enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
582
- if (ktype != (type)) { \
583
- throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
584
- } \
585
- (dst) = func(ctx, kid); \
586
- } else if (req) { \
587
- throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
588
- } \
589
- } while (0)
590
-
591
594
  static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
592
595
  { LLAMA_ROPE_SCALING_NONE, "none" },
593
596
  { LLAMA_ROPE_SCALING_LINEAR, "linear" },
@@ -621,7 +624,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
621
624
  }
622
625
  }
623
626
 
624
- static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
627
+ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
625
628
  const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
626
629
 
627
630
  switch (type) {
@@ -1113,6 +1116,12 @@ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_
1113
1116
  //
1114
1117
 
1115
1118
  struct llama_state {
1119
+ llama_state() {
1120
+ #ifdef GGML_USE_METAL
1121
+ ggml_metal_log_set_callback(log_callback, log_callback_user_data);
1122
+ #endif
1123
+ }
1124
+
1116
1125
  // We save the log callback globally
1117
1126
  ggml_log_callback log_callback = llama_log_callback_default;
1118
1127
  void * log_callback_user_data = nullptr;
@@ -1217,6 +1226,7 @@ struct llama_cparams {
1217
1226
  float yarn_beta_slow;
1218
1227
 
1219
1228
  bool mul_mat_q;
1229
+ bool offload_kqv;
1220
1230
  };
1221
1231
 
1222
1232
  struct llama_layer {
@@ -1238,6 +1248,9 @@ struct llama_layer {
1238
1248
  struct ggml_tensor * wqkv;
1239
1249
 
1240
1250
  // attention bias
1251
+ struct ggml_tensor * bq;
1252
+ struct ggml_tensor * bk;
1253
+ struct ggml_tensor * bv;
1241
1254
  struct ggml_tensor * bo;
1242
1255
  struct ggml_tensor * bqkv;
1243
1256
 
@@ -1282,8 +1295,8 @@ struct llama_kv_cache {
1282
1295
 
1283
1296
  std::vector<llama_kv_cell> cells;
1284
1297
 
1285
- struct ggml_tensor * k = NULL;
1286
- struct ggml_tensor * v = NULL;
1298
+ std::vector<struct ggml_tensor *> k_l; // per layer
1299
+ std::vector<struct ggml_tensor *> v_l;
1287
1300
 
1288
1301
  struct ggml_context * ctx = NULL;
1289
1302
 
@@ -1296,8 +1309,10 @@ struct llama_kv_cache {
1296
1309
 
1297
1310
  #ifdef GGML_USE_CUBLAS
1298
1311
  if (ggml_cublas_loaded()) {
1299
- ggml_cuda_free_data(k);
1300
- ggml_cuda_free_data(v);
1312
+ for (size_t i = 0; i < k_l.size(); ++i) {
1313
+ ggml_cuda_free_data(k_l[i]);
1314
+ ggml_cuda_free_data(v_l[i]);
1315
+ }
1301
1316
  }
1302
1317
  #endif
1303
1318
  }
@@ -1487,9 +1502,11 @@ struct llama_context {
1487
1502
  static bool llama_kv_cache_init(
1488
1503
  const struct llama_hparams & hparams,
1489
1504
  struct llama_kv_cache & cache,
1490
- ggml_type wtype,
1505
+ ggml_type ktype,
1506
+ ggml_type vtype,
1491
1507
  uint32_t n_ctx,
1492
- int n_gpu_layers) {
1508
+ int n_gpu_layers,
1509
+ bool offload) {
1493
1510
  const uint32_t n_embd = hparams.n_embd_gqa();
1494
1511
  const uint32_t n_layer = hparams.n_layer;
1495
1512
 
@@ -1505,7 +1522,7 @@ static bool llama_kv_cache_init(
1505
1522
  cache.cells.clear();
1506
1523
  cache.cells.resize(n_ctx);
1507
1524
 
1508
- cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
1525
+ cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
1509
1526
  memset(cache.buf.data, 0, cache.buf.size);
1510
1527
 
1511
1528
  struct ggml_init_params params;
@@ -1515,37 +1532,44 @@ static bool llama_kv_cache_init(
1515
1532
 
1516
1533
  cache.ctx = ggml_init(params);
1517
1534
 
1535
+ size_t vram_kv_cache = 0;
1536
+
1518
1537
  if (!cache.ctx) {
1519
1538
  LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
1520
1539
  return false;
1521
1540
  }
1522
1541
 
1523
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
1524
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
1525
- ggml_set_name(cache.k, "cache_k");
1526
- ggml_set_name(cache.v, "cache_v");
1542
+ cache.k_l.reserve(n_layer);
1543
+ cache.v_l.reserve(n_layer);
1527
1544
 
1528
- (void) n_gpu_layers;
1545
+ const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
1529
1546
 
1530
- #ifdef GGML_USE_CUBLAS
1531
- if (ggml_cublas_loaded()) {
1532
- size_t vram_kv_cache = 0;
1547
+ GGML_UNUSED(offload);
1533
1548
 
1534
- if (n_gpu_layers > (int)n_layer + 1) {
1535
- ggml_cuda_assign_buffers_no_scratch(cache.v);
1536
- LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
1537
- vram_kv_cache += ggml_nbytes(cache.v);
1538
- }
1539
- if (n_gpu_layers > (int)n_layer + 2) {
1540
- ggml_cuda_assign_buffers_no_scratch(cache.k);
1541
- LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
1542
- vram_kv_cache += ggml_nbytes(cache.k);
1543
- }
1544
- if (vram_kv_cache > 0) {
1545
- LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
1549
+ for (int i = 0; i < (int) n_layer; i++) {
1550
+ ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
1551
+ ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
1552
+ ggml_format_name(k, "cache_k_l%d", i);
1553
+ ggml_format_name(v, "cache_v_l%d", i);
1554
+ cache.k_l.push_back(k);
1555
+ cache.v_l.push_back(v);
1556
+ #ifdef GGML_USE_CUBLAS
1557
+ if (i >= i_gpu_start) {
1558
+ if (offload) {
1559
+ ggml_cuda_assign_buffers_no_scratch(k);
1560
+ vram_kv_cache += ggml_nbytes(k);
1561
+ ggml_cuda_assign_buffers_no_scratch(v);
1562
+ vram_kv_cache += ggml_nbytes(v);
1563
+ }
1546
1564
  }
1565
+ #endif // GGML_USE_CUBLAS
1547
1566
  }
1548
- #endif
1567
+
1568
+ if (vram_kv_cache > 0) {
1569
+ LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
1570
+ }
1571
+
1572
+ GGML_UNUSED(n_gpu_layers);
1549
1573
 
1550
1574
  return true;
1551
1575
  }
@@ -1766,6 +1790,169 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
1766
1790
  return buf;
1767
1791
  }
1768
1792
 
1793
+ namespace GGUFMeta {
1794
+ template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
1795
+ struct GKV_Base_Type {
1796
+ static constexpr gguf_type gt = gt_;
1797
+
1798
+ static T getter(const gguf_context * ctx, const int kid) {
1799
+ return gfun(ctx, kid);
1800
+ }
1801
+ };
1802
+
1803
+ template<typename T> struct GKV_Base;
1804
+
1805
+ template<> struct GKV_Base<bool >: GKV_Base_Type<bool, GGUF_TYPE_BOOL, gguf_get_val_bool> {};
1806
+ template<> struct GKV_Base<uint8_t >: GKV_Base_Type<uint8_t, GGUF_TYPE_UINT8, gguf_get_val_u8 > {};
1807
+ template<> struct GKV_Base<uint16_t >: GKV_Base_Type<uint16_t, GGUF_TYPE_UINT16, gguf_get_val_u16 > {};
1808
+ template<> struct GKV_Base<uint32_t >: GKV_Base_Type<uint32_t, GGUF_TYPE_UINT32, gguf_get_val_u32 > {};
1809
+ template<> struct GKV_Base<uint64_t >: GKV_Base_Type<uint64_t, GGUF_TYPE_UINT64, gguf_get_val_u64 > {};
1810
+ template<> struct GKV_Base<int8_t >: GKV_Base_Type<int8_t, GGUF_TYPE_INT8, gguf_get_val_i8 > {};
1811
+ template<> struct GKV_Base<int16_t >: GKV_Base_Type<int16_t, GGUF_TYPE_INT16, gguf_get_val_i16 > {};
1812
+ template<> struct GKV_Base<int32_t >: GKV_Base_Type<int32_t, GGUF_TYPE_INT32, gguf_get_val_i32 > {};
1813
+ template<> struct GKV_Base<int64_t >: GKV_Base_Type<int64_t, GGUF_TYPE_INT64, gguf_get_val_i64 > {};
1814
+ template<> struct GKV_Base<float >: GKV_Base_Type<float, GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
1815
+ template<> struct GKV_Base<double >: GKV_Base_Type<double, GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
1816
+ template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING, gguf_get_val_str > {};
1817
+
1818
+ template<> struct GKV_Base<std::string> {
1819
+ static constexpr gguf_type gt = GGUF_TYPE_STRING;
1820
+
1821
+ static std::string getter(const gguf_context * ctx, const int kid) {
1822
+ return gguf_get_val_str(ctx, kid);
1823
+ }
1824
+ };
1825
+
1826
+ struct ArrayInfo{
1827
+ const gguf_type gt;
1828
+ const size_t length;
1829
+ const void * data;
1830
+ };
1831
+
1832
+ template<> struct GKV_Base<ArrayInfo> {
1833
+ public:
1834
+ static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
1835
+ static ArrayInfo getter(const gguf_context *ctx, const int k) {
1836
+ return ArrayInfo {
1837
+ gguf_get_arr_type(ctx, k),
1838
+ size_t(gguf_get_arr_n(ctx, k)),
1839
+ gguf_get_arr_data(ctx, k),
1840
+ };
1841
+ }
1842
+ };
1843
+
1844
+ template<typename T>
1845
+ class GKV: public GKV_Base<T> {
1846
+ GKV() = delete;
1847
+
1848
+ public:
1849
+ static T get_kv(const gguf_context * ctx, const int k) {
1850
+ const enum gguf_type kt = gguf_get_kv_type(ctx, k);
1851
+
1852
+ if (kt != GKV::gt) {
1853
+ throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
1854
+ gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
1855
+ }
1856
+ return GKV::getter(ctx, k);
1857
+ }
1858
+
1859
+ static const char * override_type_to_str(const llama_model_kv_override_type ty) {
1860
+ switch (ty) {
1861
+ case LLAMA_KV_OVERRIDE_BOOL: return "bool";
1862
+ case LLAMA_KV_OVERRIDE_INT: return "int";
1863
+ case LLAMA_KV_OVERRIDE_FLOAT: return "float";
1864
+ }
1865
+ return "unknown";
1866
+ }
1867
+
1868
+ static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
1869
+ if (!override) { return false; }
1870
+ if (override->tag == expected_type) {
1871
+ LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
1872
+ __func__, override_type_to_str(override->tag), override->key);
1873
+ switch (override->tag) {
1874
+ case LLAMA_KV_OVERRIDE_BOOL: {
1875
+ printf("%s\n", override->bool_value ? "true" : "false");
1876
+ } break;
1877
+ case LLAMA_KV_OVERRIDE_INT: {
1878
+ printf("%" PRId64 "\n", override->int_value);
1879
+ } break;
1880
+ case LLAMA_KV_OVERRIDE_FLOAT: {
1881
+ printf("%.6f\n", override->float_value);
1882
+ } break;
1883
+ default:
1884
+ // Shouldn't be possible to end up here, but just in case...
1885
+ throw std::runtime_error(
1886
+ format("Unsupported attempt to override %s type for metadata key %s\n",
1887
+ override_type_to_str(override->tag), override->key));
1888
+ }
1889
+ return true;
1890
+ }
1891
+ LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
1892
+ __func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
1893
+ return false;
1894
+ }
1895
+
1896
+ template<typename OT>
1897
+ static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
1898
+ try_override(OT & target, const struct llama_model_kv_override *override) {
1899
+ if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
1900
+ target = override->bool_value;
1901
+ return true;
1902
+ }
1903
+ return true;
1904
+ }
1905
+
1906
+ template<typename OT>
1907
+ static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
1908
+ try_override(OT & target, const struct llama_model_kv_override *override) {
1909
+ if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
1910
+ target = override->int_value;
1911
+ return true;
1912
+ }
1913
+ return false;
1914
+ }
1915
+
1916
+ template<typename OT>
1917
+ static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
1918
+ try_override(T & target, const struct llama_model_kv_override *override) {
1919
+ if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
1920
+ target = override->float_value;
1921
+ return true;
1922
+ }
1923
+ return false;
1924
+ }
1925
+
1926
+ template<typename OT>
1927
+ static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
1928
+ try_override(T & target, const struct llama_model_kv_override *override) {
1929
+ (void)target;
1930
+ (void)override;
1931
+ if (!override) { return false; }
1932
+ // Currently, we should never end up here so it would be a bug if we do.
1933
+ throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
1934
+ override ? override->key : "NULL"));
1935
+ }
1936
+
1937
+ static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
1938
+ if (try_override<T>(target, override)) {
1939
+ return true;
1940
+ }
1941
+ if (k < 0) { return false; }
1942
+ target = get_kv(ctx, k);
1943
+ return true;
1944
+ }
1945
+
1946
+ static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
1947
+ return set(ctx, gguf_find_key(ctx, key), target, override);
1948
+ }
1949
+
1950
+ static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
1951
+ return set(ctx, key.c_str(), target, override);
1952
+ }
1953
+ };
1954
+ }
1955
+
1769
1956
  struct llama_model_loader {
1770
1957
  int n_kv = 0;
1771
1958
  int n_tensors = 0;
@@ -1781,21 +1968,34 @@ struct llama_model_loader {
1781
1968
  llama_fver fver;
1782
1969
 
1783
1970
  std::unique_ptr<llama_mmap> mapping;
1971
+ std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
1784
1972
 
1785
1973
  struct gguf_context * ctx_gguf = NULL;
1786
1974
  struct ggml_context * ctx_meta = NULL;
1787
1975
 
1788
- llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") {
1976
+ std::string arch_name;
1977
+ LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
1978
+
1979
+ llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) : file(fname.c_str(), "rb") {
1789
1980
  struct gguf_init_params params = {
1790
1981
  /*.no_alloc = */ true,
1791
1982
  /*.ctx = */ &ctx_meta,
1792
1983
  };
1793
1984
 
1985
+ if (param_overrides_p != nullptr) {
1986
+ for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) {
1987
+ kv_overrides.insert({std::string(p->key), *p});
1988
+ }
1989
+ }
1990
+
1794
1991
  ctx_gguf = gguf_init_from_file(fname.c_str(), params);
1795
1992
  if (!ctx_gguf) {
1796
1993
  throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
1797
1994
  }
1798
1995
 
1996
+ get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
1997
+ llm_kv = LLM_KV(llm_arch_from_string(arch_name));
1998
+
1799
1999
  n_kv = gguf_get_n_kv(ctx_gguf);
1800
2000
  n_tensors = gguf_get_n_tensors(ctx_gguf);
1801
2001
 
@@ -1863,6 +2063,7 @@ struct llama_model_loader {
1863
2063
  }
1864
2064
  }
1865
2065
 
2066
+ LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
1866
2067
  for (int i = 0; i < n_kv; i++) {
1867
2068
  const char * name = gguf_get_key(ctx_gguf, i);
1868
2069
  const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
@@ -1908,19 +2109,59 @@ struct llama_model_loader {
1908
2109
  }
1909
2110
  }
1910
2111
 
1911
- std::string get_arch_name() const {
1912
- const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
2112
+ template<typename T>
2113
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
2114
+ get_arr_n(const std::string & key, T & result, const bool required = true) {
2115
+ const int kid = gguf_find_key(ctx_gguf, key.c_str());
2116
+
2117
+ if (kid < 0) {
2118
+ if (required) {
2119
+ throw std::runtime_error(format("key not found in model: %s", key.c_str()));
2120
+ }
2121
+ return false;
2122
+ }
1913
2123
 
1914
- std::string arch_name;
1915
- GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
2124
+ struct GGUFMeta::ArrayInfo arr_info =
2125
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx_gguf, kid);
1916
2126
 
2127
+
2128
+ result = arr_info.length;
2129
+ return true;
2130
+ }
2131
+
2132
+ template<typename T>
2133
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
2134
+ get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
2135
+ return get_arr_n(llm_kv(kid), result, required);
2136
+ }
2137
+
2138
+ template<typename T>
2139
+ bool get_key(const std::string & key, T & result, const bool required = true) {
2140
+ auto it = kv_overrides.find(key);
2141
+
2142
+ const struct llama_model_kv_override * override =
2143
+ it != kv_overrides.end() ? &it->second : nullptr;
2144
+
2145
+ const bool found = GGUFMeta::GKV<T>::set(ctx_gguf, key, result, override);
2146
+
2147
+ if (required && !found) {
2148
+ throw std::runtime_error(format("key not found in model: %s", key.c_str()));
2149
+ }
2150
+
2151
+ return found;
2152
+ }
2153
+
2154
+ template<typename T>
2155
+ bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
2156
+ return get_key(llm_kv(kid), result, required);
2157
+ }
2158
+
2159
+ std::string get_arch_name() const {
1917
2160
  return arch_name;
1918
2161
  }
1919
2162
 
1920
2163
  enum llm_arch get_arch() const {
1921
- const std::string arch_name = get_arch_name();
1922
-
1923
- return llm_arch_from_string(arch_name);
2164
+ return llm_kv.arch;
1924
2165
  }
1925
2166
 
1926
2167
  const char * get_tensor_name(int i) const {
@@ -1960,10 +2201,13 @@ struct llama_model_loader {
1960
2201
  return tensor;
1961
2202
  }
1962
2203
 
1963
- struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, ggml_backend_type backend) {
2204
+ struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, ggml_backend_type backend, bool required = true) {
1964
2205
  struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str());
1965
2206
 
1966
2207
  if (cur == NULL) {
2208
+ if (!required) {
2209
+ return NULL;
2210
+ }
1967
2211
  throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
1968
2212
  }
1969
2213
 
@@ -2167,11 +2411,8 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
2167
2411
  static void llm_load_hparams(
2168
2412
  llama_model_loader & ml,
2169
2413
  llama_model & model) {
2170
- struct gguf_context * ctx = ml.ctx_gguf;
2171
-
2172
- const auto kv = LLM_KV(model.arch);
2173
-
2174
2414
  auto & hparams = model.hparams;
2415
+ const gguf_context * ctx = ml.ctx_gguf;
2175
2416
 
2176
2417
  // get metadata as string
2177
2418
  for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
@@ -2185,42 +2426,41 @@ static void llm_load_hparams(
2185
2426
  }
2186
2427
 
2187
2428
  // get general kv
2188
- GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
2429
+ ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
2189
2430
 
2190
2431
  // get hparams kv
2191
- GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
2192
- GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
2193
- GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
2194
- GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
2195
- GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
2196
- GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
2432
+ ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
2433
+ ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
2434
+ ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
2435
+ ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
2436
+ ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
2437
+ ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
2197
2438
 
2198
2439
  // n_head_kv is optional, default to n_head
2199
2440
  hparams.n_head_kv = hparams.n_head;
2200
- GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
2441
+ ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
2201
2442
 
2202
- hparams.rope_finetuned = false;
2203
- GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
2204
- kv(LLM_KV_ROPE_SCALING_FINETUNED));
2443
+ bool rope_finetuned = false;
2444
+ ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
2445
+ hparams.rope_finetuned = rope_finetuned;
2205
2446
 
2206
2447
  hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
2207
- GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
2208
- kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
2448
+ ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
2209
2449
 
2210
2450
  // rope_freq_base (optional)
2211
2451
  hparams.rope_freq_base_train = 10000.0f;
2212
- GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
2452
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false);
2213
2453
 
2214
2454
  std::string rope_scaling("linear");
2215
- GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
2455
+ ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
2216
2456
  hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
2217
2457
  GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
2218
2458
 
2219
2459
  // rope_freq_scale (inverse of the kv) is optional
2220
2460
  float ropescale = 0.0f;
2221
- GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
2222
- if (ropescale == 0.0f) { // try the old key name
2223
- GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
2461
+ if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
2462
+ // try the old key name
2463
+ ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
2224
2464
  }
2225
2465
  hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
2226
2466
 
@@ -2228,7 +2468,7 @@ static void llm_load_hparams(
2228
2468
  {
2229
2469
  hparams.n_rot = hparams.n_embd / hparams.n_head;
2230
2470
 
2231
- GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
2471
+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
2232
2472
 
2233
2473
  if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
2234
2474
  if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
@@ -2243,7 +2483,7 @@ static void llm_load_hparams(
2243
2483
  switch (model.arch) {
2244
2484
  case LLM_ARCH_LLAMA:
2245
2485
  {
2246
- GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
2486
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2247
2487
 
2248
2488
  switch (hparams.n_layer) {
2249
2489
  case 26: model.type = e_model::MODEL_3B; break;
@@ -2257,7 +2497,7 @@ static void llm_load_hparams(
2257
2497
  } break;
2258
2498
  case LLM_ARCH_FALCON:
2259
2499
  {
2260
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2500
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2261
2501
 
2262
2502
  switch (hparams.n_layer) {
2263
2503
  case 32: model.type = e_model::MODEL_7B; break;
@@ -2267,7 +2507,7 @@ static void llm_load_hparams(
2267
2507
  } break;
2268
2508
  case LLM_ARCH_BAICHUAN:
2269
2509
  {
2270
- GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
2510
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2271
2511
  switch (hparams.n_layer) {
2272
2512
  case 32: model.type = e_model::MODEL_7B; break;
2273
2513
  case 40: model.type = e_model::MODEL_13B; break;
@@ -2276,7 +2516,7 @@ static void llm_load_hparams(
2276
2516
  } break;
2277
2517
  case LLM_ARCH_STARCODER:
2278
2518
  {
2279
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2519
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2280
2520
  switch (hparams.n_layer) {
2281
2521
  case 24: model.type = e_model::MODEL_1B; break;
2282
2522
  case 36: model.type = e_model::MODEL_3B; break;
@@ -2287,7 +2527,7 @@ static void llm_load_hparams(
2287
2527
  } break;
2288
2528
  case LLM_ARCH_PERSIMMON:
2289
2529
  {
2290
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2530
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2291
2531
  switch (hparams.n_layer) {
2292
2532
  case 36: model.type = e_model::MODEL_8B; break;
2293
2533
  default: model.type = e_model::MODEL_UNKNOWN;
@@ -2295,7 +2535,7 @@ static void llm_load_hparams(
2295
2535
  } break;
2296
2536
  case LLM_ARCH_REFACT:
2297
2537
  {
2298
- GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
2538
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2299
2539
  switch (hparams.n_layer) {
2300
2540
  case 32: model.type = e_model::MODEL_1B; break;
2301
2541
  default: model.type = e_model::MODEL_UNKNOWN;
@@ -2303,7 +2543,7 @@ static void llm_load_hparams(
2303
2543
  } break;
2304
2544
  case LLM_ARCH_BLOOM:
2305
2545
  {
2306
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2546
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2307
2547
 
2308
2548
  switch (hparams.n_layer) {
2309
2549
  case 24: model.type = e_model::MODEL_1B; break;
@@ -2318,9 +2558,9 @@ static void llm_load_hparams(
2318
2558
  {
2319
2559
  hparams.f_clamp_kqv = 0.0f;
2320
2560
 
2321
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2322
- GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
2323
- GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
2561
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2562
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
2563
+ ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
2324
2564
 
2325
2565
  switch (hparams.n_layer) {
2326
2566
  case 32: model.type = e_model::MODEL_7B; break;
@@ -2330,13 +2570,23 @@ static void llm_load_hparams(
2330
2570
  } break;
2331
2571
  case LLM_ARCH_STABLELM:
2332
2572
  {
2333
- GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
2573
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
2334
2574
 
2335
2575
  switch (hparams.n_layer) {
2336
2576
  case 32: model.type = e_model::MODEL_3B; break;
2337
2577
  default: model.type = e_model::MODEL_UNKNOWN;
2338
2578
  }
2339
2579
  } break;
2580
+ case LLM_ARCH_QWEN:
2581
+ {
2582
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2583
+
2584
+ switch (hparams.n_layer) {
2585
+ case 32: model.type = e_model::MODEL_7B; break;
2586
+ case 40: model.type = e_model::MODEL_13B; break;
2587
+ default: model.type = e_model::MODEL_UNKNOWN;
2588
+ }
2589
+ } break;
2340
2590
 
2341
2591
  default: (void)0;
2342
2592
  }
@@ -2378,7 +2628,7 @@ static void llm_load_vocab(
2378
2628
  {
2379
2629
  std::string tokenizer_name;
2380
2630
 
2381
- GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
2631
+ ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
2382
2632
 
2383
2633
  if (tokenizer_name == "llama") {
2384
2634
  vocab.type = LLAMA_VOCAB_TYPE_SPM;
@@ -2468,34 +2718,31 @@ static void llm_load_vocab(
2468
2718
  };
2469
2719
  for (const auto & it : special_token_types) {
2470
2720
  const std::string & key = kv(std::get<0>(it));
2471
- int32_t & id = std::get<1>(it), old_id = id;
2721
+ int32_t & id = std::get<1>(it);
2472
2722
 
2473
- GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key);
2474
- // Must be >= -1 and < vocab size. Since the key is unsigned, -1
2475
- // can only come from the default value, so there's no point in
2476
- // validating that.
2477
- if (size_t(id + 1) > vocab.id_to_token.size()) {
2478
- LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n",
2479
- __func__, key.c_str(), id, old_id);
2480
- id = old_id;
2723
+ uint32_t new_id;
2724
+ if (!ml.get_key(std::get<0>(it), new_id, false)) {
2725
+ continue;
2726
+ }
2727
+ if (new_id >= vocab.id_to_token.size()) {
2728
+ LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
2729
+ __func__, key.c_str(), new_id, id);
2730
+ } else {
2731
+ id = new_id;
2481
2732
  }
2482
2733
 
2483
2734
  }
2484
2735
 
2485
2736
  // Handle add_bos_token and add_eos_token
2486
- std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
2487
- int kid = gguf_find_key(ctx, key.c_str());
2488
- enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
2489
- vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
2490
- if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
2491
- LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
2492
- }
2493
- key = kv(LLM_KV_TOKENIZER_ADD_EOS);
2494
- kid = gguf_find_key(ctx, key.c_str());
2495
- ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
2496
- vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
2497
- if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
2498
- LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
2737
+ {
2738
+ bool temp = true;
2739
+
2740
+ if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
2741
+ vocab.special_add_bos = int(temp);
2742
+ }
2743
+ if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
2744
+ vocab.special_add_eos = int(temp);
2745
+ }
2499
2746
  }
2500
2747
  }
2501
2748
 
@@ -2634,15 +2881,15 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
2634
2881
  }
2635
2882
 
2636
2883
  // general kv
2637
- LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
2884
+ LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
2638
2885
 
2639
2886
  // special tokens
2640
- if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
2641
- if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
2642
- if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
2643
- if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
2644
- if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
2645
- if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
2887
+ if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
2888
+ if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
2889
+ if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
2890
+ if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
2891
+ if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
2892
+ if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
2646
2893
  }
2647
2894
 
2648
2895
  static void llm_load_tensors(
@@ -2728,14 +2975,7 @@ static void llm_load_tensors(
2728
2975
  ggml_backend_type backend_output;
2729
2976
 
2730
2977
  if (n_gpu_layers > int(n_layer)) {
2731
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
2732
- // on Windows however this is detrimental unless everything is on the GPU
2733
- #ifndef _WIN32
2734
- backend_norm = llama_backend_offload;
2735
- #else
2736
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
2737
- #endif // _WIN32
2738
-
2978
+ backend_norm = llama_backend_offload;
2739
2979
  backend_output = llama_backend_offload_split;
2740
2980
  } else {
2741
2981
  backend_norm = GGML_BACKEND_CPU;
@@ -2772,6 +3012,12 @@ static void llm_load_tensors(
2772
3012
  layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
2773
3013
  layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2774
3014
 
3015
+ // optional bias tensors
3016
+ layer.bq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, backend, false);
3017
+ layer.bk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, backend, false);
3018
+ layer.bv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, backend, false);
3019
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend, false);
3020
+
2775
3021
  layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
2776
3022
 
2777
3023
  layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
@@ -2780,9 +3026,14 @@ static void llm_load_tensors(
2780
3026
 
2781
3027
  if (backend == GGML_BACKEND_GPU) {
2782
3028
  vram_weights +=
2783
- ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
2784
- ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
2785
- ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
3029
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
3030
+ ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) +
3031
+ (layer.bq ? ggml_nbytes(layer.bq) : 0) +
3032
+ (layer.bk ? ggml_nbytes(layer.bk) : 0) +
3033
+ (layer.bv ? ggml_nbytes(layer.bv) : 0) +
3034
+ (layer.bo ? ggml_nbytes(layer.bo) : 0) +
3035
+ ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
3036
+ ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
2786
3037
  }
2787
3038
  }
2788
3039
  } break;
@@ -2794,14 +3045,7 @@ static void llm_load_tensors(
2794
3045
  ggml_backend_type backend_output;
2795
3046
 
2796
3047
  if (n_gpu_layers > int(n_layer)) {
2797
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
2798
- // on Windows however this is detrimental unless everything is on the GPU
2799
- #ifndef _WIN32
2800
- backend_norm = llama_backend_offload;
2801
- #else
2802
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
2803
- #endif // _WIN32
2804
-
3048
+ backend_norm = llama_backend_offload;
2805
3049
  backend_output = llama_backend_offload_split;
2806
3050
  } else {
2807
3051
  backend_norm = GGML_BACKEND_CPU;
@@ -2864,14 +3108,7 @@ static void llm_load_tensors(
2864
3108
  ggml_backend_type backend_output;
2865
3109
 
2866
3110
  if (n_gpu_layers > int(n_layer)) {
2867
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
2868
- // on Windows however this is detrimental unless everything is on the GPU
2869
- #ifndef _WIN32
2870
- backend_norm = llama_backend_offload;
2871
- #else
2872
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
2873
- #endif // _WIN32
2874
-
3111
+ backend_norm = llama_backend_offload;
2875
3112
  backend_output = llama_backend_offload_split;
2876
3113
  } else {
2877
3114
  backend_norm = GGML_BACKEND_CPU;
@@ -2941,14 +3178,7 @@ static void llm_load_tensors(
2941
3178
  ggml_backend_type backend_output;
2942
3179
 
2943
3180
  if (n_gpu_layers > int(n_layer)) {
2944
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
2945
- // on Windows however this is detrimental unless everything is on the GPU
2946
- #ifndef _WIN32
2947
- backend_norm = llama_backend_offload;
2948
- #else
2949
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
2950
- #endif // _WIN32
2951
-
3181
+ backend_norm = llama_backend_offload;
2952
3182
  backend_output = llama_backend_offload_split;
2953
3183
  } else {
2954
3184
  backend_norm = GGML_BACKEND_CPU;
@@ -3018,21 +3248,7 @@ static void llm_load_tensors(
3018
3248
  ggml_backend_type backend_output;
3019
3249
 
3020
3250
  if (n_gpu_layers > int(n_layer)) {
3021
- #ifdef GGML_USE_CUBLAS
3022
- if (n_gpu_layers > int(n_layer + 1)) {
3023
- LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
3024
- __func__, n_layer + 1);
3025
- throw std::runtime_error("Persimmon CUDA offload failed");
3026
- }
3027
- #endif
3028
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
3029
- // on Windows however this is detrimental unless everything is on the GPU
3030
- #ifndef _WIN32
3031
- backend_norm = llama_backend_offload;
3032
- #else
3033
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
3034
- #endif // _WIN32
3035
-
3251
+ backend_norm = llama_backend_offload;
3036
3252
  backend_output = llama_backend_offload_split;
3037
3253
  } else {
3038
3254
  backend_norm = GGML_BACKEND_CPU;
@@ -3091,14 +3307,7 @@ static void llm_load_tensors(
3091
3307
  ggml_backend_type backend_output;
3092
3308
 
3093
3309
  if (n_gpu_layers > int(n_layer)) {
3094
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
3095
- // on Windows however this is detrimental unless everything is on the GPU
3096
- #ifndef _WIN32
3097
- backend_norm = llama_backend_offload;
3098
- #else
3099
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
3100
- #endif // _WIN32
3101
-
3310
+ backend_norm = llama_backend_offload;
3102
3311
  backend_output = llama_backend_offload_split;
3103
3312
  } else {
3104
3313
  backend_norm = GGML_BACKEND_CPU;
@@ -3169,14 +3378,7 @@ static void llm_load_tensors(
3169
3378
  ggml_backend_type backend_output;
3170
3379
 
3171
3380
  if (n_gpu_layers > int(n_layer)) {
3172
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
3173
- // on Windows however this is detrimental unless everything is on the GPU
3174
- #ifndef _WIN32
3175
- backend_norm = llama_backend_offload;
3176
- #else
3177
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
3178
- #endif // _WIN32
3179
-
3381
+ backend_norm = llama_backend_offload;
3180
3382
  backend_output = llama_backend_offload_split;
3181
3383
  } else {
3182
3384
  backend_norm = GGML_BACKEND_CPU;
@@ -3236,14 +3438,7 @@ static void llm_load_tensors(
3236
3438
  ggml_backend_type backend_output;
3237
3439
 
3238
3440
  if (n_gpu_layers > int(n_layer)) {
3239
- // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
3240
- // on Windows however this is detrimental unless everything is on the GPU
3241
- #ifndef _WIN32
3242
- backend_norm = llama_backend_offload;
3243
- #else
3244
- backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
3245
- #endif // _WIN32
3246
-
3441
+ backend_norm = llama_backend_offload;
3247
3442
  backend_output = llama_backend_offload_split;
3248
3443
  } else {
3249
3444
  backend_norm = GGML_BACKEND_CPU;
@@ -3300,6 +3495,64 @@ static void llm_load_tensors(
3300
3495
  }
3301
3496
  }
3302
3497
  } break;
3498
+ case LLM_ARCH_QWEN:
3499
+ {
3500
+ model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
3501
+ {
3502
+ ggml_backend_type backend_norm;
3503
+ ggml_backend_type backend_output;
3504
+
3505
+ if (n_gpu_layers > int(n_layer)) {
3506
+ backend_norm = llama_backend_offload;
3507
+ backend_output = llama_backend_offload_split;
3508
+ } else {
3509
+ backend_norm = GGML_BACKEND_CPU;
3510
+ backend_output = GGML_BACKEND_CPU;
3511
+ }
3512
+
3513
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
3514
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
3515
+
3516
+ if (backend_norm == GGML_BACKEND_GPU) {
3517
+ vram_weights += ggml_nbytes(model.output_norm);
3518
+ }
3519
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
3520
+ vram_weights += ggml_nbytes(model.output);
3521
+ }
3522
+ }
3523
+
3524
+ const uint32_t n_ff = hparams.n_ff / 2;
3525
+
3526
+ const int i_gpu_start = n_layer - n_gpu_layers;
3527
+
3528
+ model.layers.resize(n_layer);
3529
+
3530
+ for (uint32_t i = 0; i < n_layer; ++i) {
3531
+ const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
3532
+ const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
3533
+
3534
+ auto & layer = model.layers[i];
3535
+
3536
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
3537
+
3538
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd * 3}, backend_split);
3539
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd * 3}, backend);
3540
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
3541
+
3542
+ layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
3543
+
3544
+ layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
3545
+ layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
3546
+ layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
3547
+
3548
+ if (backend == GGML_BACKEND_GPU) {
3549
+ vram_weights +=
3550
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
3551
+ ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
3552
+ ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
3553
+ }
3554
+ }
3555
+ } break;
3303
3556
 
3304
3557
  default:
3305
3558
  throw std::runtime_error("unknown architecture");
@@ -3326,8 +3579,8 @@ static void llm_load_tensors(
3326
3579
  }
3327
3580
 
3328
3581
  #ifdef GGML_USE_CUBLAS
3329
- const int max_backend_supported_layers = hparams.n_layer + 3;
3330
- const int max_offloadable_layers = hparams.n_layer + 3;
3582
+ const int max_backend_supported_layers = hparams.n_layer + 1;
3583
+ const int max_offloadable_layers = hparams.n_layer + 1;
3331
3584
  #elif GGML_USE_CLBLAST
3332
3585
  const int max_backend_supported_layers = hparams.n_layer + 1;
3333
3586
  const int max_offloadable_layers = hparams.n_layer + 1;
@@ -3368,7 +3621,7 @@ static void llm_load_tensors(
3368
3621
 
3369
3622
  static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
3370
3623
  try {
3371
- llama_model_loader ml(fname, params.use_mmap);
3624
+ llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
3372
3625
 
3373
3626
  model.hparams.vocab_only = params.vocab_only;
3374
3627
 
@@ -3464,7 +3717,7 @@ static void llm_build_k_shift(
3464
3717
  struct ggml_cgraph * graph,
3465
3718
  llm_rope_type type,
3466
3719
  int64_t n_ctx,
3467
- int64_t n_rot,
3720
+ int n_rot,
3468
3721
  float freq_base,
3469
3722
  float freq_scale,
3470
3723
  const llm_build_cb & cb) {
@@ -3495,11 +3748,11 @@ static void llm_build_k_shift(
3495
3748
  struct ggml_tensor * tmp =
3496
3749
  // we rotate only the first n_rot dimensions
3497
3750
  ggml_rope_custom_inplace(ctx,
3498
- ggml_view_3d(ctx, kv.k,
3499
- n_rot, n_head_kv, n_ctx,
3500
- ggml_element_size(kv.k)*n_embd_head,
3501
- ggml_element_size(kv.k)*n_embd_gqa,
3502
- ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
3751
+ ggml_view_3d(ctx, kv.k_l[il],
3752
+ n_embd_head, n_head_kv, n_ctx,
3753
+ ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
3754
+ ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
3755
+ 0),
3503
3756
  K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
3504
3757
  ext_factor, attn_factor, beta_fast, beta_slow);
3505
3758
  cb(tmp, "K_shifted", il);
@@ -3526,13 +3779,13 @@ static void llm_build_kv_store(
3526
3779
  //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
3527
3780
  cb(v_cur_t, "v_cur_t", il);
3528
3781
 
3529
- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
3530
- (ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
3782
+ struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
3783
+ (ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
3531
3784
  cb(k_cache_view, "k_cache_view", il);
3532
3785
 
3533
- struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
3534
- ( n_ctx)*ggml_element_size(kv.v),
3535
- (il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
3786
+ struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
3787
+ ( n_ctx)*ggml_element_size(kv.v_l[il]),
3788
+ (kv_head)*ggml_element_size(kv.v_l[il]));
3536
3789
  cb(v_cache_view, "v_cache_view", il);
3537
3790
 
3538
3791
  // important: storing RoPE-ed version of K in the KV cache!
@@ -3684,40 +3937,46 @@ static struct ggml_tensor * llm_build_kqv(
3684
3937
  cb(q, "q", il);
3685
3938
 
3686
3939
  struct ggml_tensor * k =
3687
- ggml_view_3d(ctx, kv.k,
3940
+ ggml_view_3d(ctx, kv.k_l[il],
3688
3941
  n_embd_head, n_kv, n_head_kv,
3689
- ggml_element_size(kv.k)*n_embd_gqa,
3690
- ggml_element_size(kv.k)*n_embd_head,
3691
- ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
3942
+ ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
3943
+ ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
3944
+ 0);
3692
3945
  cb(k, "k", il);
3693
3946
 
3694
3947
  struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
3695
3948
  cb(kq, "kq", il);
3696
3949
 
3697
- kq = ggml_scale(ctx, kq, kq_scale);
3698
- cb(kq, "kq_scaled", il);
3699
-
3700
3950
  if (max_alibi_bias > 0.0f) {
3701
- // TODO: n_head or n_head_kv
3702
- // TODO: K-shift is likely not working
3703
- // TODO: change to ggml_add
3704
- kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
3705
- cb(kq, "kq_scaled_alibi", il);
3706
- }
3951
+ // temporary branch until we figure out how to handle ggml_alibi through ggml_add
3952
+ kq = ggml_scale(ctx, kq, kq_scale);
3953
+ cb(kq, "kq_scaled", il);
3707
3954
 
3708
- kq = ggml_add(ctx, kq, kq_mask);
3709
- cb(kq, "kq_masked", il);
3955
+ if (max_alibi_bias > 0.0f) {
3956
+ // TODO: n_head or n_head_kv
3957
+ // TODO: K-shift is likely not working
3958
+ // TODO: change to ggml_add
3959
+ kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
3960
+ cb(kq, "kq_scaled_alibi", il);
3961
+ }
3710
3962
 
3711
- kq = ggml_soft_max(ctx, kq);
3712
- cb(kq, "kq_soft_max", il);
3963
+ kq = ggml_add(ctx, kq, kq_mask);
3964
+ cb(kq, "kq_masked", il);
3965
+
3966
+ kq = ggml_soft_max(ctx, kq);
3967
+ cb(kq, "kq_soft_max", il);
3968
+ } else {
3969
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
3970
+ cb(kq, "kq_soft_max_ext", il);
3971
+ }
3713
3972
 
3714
3973
  // split cached v into n_head heads
3715
3974
  struct ggml_tensor * v =
3716
- ggml_view_3d(ctx, kv.v,
3975
+ ggml_view_3d(ctx, kv.v_l[il],
3717
3976
  n_kv, n_embd_head, n_head_kv,
3718
- ggml_element_size(kv.v)*n_ctx,
3719
- ggml_element_size(kv.v)*n_ctx*n_embd_head,
3720
- ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
3977
+ ggml_element_size(kv.v_l[il])*n_ctx,
3978
+ ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
3979
+ 0);
3721
3980
  cb(v, "v", il);
3722
3981
 
3723
3982
  struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@@ -3875,12 +4134,24 @@ struct llm_build_context {
3875
4134
  // compute Q and K and RoPE them
3876
4135
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
3877
4136
  cb(Qcur, "Qcur", il);
4137
+ if (model.layers[il].bq) {
4138
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
4139
+ cb(Qcur, "Qcur", il);
4140
+ }
3878
4141
 
3879
4142
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
3880
4143
  cb(Kcur, "Kcur", il);
4144
+ if (model.layers[il].bk) {
4145
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
4146
+ cb(Kcur, "Kcur", il);
4147
+ }
3881
4148
 
3882
4149
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
3883
4150
  cb(Vcur, "Vcur", il);
4151
+ if (model.layers[il].bv) {
4152
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
4153
+ cb(Vcur, "Vcur", il);
4154
+ }
3884
4155
 
3885
4156
  Qcur = ggml_rope_custom(
3886
4157
  ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
@@ -3899,7 +4170,7 @@ struct llm_build_context {
3899
4170
  llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
3900
4171
 
3901
4172
  cur = llm_build_kqv(ctx0, hparams, kv_self,
3902
- model.layers[il].wo, NULL,
4173
+ model.layers[il].wo, model.layers[il].bo,
3903
4174
  Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
3904
4175
  cb(cur, "kqv_out", il);
3905
4176
  }
@@ -4297,6 +4568,7 @@ struct llm_build_context {
4297
4568
  inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
4298
4569
  cb(inpL, "imp_embd", -1);
4299
4570
 
4571
+ // inp_pos - contains the positions
4300
4572
  struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
4301
4573
  cb(inp_pos, "inp_pos", -1);
4302
4574
 
@@ -4304,6 +4576,7 @@ struct llm_build_context {
4304
4576
  struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
4305
4577
  cb(KQ_scale, "KQ_scale", -1);
4306
4578
 
4579
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
4307
4580
  struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
4308
4581
  cb(KQ_mask, "KQ_mask", -1);
4309
4582
 
@@ -4892,6 +5165,121 @@ struct llm_build_context {
4892
5165
 
4893
5166
  return gf;
4894
5167
  }
5168
+
5169
+ struct ggml_cgraph * build_qwen() {
5170
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5171
+
5172
+ struct ggml_tensor * cur;
5173
+ struct ggml_tensor * inpL;
5174
+
5175
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
5176
+ cb(inpL, "inp_embd", -1);
5177
+
5178
+ // inp_pos - contains the positions
5179
+ struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
5180
+ cb(inp_pos, "inp_pos", -1);
5181
+
5182
+ // KQ_scale
5183
+ struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
5184
+ cb(KQ_scale, "KQ_scale", -1);
5185
+
5186
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
5187
+ struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
5188
+ cb(KQ_mask, "KQ_mask", -1);
5189
+
5190
+ // shift the entire K-cache if needed
5191
+ if (do_rope_shift) {
5192
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5193
+ }
5194
+
5195
+ for (int il = 0; il < n_layer; ++il) {
5196
+ struct ggml_tensor * inpSA = inpL;
5197
+
5198
+ cur = llm_build_norm(ctx0, inpL, hparams,
5199
+ model.layers[il].attn_norm, NULL,
5200
+ LLM_NORM_RMS, cb, il);
5201
+ cb(cur, "attn_norm", il);
5202
+
5203
+ // self-attention
5204
+ {
5205
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
5206
+ cb(cur, "wqkv", il);
5207
+
5208
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5209
+ cb(cur, "bqkv", il);
5210
+
5211
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5212
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5213
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
5214
+
5215
+ cb(Qcur, "Qcur", il);
5216
+ cb(Kcur, "Kcur", il);
5217
+ cb(Vcur, "Vcur", il);
5218
+
5219
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5220
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5221
+
5222
+ // using mode = 2 for neox mode
5223
+ Qcur = ggml_rope_custom(
5224
+ ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
5225
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5226
+ );
5227
+ cb(Qcur, "Qcur", il);
5228
+
5229
+ Kcur = ggml_rope_custom(
5230
+ ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
5231
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5232
+ );
5233
+ cb(Kcur, "Kcur", il);
5234
+
5235
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5236
+
5237
+ cur = llm_build_kqv(ctx0, hparams, kv_self,
5238
+ model.layers[il].wo, NULL,
5239
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
5240
+ cb(cur, "kqv_out", il);
5241
+ }
5242
+
5243
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
5244
+ cb(ffn_inp, "ffn_inp", il);
5245
+
5246
+ // feed-forward forward
5247
+ {
5248
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
5249
+ model.layers[il].ffn_norm, NULL,
5250
+ LLM_NORM_RMS, cb, il);
5251
+ cb(cur, "ffn_norm", il);
5252
+
5253
+ cur = llm_build_ffn(ctx0, cur,
5254
+ model.layers[il].ffn_up, NULL,
5255
+ model.layers[il].ffn_gate, NULL,
5256
+ model.layers[il].ffn_down, NULL,
5257
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
5258
+ cb(cur, "ffn_out", il);
5259
+ }
5260
+
5261
+ cur = ggml_add(ctx0, cur, ffn_inp);
5262
+ cb(cur, "l_out", il);
5263
+
5264
+ // input for next layer
5265
+ inpL = cur;
5266
+ }
5267
+
5268
+ cur = inpL;
5269
+
5270
+ cur = llm_build_norm(ctx0, cur, hparams,
5271
+ model.output_norm, NULL,
5272
+ LLM_NORM_RMS, cb, -1);
5273
+ cb(cur, "result_norm", -1);
5274
+
5275
+ // lm_head
5276
+ cur = ggml_mul_mat(ctx0, model.output, cur);
5277
+ cb(cur, "result_output", -1);
5278
+
5279
+ ggml_build_forward_expand(gf, cur);
5280
+
5281
+ return gf;
5282
+ }
4895
5283
  };
4896
5284
 
4897
5285
  //
@@ -4902,8 +5290,8 @@ struct llm_build_context {
4902
5290
  enum llm_offload_func_e {
4903
5291
  OFFLOAD_FUNC_NOP,
4904
5292
  OFFLOAD_FUNC,
4905
- OFFLOAD_FUNC_KQ,
4906
- OFFLOAD_FUNC_V,
5293
+ OFFLOAD_FUNC_FRC, // force offload
5294
+ OFFLOAD_FUNC_KQV,
4907
5295
  OFFLOAD_FUNC_NR,
4908
5296
  OFFLOAD_FUNC_EMB,
4909
5297
  OFFLOAD_FUNC_OUT,
@@ -4989,11 +5377,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
4989
5377
  //{ "inp_embd", OFFLOAD_FUNC_NR }, // TODO: missing K-quants get_rows kernel
4990
5378
  { "pos_embd", OFFLOAD_FUNC_NR },
4991
5379
 
4992
- { "inp_pos", OFFLOAD_FUNC_KQ }, // this is often used for KQ ops (e.g. rope)
4993
- { "KQ_scale", OFFLOAD_FUNC_KQ },
4994
- { "KQ_mask", OFFLOAD_FUNC_KQ },
4995
- { "K_shift", OFFLOAD_FUNC_KQ },
4996
- { "K_shifted", OFFLOAD_FUNC_KQ },
5380
+ { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
5381
+ { "KQ_scale", OFFLOAD_FUNC_FRC },
5382
+ { "KQ_mask", OFFLOAD_FUNC_FRC },
5383
+ { "K_shift", OFFLOAD_FUNC_FRC },
5384
+
5385
+ { "K_shifted", OFFLOAD_FUNC },
4997
5386
 
4998
5387
  { "inp_norm", OFFLOAD_FUNC_NR },
4999
5388
  { "inp_norm_w", OFFLOAD_FUNC_NR },
@@ -5006,37 +5395,38 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
5006
5395
  { "attn_norm", OFFLOAD_FUNC },
5007
5396
  { "attn_norm_2", OFFLOAD_FUNC },
5008
5397
 
5009
- { "wqkv", OFFLOAD_FUNC_KQ },
5010
- { "bqkv", OFFLOAD_FUNC_KQ },
5011
- { "wqkv_clamped", OFFLOAD_FUNC_KQ },
5012
-
5013
- { "tmpk", OFFLOAD_FUNC_KQ },
5014
- { "tmpq", OFFLOAD_FUNC_KQ },
5015
- { "tmpv", OFFLOAD_FUNC_V },
5016
- { "Kcur", OFFLOAD_FUNC_KQ },
5017
- { "Qcur", OFFLOAD_FUNC_KQ },
5018
- { "Vcur", OFFLOAD_FUNC_V },
5019
-
5020
- { "krot", OFFLOAD_FUNC_KQ },
5021
- { "qrot", OFFLOAD_FUNC_KQ },
5022
- { "kpass", OFFLOAD_FUNC_KQ },
5023
- { "qpass", OFFLOAD_FUNC_KQ },
5024
- { "krotated", OFFLOAD_FUNC_KQ },
5025
- { "qrotated", OFFLOAD_FUNC_KQ },
5026
-
5027
- { "q", OFFLOAD_FUNC_KQ },
5028
- { "k", OFFLOAD_FUNC_KQ },
5029
- { "kq", OFFLOAD_FUNC_KQ },
5030
- { "kq_scaled", OFFLOAD_FUNC_KQ },
5031
- { "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
5032
- { "kq_masked", OFFLOAD_FUNC_KQ },
5033
- { "kq_soft_max", OFFLOAD_FUNC_V },
5034
- { "v", OFFLOAD_FUNC_V },
5035
- { "kqv", OFFLOAD_FUNC_V },
5036
- { "kqv_merged", OFFLOAD_FUNC_V },
5037
- { "kqv_merged_cont", OFFLOAD_FUNC_V },
5038
- { "kqv_wo", OFFLOAD_FUNC_V },
5039
- { "kqv_out", OFFLOAD_FUNC_V },
5398
+ { "wqkv", OFFLOAD_FUNC_KQV },
5399
+ { "bqkv", OFFLOAD_FUNC_KQV },
5400
+ { "wqkv_clamped", OFFLOAD_FUNC_KQV },
5401
+
5402
+ { "tmpk", OFFLOAD_FUNC_KQV },
5403
+ { "tmpq", OFFLOAD_FUNC_KQV },
5404
+ { "tmpv", OFFLOAD_FUNC_KQV },
5405
+ { "Kcur", OFFLOAD_FUNC_KQV },
5406
+ { "Qcur", OFFLOAD_FUNC_KQV },
5407
+ { "Vcur", OFFLOAD_FUNC_KQV },
5408
+
5409
+ { "krot", OFFLOAD_FUNC_KQV },
5410
+ { "qrot", OFFLOAD_FUNC_KQV },
5411
+ { "kpass", OFFLOAD_FUNC_KQV },
5412
+ { "qpass", OFFLOAD_FUNC_KQV },
5413
+ { "krotated", OFFLOAD_FUNC_KQV },
5414
+ { "qrotated", OFFLOAD_FUNC_KQV },
5415
+
5416
+ { "q", OFFLOAD_FUNC_KQV },
5417
+ { "k", OFFLOAD_FUNC_KQV },
5418
+ { "kq", OFFLOAD_FUNC_KQV },
5419
+ { "kq_scaled", OFFLOAD_FUNC_KQV },
5420
+ { "kq_scaled_alibi", OFFLOAD_FUNC_KQV },
5421
+ { "kq_masked", OFFLOAD_FUNC_KQV },
5422
+ { "kq_soft_max", OFFLOAD_FUNC_KQV },
5423
+ { "kq_soft_max_ext", OFFLOAD_FUNC_KQV },
5424
+ { "v", OFFLOAD_FUNC_KQV },
5425
+ { "kqv", OFFLOAD_FUNC_KQV },
5426
+ { "kqv_merged", OFFLOAD_FUNC_KQV },
5427
+ { "kqv_merged_cont", OFFLOAD_FUNC_KQV },
5428
+ { "kqv_wo", OFFLOAD_FUNC_KQV },
5429
+ { "kqv_out", OFFLOAD_FUNC_KQV },
5040
5430
 
5041
5431
  { "ffn_inp", OFFLOAD_FUNC },
5042
5432
  { "ffn_norm", OFFLOAD_FUNC },
@@ -5228,15 +5618,15 @@ static struct ggml_cgraph * llama_build_graph(
5228
5618
  { OFFLOAD_FUNC_NOP, "CPU" },
5229
5619
  { OFFLOAD_FUNC_OUT, "CPU" },
5230
5620
  #ifdef GGML_USE_CUBLAS
5231
- { OFFLOAD_FUNC, "GPU (CUDA)" },
5232
- { OFFLOAD_FUNC_KQ, "GPU (CUDA) KQ" },
5233
- { OFFLOAD_FUNC_V, "GPU (CUDA) V" },
5234
- { OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
5621
+ { OFFLOAD_FUNC, "GPU (CUDA)" },
5622
+ { OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" },
5623
+ { OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" },
5624
+ { OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
5235
5625
  { OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
5236
5626
  #else
5237
5627
  { OFFLOAD_FUNC, "CPU" },
5238
- { OFFLOAD_FUNC_KQ, "CPU" },
5239
- { OFFLOAD_FUNC_V, "CPU" },
5628
+ { OFFLOAD_FUNC_FRC, "CPU" },
5629
+ { OFFLOAD_FUNC_KQV, "CPU" },
5240
5630
  { OFFLOAD_FUNC_NR, "CPU" },
5241
5631
  { OFFLOAD_FUNC_EMB, "CPU" },
5242
5632
  #endif // GGML_USE_CUBLAS
@@ -5269,18 +5659,23 @@ static struct ggml_cgraph * llama_build_graph(
5269
5659
  }
5270
5660
  }
5271
5661
  break;
5272
- case OFFLOAD_FUNC_NR:
5273
- if (n_gpu_layers <= n_layer + 0) {
5662
+ case OFFLOAD_FUNC_FRC:
5663
+ if (!lctx.cparams.offload_kqv) {
5274
5664
  func_e = OFFLOAD_FUNC_NOP;
5275
- }
5276
- break;
5277
- case OFFLOAD_FUNC_V:
5278
- if (n_gpu_layers <= n_layer + 1) {
5665
+ } break;
5666
+ case OFFLOAD_FUNC_KQV:
5667
+ if (!lctx.cparams.offload_kqv) {
5279
5668
  func_e = OFFLOAD_FUNC_NOP;
5669
+ } else {
5670
+ if (n_gpu_layers < n_layer) {
5671
+ if (il < i_gpu_start) {
5672
+ func_e = OFFLOAD_FUNC_NOP;
5673
+ }
5674
+ }
5280
5675
  }
5281
5676
  break;
5282
- case OFFLOAD_FUNC_KQ:
5283
- if (n_gpu_layers <= n_layer + 2) {
5677
+ case OFFLOAD_FUNC_NR:
5678
+ if (n_gpu_layers <= n_layer + 0) {
5284
5679
  func_e = OFFLOAD_FUNC_NOP;
5285
5680
  }
5286
5681
  break;
@@ -5305,8 +5700,8 @@ static struct ggml_cgraph * llama_build_graph(
5305
5700
  case OFFLOAD_FUNC_NOP:
5306
5701
  case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
5307
5702
  case OFFLOAD_FUNC:
5308
- case OFFLOAD_FUNC_KQ:
5309
- case OFFLOAD_FUNC_V:
5703
+ case OFFLOAD_FUNC_KQV:
5704
+ case OFFLOAD_FUNC_FRC:
5310
5705
  case OFFLOAD_FUNC_NR:
5311
5706
  case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break;
5312
5707
  default: GGML_ASSERT(false);
@@ -5365,6 +5760,10 @@ static struct ggml_cgraph * llama_build_graph(
5365
5760
  {
5366
5761
  result = llm.build_stablelm();
5367
5762
  } break;
5763
+ case LLM_ARCH_QWEN:
5764
+ {
5765
+ result = llm.build_qwen();
5766
+ } break;
5368
5767
  default:
5369
5768
  GGML_ASSERT(false);
5370
5769
  }
@@ -5487,8 +5886,8 @@ static int llama_decode_internal(
5487
5886
  // a heuristic, to avoid attending the full cache if it is not yet utilized
5488
5887
  // after enough generations, the benefit from this heuristic disappears
5489
5888
  // if we start defragmenting the cache, the benefit from this will be more important
5490
- //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
5491
- kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
5889
+ kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
5890
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
5492
5891
 
5493
5892
  //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
5494
5893
 
@@ -5539,18 +5938,8 @@ static int llama_decode_internal(
5539
5938
  n_threads = std::min(4, n_threads);
5540
5939
  }
5541
5940
 
5542
- // If all tensors can be run on the GPU then using more than 1 thread is detrimental.
5543
- const bool full_offload_supported =
5544
- model.arch == LLM_ARCH_LLAMA ||
5545
- model.arch == LLM_ARCH_BAICHUAN ||
5546
- model.arch == LLM_ARCH_FALCON ||
5547
- model.arch == LLM_ARCH_REFACT ||
5548
- model.arch == LLM_ARCH_MPT ||
5549
- model.arch == LLM_ARCH_STARCODER ||
5550
- model.arch == LLM_ARCH_STABLELM;
5551
-
5552
- const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
5553
- if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
5941
+ const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
5942
+ if (ggml_cpu_has_cublas() && fully_offloaded) {
5554
5943
  n_threads = 1;
5555
5944
  }
5556
5945
 
@@ -6408,11 +6797,13 @@ struct llama_grammar_candidate {
6408
6797
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
6409
6798
  // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
6410
6799
  static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
6411
- const char * src,
6800
+ const std::string & src,
6412
6801
  llama_partial_utf8 partial_start) {
6413
6802
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
6414
- const char * pos = src;
6803
+ const char * pos = src.c_str();
6415
6804
  std::vector<uint32_t> code_points;
6805
+ // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
6806
+ code_points.reserve(src.size() + 1);
6416
6807
  uint32_t value = partial_start.value;
6417
6808
  int n_remain = partial_start.n_remain;
6418
6809
 
@@ -7016,6 +7407,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
7016
7407
  // Replace the data in candidates with the new_candidates data
7017
7408
  std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
7018
7409
  candidates->size = new_candidates.size();
7410
+ candidates->sorted = false;
7019
7411
 
7020
7412
  if (ctx) {
7021
7413
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@@ -7100,11 +7492,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
7100
7492
  const llama_token eos = llama_token_eos(&ctx->model);
7101
7493
 
7102
7494
  std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
7495
+ candidates_decoded.reserve(candidates->size);
7103
7496
  std::vector<llama_grammar_candidate> candidates_grammar;
7497
+ candidates_grammar.reserve(candidates->size);
7104
7498
 
7105
7499
  for (size_t i = 0; i < candidates->size; ++i) {
7106
7500
  const llama_token id = candidates->data[i].id;
7107
- const std::string piece = llama_token_to_piece(ctx, id);
7501
+ const std::string & piece = ctx->model.vocab.id_to_token[id].text;
7108
7502
  if (id == eos) {
7109
7503
  if (!allow_eos) {
7110
7504
  candidates->data[i].logit = -INFINITY;
@@ -7112,7 +7506,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
7112
7506
  } else if (piece.empty() || piece[0] == 0) {
7113
7507
  candidates->data[i].logit = -INFINITY;
7114
7508
  } else {
7115
- candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
7509
+ candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
7116
7510
  candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
7117
7511
  }
7118
7512
  }
@@ -7316,10 +7710,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
7316
7710
  GGML_ASSERT(false);
7317
7711
  }
7318
7712
 
7319
- const std::string piece = llama_token_to_piece(ctx, token);
7713
+ const std::string & piece = ctx->model.vocab.id_to_token[token].text;
7320
7714
 
7321
7715
  // Note terminating 0 in decoded string
7322
- const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
7716
+ const auto decoded = decode_utf8(piece, grammar->partial_utf8);
7323
7717
  const auto & code_points = decoded.first;
7324
7718
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
7325
7719
  grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
@@ -7637,18 +8031,21 @@ static void llama_convert_tensor_internal(
7637
8031
  return;
7638
8032
  }
7639
8033
 
7640
- auto block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
7641
- auto block_size_bytes = ggml_type_size(tensor->type);
8034
+ size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
8035
+ size_t block_size_bytes = ggml_type_size(tensor->type);
7642
8036
 
7643
8037
  GGML_ASSERT(nelements % block_size == 0);
7644
- auto nblocks = nelements / block_size;
7645
- auto blocks_per_thread = nblocks / nthread;
7646
- auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
8038
+ size_t nblocks = nelements / block_size;
8039
+ size_t blocks_per_thread = nblocks / nthread;
8040
+ size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
8041
+
8042
+ size_t in_buff_offs = 0;
8043
+ size_t out_buff_offs = 0;
7647
8044
 
7648
- for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) {
7649
- auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
7650
- auto thr_elems = thr_blocks * block_size; // number of elements for this thread
7651
- auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
8045
+ for (int tnum = 0; tnum < nthread; tnum++) {
8046
+ size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
8047
+ size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
8048
+ size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
7652
8049
 
7653
8050
  auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
7654
8051
  if (typ == GGML_TYPE_F16) {
@@ -7818,7 +8215,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
7818
8215
  constexpr bool use_mmap = false;
7819
8216
  #endif
7820
8217
 
7821
- llama_model_loader ml(fname_inp, use_mmap);
8218
+ llama_model_loader ml(fname_inp, use_mmap, NULL);
7822
8219
  if (ml.use_mmap) {
7823
8220
  ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa()));
7824
8221
  }
@@ -8114,7 +8511,7 @@ static int llama_apply_lora_from_file_internal(
8114
8511
  std::vector<uint8_t> base_buf;
8115
8512
  if (path_base_model) {
8116
8513
  LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
8117
- ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
8514
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL));
8118
8515
 
8119
8516
  size_t ctx_size;
8120
8517
  size_t mmapped_size;
@@ -8342,6 +8739,7 @@ struct llama_model_params llama_model_default_params() {
8342
8739
  /*.tensor_split =*/ nullptr,
8343
8740
  /*.progress_callback =*/ nullptr,
8344
8741
  /*.progress_callback_user_data =*/ nullptr,
8742
+ /*.kv_overrides =*/ nullptr,
8345
8743
  /*.vocab_only =*/ false,
8346
8744
  /*.use_mmap =*/ true,
8347
8745
  /*.use_mlock =*/ false,
@@ -8369,10 +8767,12 @@ struct llama_context_params llama_context_default_params() {
8369
8767
  /*.yarn_beta_fast =*/ 32.0f,
8370
8768
  /*.yarn_beta_slow =*/ 1.0f,
8371
8769
  /*.yarn_orig_ctx =*/ 0,
8770
+ /*.type_k =*/ GGML_TYPE_F16,
8771
+ /*.type_v =*/ GGML_TYPE_F16,
8372
8772
  /*.mul_mat_q =*/ true,
8373
- /*.f16_kv =*/ true,
8374
8773
  /*.logits_all =*/ false,
8375
8774
  /*.embedding =*/ false,
8775
+ /*.offload_kqv =*/ true,
8376
8776
  };
8377
8777
 
8378
8778
  return result;
@@ -8489,6 +8889,7 @@ struct llama_context * llama_new_context_with_model(
8489
8889
  cparams.yarn_beta_fast = params.yarn_beta_fast;
8490
8890
  cparams.yarn_beta_slow = params.yarn_beta_slow;
8491
8891
  cparams.mul_mat_q = params.mul_mat_q;
8892
+ cparams.offload_kqv = params.offload_kqv;
8492
8893
 
8493
8894
  cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
8494
8895
  cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -8522,19 +8923,36 @@ struct llama_context * llama_new_context_with_model(
8522
8923
  ctx->rng = std::mt19937(params.seed);
8523
8924
  ctx->logits_all = params.logits_all;
8524
8925
 
8525
- ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
8926
+ const ggml_type type_k = params.type_k;
8927
+ const ggml_type type_v = params.type_v;
8928
+
8929
+ GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
8930
+ GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
8526
8931
 
8527
8932
  // reserve memory for context buffers
8528
8933
  if (!hparams.vocab_only) {
8529
- if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
8934
+ if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
8530
8935
  LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
8531
8936
  llama_free(ctx);
8532
8937
  return nullptr;
8533
8938
  }
8534
8939
 
8535
8940
  {
8536
- const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
8537
- LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
8941
+ size_t memory_size_k = 0;
8942
+ size_t memory_size_v = 0;
8943
+
8944
+ for (auto & k : ctx->kv_self.k_l) {
8945
+ memory_size_k += ggml_nbytes(k);
8946
+ }
8947
+
8948
+ for (auto & v : ctx->kv_self.v_l) {
8949
+ memory_size_v += ggml_nbytes(v);
8950
+ }
8951
+
8952
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
8953
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
8954
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
8955
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
8538
8956
  }
8539
8957
 
8540
8958
  // resized during inference
@@ -8564,8 +8982,6 @@ struct llama_context * llama_new_context_with_model(
8564
8982
 
8565
8983
  #ifdef GGML_USE_METAL
8566
8984
  if (model->n_gpu_layers > 0) {
8567
- ggml_metal_log_set_callback(llama_log_callback_default, NULL);
8568
-
8569
8985
  ctx->ctx_metal = ggml_metal_init(1);
8570
8986
  if (!ctx->ctx_metal) {
8571
8987
  LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
@@ -8607,8 +9023,12 @@ struct llama_context * llama_new_context_with_model(
8607
9023
  }
8608
9024
 
8609
9025
  size_t kv_vram_size = 0;
8610
- add_tensor(ctx->kv_self.k, kv_vram_size);
8611
- add_tensor(ctx->kv_self.v, kv_vram_size);
9026
+ for (auto & k : ctx->kv_self.k_l) {
9027
+ add_tensor(k, kv_vram_size);
9028
+ }
9029
+ for (auto & v : ctx->kv_self.v_l) {
9030
+ add_tensor(v, kv_vram_size);
9031
+ }
8612
9032
 
8613
9033
  size_t ctx_vram_size = alloc_size + kv_vram_size;
8614
9034
  size_t total_vram_size = model_vram_size + ctx_vram_size;
@@ -9078,37 +9498,45 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
9078
9498
  data_ctx->write(&kv_used, sizeof(kv_used));
9079
9499
 
9080
9500
  if (kv_buf_size) {
9081
- const size_t elt_size = ggml_element_size(kv_self.k);
9501
+ const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
9082
9502
 
9083
- ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
9503
+ ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
9084
9504
  ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
9085
9505
 
9086
- ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
9087
- std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
9088
- kout3d->data = kout3d_data.data();
9506
+ std::vector<std::vector<uint8_t>> kout2d_data(n_layer);
9507
+ std::vector<std::vector<uint8_t>> vout2d_data(n_layer);
9089
9508
 
9090
- ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
9091
- std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
9092
- vout3d->data = vout3d_data.data();
9509
+ for (int il = 0; il < (int) n_layer; ++il) {
9510
+ ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
9511
+ kout2d_data[il].resize(ggml_nbytes(kout2d));
9512
+ kout2d->data = kout2d_data[il].data();
9093
9513
 
9094
- ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
9095
- n_embd, kv_head, n_layer,
9096
- elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
9514
+ ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
9515
+ vout2d_data[il].resize(ggml_nbytes(vout2d));
9516
+ vout2d->data = vout2d_data[il].data();
9097
9517
 
9098
- ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
9099
- kv_head, n_embd, n_layer,
9100
- elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
9518
+ ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
9519
+ n_embd, kv_head,
9520
+ elt_size*n_embd, 0);
9521
+
9522
+ ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
9523
+ kv_head, n_embd,
9524
+ elt_size*n_ctx, 0);
9525
+
9526
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d));
9527
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d));
9528
+ }
9101
9529
 
9102
- ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
9103
- ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
9104
9530
  ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
9105
9531
 
9106
9532
  ggml_free(cpy_ctx);
9107
9533
 
9108
- // our data is now in the kout3d_data and vout3d_data buffers
9534
+ // our data is now in the kout2d_data and vout2d_data buffers
9109
9535
  // write them to file
9110
- data_ctx->write(kout3d_data.data(), kout3d_data.size());
9111
- data_ctx->write(vout3d_data.data(), vout3d_data.size());
9536
+ for (uint32_t il = 0; il < n_layer; ++il) {
9537
+ data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size());
9538
+ data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size());
9539
+ }
9112
9540
  }
9113
9541
 
9114
9542
  for (uint32_t i = 0; i < kv_size; ++i) {
@@ -9208,29 +9636,32 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
9208
9636
  if (kv_buf_size) {
9209
9637
  GGML_ASSERT(kv_self.buf.size == kv_buf_size);
9210
9638
 
9211
- const size_t elt_size = ggml_element_size(kv_self.k);
9639
+ const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
9212
9640
 
9213
- ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
9641
+ ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
9214
9642
  ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
9215
9643
 
9216
- ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
9217
- kin3d->data = (void *) inp;
9218
- inp += ggml_nbytes(kin3d);
9644
+ for (int il = 0; il < n_layer; ++il) {
9645
+ ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
9646
+ kin2d->data = (void *) inp;
9647
+ inp += ggml_nbytes(kin2d);
9648
+
9649
+ ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
9650
+ vin2d->data = (void *) inp;
9651
+ inp += ggml_nbytes(vin2d);
9219
9652
 
9220
- ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
9221
- vin3d->data = (void *) inp;
9222
- inp += ggml_nbytes(vin3d);
9653
+ ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
9654
+ n_embd, kv_head,
9655
+ elt_size*n_embd, 0);
9223
9656
 
9224
- ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
9225
- n_embd, kv_head, n_layer,
9226
- elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
9657
+ ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
9658
+ kv_head, n_embd,
9659
+ elt_size*n_ctx, 0);
9227
9660
 
9228
- ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
9229
- kv_head, n_embd, n_layer,
9230
- elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
9661
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d));
9662
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d));
9663
+ }
9231
9664
 
9232
- ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
9233
- ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
9234
9665
  ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
9235
9666
 
9236
9667
  ggml_free(cpy_ctx);
@@ -9701,6 +10132,9 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
9701
10132
  void llama_log_set(ggml_log_callback log_callback, void * user_data) {
9702
10133
  g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
9703
10134
  g_state.log_callback_user_data = user_data;
10135
+ #ifdef GGML_USE_METAL
10136
+ ggml_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
10137
+ #endif
9704
10138
  }
9705
10139
 
9706
10140
  static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {