llama_cpp 0.9.0 → 0.9.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
4
- data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
3
+ metadata.gz: 66c53ea31dd93cc684d6bbc5331bb7e9f12abe2a23e6e16b8f8a3407e62961a0
4
+ data.tar.gz: 723d4f1d879c314d1733c84411e39d470f619a22be6a17d589406e831d8ea97b
5
5
  SHA512:
6
- metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
7
- data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
6
+ metadata.gz: bee0ffe56796ec8bf6240178246c7c95c38ec7cec2bd29f61c1cd85e1230291751c13da850c330fca644089ee2ff524a767b132b5bc6658e95205114e7399ba4
7
+ data.tar.gz: 382d05658c0a0d8df1c03dcaf93c8861bff3326e1d1e0c0cb3b0638f38cc3de5d36990b1f4df6d0bf3ce19337e9507cd5a2d196d893d8baf56d9b38a49738bc2
data/CHANGELOG.md CHANGED
@@ -1,3 +1,16 @@
1
+ ## [[0.9.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.1...v0.9.2)] - 2023-11-11
2
+
3
+ - Bump bundled llama.cpp from b1472 to b1500.
4
+
5
+ ## [[0.9.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.0...v0.9.1)] - 2023-11-03
6
+
7
+ - Bump bundled llama.cpp from b1429 to b1472
8
+ - Rename `kv_cahe_tokens_rm` method to `kv_cahce_clear` in Context.
9
+ - Add `sample_min_p method` to Context.
10
+ - Add `rope_scaling_type`, `rope_freq_base`, `rope_freq_scale`, `yarn_ext_factor`, `yarn_attn_factor`, `yarn_beta_fast`, `yarn_beta_slow`, and `yarn_orig_ctx` to ContextParams.
11
+ - Add `pure` to ModelQuantizeParams.
12
+ - Add contstants for RoPE scaling type.
13
+
1
14
  ## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
2
15
 
3
16
  - Fix missing object file for ggml-backend when building with metal and cublas options.
@@ -5,7 +5,7 @@ require 'fileutils'
5
5
 
6
6
  abort 'libstdc++ is not found.' unless have_library('stdc++')
7
7
 
8
- $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c llama.cpp llama_cpp.cpp]
8
+ $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c ggml-quants.c llama.cpp llama_cpp.cpp]
9
9
  $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
10
  $srcs << 'ggml-mpi.c' if with_config('mpi')
11
11
  $CFLAGS << ' -w -DNDEBUG'
@@ -18,12 +18,6 @@ if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>',
18
18
  $CXXFLAGS << ' -pthread'
19
19
  end
20
20
 
21
- unless with_config('no_k_quants')
22
- $CFLAGS << ' -DGGML_USE_K_QUANTS'
23
- $CXXFLAGS << ' -DGGML_USE_K_QUANTS'
24
- $srcs << 'k_quants.c'
25
- end
26
-
27
21
  if with_config('qkk_64')
28
22
  $CFLAGS << ' -DGGML_QKK_64'
29
23
  $CXXFLAGS << ' -DGGML_QKK_64'
@@ -53,16 +47,14 @@ if with_config('metal')
53
47
  $CFLAGS << ' -DGGML_USE_METAL'
54
48
  $CXXFLAGS << ' -DGGML_USE_METAL'
55
49
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
56
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
57
- $objs << 'k_quants.o' unless with_config('no_k_quants')
50
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-metal.o llama.o llama_cpp.o]
58
51
  end
59
52
 
60
53
  if with_config('cublas')
61
54
  $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
62
55
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
63
56
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
64
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
65
- $objs << 'k_quants.o' unless with_config('no_k_quants')
57
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-cuda.o llama.o llama_cpp.o]
66
58
  end
67
59
 
68
60
  if with_config('clblast')
@@ -796,10 +796,22 @@ public:
796
796
  rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
797
797
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
798
798
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
799
+ rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
800
+ rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
799
801
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
800
802
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
801
803
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
802
804
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
805
+ rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_ext_factor), 1);
806
+ rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_ext_factor), 0);
807
+ rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_attn_factor), 1);
808
+ rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_attn_factor), 0);
809
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_fast), 1);
810
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_fast), 0);
811
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_slow), 1);
812
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
813
+ rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
814
+ rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
803
815
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
804
816
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
805
817
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
@@ -883,6 +895,18 @@ private:
883
895
  return INT2NUM(ptr->params.n_threads_batch);
884
896
  }
885
897
 
898
+ // rope_scaling_type
899
+ static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
900
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
901
+ ptr->params.rope_scaling_type = NUM2INT(scaling_type);
902
+ return INT2NUM(ptr->params.rope_scaling_type);
903
+ }
904
+
905
+ static VALUE _llama_context_params_get_rope_scaling_type(VALUE self) {
906
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
907
+ return INT2NUM(ptr->params.rope_scaling_type);
908
+ }
909
+
886
910
  // rope_freq_base
887
911
  static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
888
912
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -907,6 +931,66 @@ private:
907
931
  return DBL2NUM(ptr->params.rope_freq_scale);
908
932
  }
909
933
 
934
+ // yarn_ext_factor
935
+ static VALUE _llama_context_params_set_yarn_ext_factor(VALUE self, VALUE yarn_ext_factor) {
936
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
937
+ ptr->params.yarn_ext_factor = NUM2DBL(yarn_ext_factor);
938
+ return DBL2NUM(ptr->params.yarn_ext_factor);
939
+ }
940
+
941
+ static VALUE _llama_context_params_get_yarn_ext_factor(VALUE self) {
942
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
943
+ return DBL2NUM(ptr->params.yarn_ext_factor);
944
+ }
945
+
946
+ // yarn_attn_factor
947
+ static VALUE _llama_context_params_set_yarn_attn_factor(VALUE self, VALUE yarn_attn_factor) {
948
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
949
+ ptr->params.yarn_attn_factor = NUM2DBL(yarn_attn_factor);
950
+ return DBL2NUM(ptr->params.yarn_attn_factor);
951
+ }
952
+
953
+ static VALUE _llama_context_params_get_yarn_attn_factor(VALUE self) {
954
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
955
+ return DBL2NUM(ptr->params.yarn_attn_factor);
956
+ }
957
+
958
+ // yarn_beta_fast
959
+ static VALUE _llama_context_params_set_yarn_beta_fast(VALUE self, VALUE yarn_beta_fast) {
960
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
961
+ ptr->params.yarn_beta_fast = NUM2DBL(yarn_beta_fast);
962
+ return DBL2NUM(ptr->params.yarn_beta_fast);
963
+ }
964
+
965
+ static VALUE _llama_context_params_get_yarn_beta_fast(VALUE self) {
966
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
967
+ return DBL2NUM(ptr->params.yarn_beta_fast);
968
+ }
969
+
970
+ // yarn_beta_slow
971
+ static VALUE _llama_context_params_set_yarn_beta_slow(VALUE self, VALUE yarn_beta_slow) {
972
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
973
+ ptr->params.yarn_beta_slow = NUM2DBL(yarn_beta_slow);
974
+ return DBL2NUM(ptr->params.yarn_beta_slow);
975
+ }
976
+
977
+ static VALUE _llama_context_params_get_yarn_beta_slow(VALUE self) {
978
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
979
+ return DBL2NUM(ptr->params.yarn_beta_slow);
980
+ }
981
+
982
+ // yarn_orig_ctx
983
+ static VALUE _llama_context_params_set_yarn_orig_ctx(VALUE self, VALUE yarn_orig_ctx) {
984
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
985
+ ptr->params.yarn_orig_ctx = NUM2UINT(yarn_orig_ctx);
986
+ return UINT2NUM(ptr->params.yarn_orig_ctx);
987
+ }
988
+
989
+ static VALUE _llama_context_params_get_yarn_orig_ctx(VALUE self) {
990
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
991
+ return UINT2NUM(ptr->params.yarn_orig_ctx);
992
+ }
993
+
910
994
  // mul_mat_q
911
995
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
912
996
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -1011,6 +1095,8 @@ public:
1011
1095
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
1012
1096
  rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_only_copy), 1);
1013
1097
  rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_only_copy), 0);
1098
+ rb_define_method(rb_cLLaMAModelQuantizeParams, "pure=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_pure), 1);
1099
+ rb_define_method(rb_cLLaMAModelQuantizeParams, "pure", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_pure), 0);
1014
1100
  }
1015
1101
 
1016
1102
  private:
@@ -1083,6 +1169,18 @@ private:
1083
1169
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1084
1170
  return ptr->params.only_copy ? Qtrue : Qfalse;
1085
1171
  }
1172
+
1173
+ // pure
1174
+ static VALUE _llama_model_quantize_params_set_pure(VALUE self, VALUE pure) {
1175
+ LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1176
+ ptr->params.pure = RTEST(pure) ? true : false;
1177
+ return ptr->params.pure ? Qtrue : Qfalse;
1178
+ }
1179
+
1180
+ static VALUE _llama_model_quantize_params_get_pure(VALUE self) {
1181
+ LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1182
+ return ptr->params.pure ? Qtrue : Qfalse;
1183
+ }
1086
1184
  };
1087
1185
 
1088
1186
  const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
@@ -1741,7 +1839,7 @@ public:
1741
1839
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1742
1840
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1743
1841
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1744
- rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1842
+ rb_define_method(rb_cLLaMAContext, "kv_cache_clear", RUBY_METHOD_FUNC(_llama_context_kv_cache_clear), 0);
1745
1843
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1746
1844
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1747
1845
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
@@ -1754,6 +1852,7 @@ public:
1754
1852
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
1755
1853
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
1756
1854
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1855
+ rb_define_method(rb_cLLaMAContext, "sample_min_p", RUBY_METHOD_FUNC(_llama_context_sample_min_p), -1);
1757
1856
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1758
1857
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1759
1858
  rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
@@ -2032,13 +2131,13 @@ private:
2032
2131
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2033
2132
  }
2034
2133
 
2035
- static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
2134
+ static VALUE _llama_context_kv_cache_clear(VALUE self) {
2036
2135
  LLaMAContextWrapper* ptr = get_llama_context(self);
2037
2136
  if (ptr->ctx == NULL) {
2038
2137
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2039
2138
  return Qnil;
2040
2139
  }
2041
- llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
2140
+ llama_kv_cache_clear(ptr->ctx);
2042
2141
  return Qnil;
2043
2142
  }
2044
2143
 
@@ -2386,6 +2485,45 @@ private:
2386
2485
  return Qnil;
2387
2486
  }
2388
2487
 
2488
+ static VALUE _llama_context_sample_min_p(int argc, VALUE* argv, VALUE self) {
2489
+ VALUE kw_args = Qnil;
2490
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
2491
+ VALUE kw_values[2] = { Qundef, Qundef };
2492
+ VALUE candidates = Qnil;
2493
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2494
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
2495
+
2496
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2497
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2498
+ return Qnil;
2499
+ }
2500
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2501
+ rb_raise(rb_eArgError, "prob must be a float");
2502
+ return Qnil;
2503
+ }
2504
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2505
+ rb_raise(rb_eArgError, "min_keep must be an integer");
2506
+ return Qnil;
2507
+ }
2508
+
2509
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2510
+ if (ctx_ptr->ctx == NULL) {
2511
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2512
+ return Qnil;
2513
+ }
2514
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2515
+ if (cnd_ptr->array.data == nullptr) {
2516
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2517
+ return Qnil;
2518
+ }
2519
+ const float prob = NUM2DBL(kw_values[0]);
2520
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
2521
+
2522
+ llama_sample_min_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
2523
+
2524
+ return Qnil;
2525
+ }
2526
+
2389
2527
  static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
2390
2528
  VALUE kw_args = Qnil;
2391
2529
  ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
@@ -2881,6 +3019,12 @@ extern "C" void Init_llama_cpp(void) {
2881
3019
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2882
3020
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2883
3021
 
3022
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_UNSPECIFIED", INT2NUM(LLAMA_ROPE_SCALING_UNSPECIFIED));
3023
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_NONE", INT2NUM(LLAMA_ROPE_SCALING_NONE));
3024
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_LINEAR", INT2NUM(LLAMA_ROPE_SCALING_LINEAR));
3025
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_YARN", INT2NUM(LLAMA_ROPE_SCALING_YARN));
3026
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_MAX_VALUE));
3027
+
2884
3028
  std::stringstream ss_magic;
2885
3029
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSN;
2886
3030
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSN", rb_str_new2(ss_magic.str().c_str()));
@@ -378,9 +378,13 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
378
378
  }
379
379
  }
380
380
 
381
- static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) {
381
+ static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view, bool update_backend) {
382
382
  assert(view->view_src != NULL && view->view_src->data != NULL);
383
- view->backend = view->view_src->backend;
383
+
384
+ if (update_backend) {
385
+ view->backend = view->view_src->backend;
386
+ }
387
+
384
388
  view->buffer = view->view_src->buffer;
385
389
  view->data = (char *)view->view_src->data + view->view_offs;
386
390
 
@@ -394,7 +398,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
394
398
  struct hash_node * ht = alloc->hash_table;
395
399
  if (node->data == NULL) {
396
400
  if (ggml_is_view(node)) {
397
- init_view(alloc, node);
401
+ init_view(alloc, node, true);
398
402
  } else {
399
403
  // see if we can reuse a parent's buffer (inplace)
400
404
  if (ggml_op_can_inplace(node->op)) {
@@ -424,15 +428,14 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
424
428
  AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
425
429
  node->view_src = view_src;
426
430
  view_src_hn->n_views += 1;
427
- init_view(alloc, node);
431
+ init_view(alloc, node, false);
428
432
  return;
429
433
  }
430
- }
431
- else {
434
+ } else {
432
435
  AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
433
436
  node->view_src = parent;
434
437
  p_hn->n_views += 1;
435
- init_view(alloc, node);
438
+ init_view(alloc, node, false);
436
439
  return;
437
440
  }
438
441
  }
@@ -463,7 +466,7 @@ size_t ggml_allocr_alloc_graph_n(
463
466
  hash_get(ht, view_src)->n_views += 1;
464
467
  if (node->buffer == NULL && node->data != NULL) {
465
468
  // view of a pre-allocated tensor, didn't call init_view() yet
466
- init_view(alloc, node);
469
+ init_view(alloc, node, true);
467
470
  }
468
471
  }
469
472
 
@@ -474,7 +477,7 @@ size_t ggml_allocr_alloc_graph_n(
474
477
  }
475
478
  hash_get(ht, parent)->n_children += 1;
476
479
  if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
477
- init_view(alloc, parent);
480
+ init_view(alloc, parent, true);
478
481
  }
479
482
  }
480
483
  }