llama_cpp 0.10.3 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/llama_cpp/extconf.rb +35 -110
  5. data/ext/llama_cpp/llama_cpp.cpp +52 -28
  6. data/lib/llama_cpp/version.rb +2 -2
  7. data/sig/llama_cpp.rbs +3 -1
  8. data/vendor/include/.gitkeep +0 -0
  9. data/vendor/lib/.gitkeep +0 -0
  10. data/vendor/tmp/llama.cpp/Makefile +758 -0
  11. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
  12. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
  13. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
  14. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
  15. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
  16. data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
  17. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
  18. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
  19. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
  20. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
  21. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
  22. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
  23. metadata +30 -27
  24. data/ext/llama_cpp/src/ggml-opencl.h +0 -25
  25. data/ext/llama_cpp/src/llama-util.h +0 -546
  26. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
  27. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
  28. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
  29. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
  30. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
  31. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
  32. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
  33. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
  34. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
  35. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
  36. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
  37. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e679eaf867f62033f7d586a8ef131f2126cb3efb2fde49af7c0be17492a66edf
4
- data.tar.gz: da1e9828c456677dc877db6b9754e961ceff27ecfc93c48abd7624d9bb8cdd29
3
+ metadata.gz: 58b6e91201c53b1ced4db60f325d3ced3fa486e24a84d53b0e5c62f613e33fc9
4
+ data.tar.gz: 7b1c4594a79c8ac86aef84be3608dbd51e397c8fe4226d65b3ee87aa1fc800b2
5
5
  SHA512:
6
- metadata.gz: b1fd0737acaa229493e2cbacc79f5b0b6b91233d40e26b57ab7005945ddba79ea3f44e2cca8a0d75df3695373f8eaa2fdfd4ff766a166a688c051beb2acfb126
7
- data.tar.gz: '01889a0ff9ebabd400fa374066659686ee84d4afab973cdd55b36ce5588bded1ed424a88296c1a26acc413f1e4f98f9f6e36eebaf7f37874b91a335dd147d3f4'
6
+ metadata.gz: aece2e7a49f08d0799ff6eb24904ef176fc916eeb57380916b2c8397ea3236991b52fd806aa8c76822a7c1beac86348f3ceb7094880c8d79015debc62babaa0c
7
+ data.tar.gz: 2049d26027e8be4e47bbbb12a9a521776c369ca45d05743dec3c96249a09fe67e31a21aa09dcb8d717f39ee29904ee082bcbfa292fd6c1e956d6e319809ca31c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,16 @@
1
+ ## [[0.11.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.11.0)] - 2024-01-07
2
+
3
+ - Add `set_n_seq_id` and `get_n_seq_id` methods to `Batch`.
4
+
5
+ **Breaking Changes**
6
+ - Change to build shared and static libraries of llama.cpp using its Makefile.
7
+ - Change keyword arguments of `Batch` constructor.
8
+ - Remove upper limit check for index value in `Batch` methods.
9
+
10
+ ## [[0.10.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.10.4)] - 2024-01-06
11
+
12
+ - Bump bundled llama.cpp from b1710 to b1768.
13
+
1
14
  ## [[0.10.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.2...v0.10.3)] - 2023-12-29
2
15
 
3
16
  - Bump bundled llama.cpp from b1686 to b1710.
data/LICENSE.txt CHANGED
@@ -1,6 +1,6 @@
1
1
  The MIT License (MIT)
2
2
 
3
- Copyright (c) 2023 Atsushi Tatsuma
3
+ Copyright (c) 2023-2024 Atsushi Tatsuma
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal
@@ -2,119 +2,44 @@
2
2
 
3
3
  require 'mkmf'
4
4
  require 'fileutils'
5
-
6
- abort 'libstdc++ is not found.' unless have_library('stdc++')
7
-
8
- $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c ggml-quants.c llama.cpp llama_cpp.cpp]
9
- $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
- $srcs << 'ggml-mpi.c' if with_config('mpi')
11
- $CFLAGS << ' -w -DNDEBUG'
12
- $CXXFLAGS << ' -std=c++11 -DNDEBUG'
13
- $INCFLAGS << ' -I$(srcdir)/src'
14
- $VPATH << '$(srcdir)/src'
15
-
16
- if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>', '-pthread')
17
- $CFLAGS << ' -pthread'
18
- $CXXFLAGS << ' -pthread'
19
- end
20
-
21
- if with_config('qkk_64')
22
- $CFLAGS << ' -DGGML_QKK_64'
23
- $CXXFLAGS << ' -DGGML_QKK_64'
24
- end
25
-
26
- if with_config('openblas')
27
- abort 'libopenblas is not found.' unless have_library('openblas')
28
- abort 'cblas.h is not found.' unless have_header('cblas.h')
29
-
30
- $CFLAGS << ' -DGGML_USE_OPENBLAS'
31
- end
32
-
33
- if with_config('blis')
34
- abort 'libblis is not found.' unless have_library('blis')
35
- abort 'cblas.h is not found.' unless have_header('cblas.h')
36
-
37
- $CFLAGS << ' -DGGML_USE_OPENBLAS'
38
- end
39
-
40
- if with_config('accelerate')
41
- abort 'Accelerate framework is not found.' unless have_framework('Accelerate')
42
-
43
- $CFLAGS << ' -DGGML_USE_ACCELERATE'
44
- end
45
-
46
- if with_config('metal')
47
- $CFLAGS << ' -DGGML_USE_METAL'
48
- $CXXFLAGS << ' -DGGML_USE_METAL'
49
- $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
50
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-metal.o llama.o llama_cpp.o]
51
- end
52
-
53
- if with_config('cublas')
54
- $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
55
- $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
56
- $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
57
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-cuda.o llama.o llama_cpp.o]
58
- end
59
-
60
- if with_config('clblast')
61
- abort 'libclblast is not found.' unless have_library('clblast')
62
-
63
- $CFLAGS << ' -DGGML_USE_CLBLAST'
64
- $CXXFLAGS << ' -DGGML_USE_CLBLAST'
65
- if RUBY_PLATFORM.match?(/darwin/)
66
- $LDFLAGS << ' -framework OpenCL'
67
- else
68
- abort 'libOpenCL is not found.' unless have_library('OpenCL')
5
+ require 'open3'
6
+
7
+ VENDOR_DIR = File.expand_path("#{__dir__}/../../vendor")
8
+ VENDOR_LIB_DIR = "#{VENDOR_DIR}/lib"
9
+ VENDOR_INC_DIR = "#{VENDOR_DIR}/include"
10
+ LLAMA_CPP_DIR = "#{VENDOR_DIR}/tmp/llama.cpp"
11
+
12
+ make_envs = +''
13
+ make_envs << ' LLAMA_DEBUG=1' if with_config('debug')
14
+ make_envs << ' LLAMA_QKK_64=1' if with_config('qkk-64')
15
+ make_envs << ' LLAMA_NO_ACCELERATE=1' if with_config('no-accelerate')
16
+ make_envs << ' LLAMA_OPENBLAS=1' if with_config('openblas')
17
+ make_envs << ' LLAMA_BLIS=1' if with_config('blis')
18
+ make_envs << ' LLAMA_CUBLAS=1' if with_config('cublas')
19
+ make_envs << ' LLAMA_CLBLAST=1' if with_config('clblast')
20
+ make_envs << ' LLAMA_HIPBLAS=1' if with_config('hipblas')
21
+ make_envs << ' LLAMA_MPI=1' if with_config('mpi')
22
+
23
+ Dir.chdir(LLAMA_CPP_DIR) do
24
+ _mkstdout, _mkstderr, mkstatus = Open3.capture3("make lib #{make_envs}".strip)
25
+ abort('Failed to build llama.cpp.') unless mkstatus.success?
26
+
27
+ FileUtils.cp(Dir.glob('libllama.*'), VENDOR_LIB_DIR)
28
+ FileUtils.cp(Dir.glob('*.h'), "#{VENDOR_DIR}/include/")
29
+ end
30
+
31
+ if RUBY_PLATFORM.match?(/darwin/)
32
+ Dir.chdir(VENDOR_LIB_DIR) do
33
+ _mkstdout, _mkstderr, mkstatus = Open3.capture3("install_name_tool -id #{VENDOR_LIB_DIR}/libllama.dylib libllama.dylib")
34
+ abort('Failed to set installation path for libllama.dylib.') unless mkstatus.success?
35
+ FileUtils.cp("#{LLAMA_CPP_DIR}/ggml-metal.metal", VENDOR_LIB_DIR)
69
36
  end
70
37
  end
71
38
 
72
- if with_config('mpi')
73
- abort 'libmpi is not found.' unless have_library('mpi')
74
- abort 'mpi.h is not found.' unless have_header('mpi.h')
75
-
76
- $CFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
77
- $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
78
- end
79
-
80
- # @!visibility private
81
- UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
39
+ abort('libstdc++ is not found.') unless have_library('stdc++')
40
+ abort('libllama is not found.') unless find_library('llama', nil, VENDOR_LIB_DIR)
41
+ abort('llama.h is not found.') unless find_header('llama.h', nil, VENDOR_INC_DIR)
82
42
 
83
- # rubocop:disable Layout/LineLength
84
- if UNAME_M.match?(/x86_64|i686/) && try_compile('#include <stdio.h>', '-march=native -mtune=native')
85
- $CFLAGS << ' -march=native -mtune=native'
86
- $CXXFLAGS << ' -march=native -mtune=native'
87
- elsif UNAME_M.match?(/aarch64/) && try_compile('#include <stdio.h>', '-mcpu=native')
88
- $CFLAGS << ' -mcpu=native'
89
- $CXXFLAGS << ' -mcpu=native'
90
- elsif UNAME_M.match?(/armv6/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access')
91
- $CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
92
- $CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
93
- elsif UNAME_M.match?(/armv7/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations')
94
- $CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
95
- $CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
96
- elsif UNAME_M.match?(/armv8/) && try_compile('#include <stdio.h>', '-mfp16-format=ieee -mno-unaligned-access')
97
- $CFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
98
- $CXXFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
99
- end
100
- # rubocop:enable Layout/LineLength
43
+ $CXXFLAGS << ' -std=c++11'
101
44
 
102
45
  create_makefile('llama_cpp/llama_cpp')
103
-
104
- if with_config('cublas')
105
- File.open('Makefile', 'a') do |f|
106
- f.puts 'ggml-cuda.o: ggml-cuda.cu ggml-cuda.h'
107
- f.puts "\tnvcc -shared -Xcompiler -fPIC -arch=native -c -o $@ $<"
108
- end
109
- end
110
-
111
- if with_config('metal')
112
- File.open('Makefile', 'a') do |f|
113
- f.puts 'ggml-metal.o: ggml-metal.m ggml-metal.h'
114
- f.puts "\t$(CC) $(CFLAGS) -c $< -o $@"
115
- end
116
-
117
- metal_path = File.expand_path("#{__dir__}/src/ggml-metal.metal")
118
- dest_path = File.expand_path("#{__dir__}/../../lib/llama_cpp/")
119
- FileUtils.cp(metal_path, dest_path)
120
- end
@@ -64,6 +64,8 @@ public:
64
64
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
65
65
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
66
66
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
67
+ rb_define_method(rb_cLLaMABatch, "set_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_n_seq_id), 2);
68
+ rb_define_method(rb_cLLaMABatch, "get_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_n_seq_id), 1);
67
69
  rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
68
70
  rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
69
71
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
@@ -75,30 +77,30 @@ private:
75
77
 
76
78
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
77
79
  VALUE kw_args = Qnil;
78
- ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
80
+ ID kw_table[3] = { rb_intern("max_n_token"), rb_intern("n_embd"), rb_intern("max_n_seq") };
79
81
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
80
82
  rb_scan_args(argc, argv, ":", &kw_args);
81
83
  rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
82
84
 
83
85
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
84
- rb_raise(rb_eArgError, "n_tokens must be an integer");
86
+ rb_raise(rb_eArgError, "max_n_token must be an integer");
85
87
  return Qnil;
86
88
  }
87
89
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
88
- rb_raise(rb_eArgError, "embd must be an integer");
90
+ rb_raise(rb_eArgError, "n_embd must be an integer");
89
91
  return Qnil;
90
92
  }
91
93
  if (!RB_INTEGER_TYPE_P(kw_values[2])) {
92
- rb_raise(rb_eArgError, "n_seq_max must be an integer");
94
+ rb_raise(rb_eArgError, "max_n_seq must be an integer");
93
95
  return Qnil;
94
96
  }
95
97
 
96
- const int32_t n_tokens = NUM2INT(kw_values[0]);
97
- const int32_t embd = NUM2INT(kw_values[1]);
98
- const int32_t n_seq_max = NUM2INT(kw_values[2]);
98
+ const int32_t max_n_token = NUM2INT(kw_values[0]);
99
+ const int32_t n_embd = NUM2INT(kw_values[1]);
100
+ const int32_t max_n_seq = NUM2INT(kw_values[2]);
99
101
 
100
102
  LLaMABatchWrapper* ptr = get_llama_batch(self);
101
- ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
103
+ ptr->batch = llama_batch_init(max_n_token, n_embd, max_n_seq);
102
104
 
103
105
  return Qnil;
104
106
  }
@@ -155,8 +157,8 @@ private:
155
157
  static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
156
158
  LLaMABatchWrapper* ptr = get_llama_batch(self);
157
159
  const int32_t id = NUM2INT(idx);
158
- if (id < 0 || id >= ptr->batch.n_tokens) {
159
- rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
160
+ if (id < 0) {
161
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
160
162
  return Qnil;
161
163
  }
162
164
  ptr->batch.token[id] = NUM2INT(value);
@@ -166,8 +168,8 @@ private:
166
168
  static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
167
169
  LLaMABatchWrapper* ptr = get_llama_batch(self);
168
170
  const int32_t id = NUM2INT(idx);
169
- if (id < 0 || id >= ptr->batch.n_tokens) {
170
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
171
+ if (id < 0) {
172
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
171
173
  return Qnil;
172
174
  }
173
175
  return INT2NUM(ptr->batch.token[id]);
@@ -177,8 +179,8 @@ private:
177
179
  static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
178
180
  LLaMABatchWrapper* ptr = get_llama_batch(self);
179
181
  const int32_t id = NUM2INT(idx);
180
- if (id < 0 || id >= ptr->batch.n_tokens) {
181
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
182
+ if (id < 0) {
183
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
182
184
  return Qnil;
183
185
  }
184
186
  ptr->batch.pos[id] = NUM2INT(value);
@@ -188,24 +190,46 @@ private:
188
190
  static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
189
191
  LLaMABatchWrapper* ptr = get_llama_batch(self);
190
192
  const int32_t id = NUM2INT(idx);
191
- if (id < 0 || id >= ptr->batch.n_tokens) {
192
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
193
+ if (id < 0) {
194
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
193
195
  return Qnil;
194
196
  }
195
197
  return INT2NUM(ptr->batch.pos[id]);
196
198
  }
197
199
 
200
+ // n_seq_id
201
+ static VALUE _llama_batch_set_n_seq_id(VALUE self, VALUE idx, VALUE value) {
202
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
203
+ const int32_t id = NUM2INT(idx);
204
+ if (id < 0) {
205
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
206
+ return Qnil;
207
+ }
208
+ ptr->batch.n_seq_id[id] = NUM2INT(value);
209
+ return INT2NUM(ptr->batch.n_seq_id[id]);
210
+ }
211
+
212
+ static VALUE _llama_batch_get_n_seq_id(VALUE self, VALUE idx) {
213
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
214
+ const int32_t id = NUM2INT(idx);
215
+ if (id < 0) {
216
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
217
+ return Qnil;
218
+ }
219
+ return INT2NUM(ptr->batch.n_seq_id[id]);
220
+ }
221
+
198
222
  // seq_id
199
223
  static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
200
224
  LLaMABatchWrapper* ptr = get_llama_batch(self);
201
225
  const int32_t i = NUM2INT(i_);
202
- if (i < 0 || i >= ptr->batch.n_tokens) {
203
- rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
226
+ if (i < 0) {
227
+ rb_raise(rb_eArgError, "i must be greater or equal to 0");
204
228
  return Qnil;
205
229
  }
206
230
  const int32_t j = NUM2INT(j_);
207
- if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
208
- rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
231
+ if (j < 0) {
232
+ rb_raise(rb_eArgError, "j must be greater or equal to 0");
209
233
  return Qnil;
210
234
  }
211
235
  ptr->batch.seq_id[i][j] = NUM2INT(value);
@@ -215,13 +239,13 @@ private:
215
239
  static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
216
240
  LLaMABatchWrapper* ptr = get_llama_batch(self);
217
241
  const int32_t i = NUM2INT(i_);
218
- if (i < 0 || i >= ptr->batch.n_tokens) {
219
- rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
242
+ if (i < 0) {
243
+ rb_raise(rb_eArgError, "i must be greater or equal to 0");
220
244
  return Qnil;
221
245
  }
222
246
  const int32_t j = NUM2INT(j_);
223
- if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
224
- rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
247
+ if (j < 0) {
248
+ rb_raise(rb_eArgError, "j must be greater or equal to 0");
225
249
  return Qnil;
226
250
  }
227
251
  return INT2NUM(ptr->batch.seq_id[i][j]);
@@ -231,8 +255,8 @@ private:
231
255
  static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
232
256
  LLaMABatchWrapper* ptr = get_llama_batch(self);
233
257
  const int32_t id = NUM2INT(idx);
234
- if (id < 0 || id >= ptr->batch.n_tokens) {
235
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
258
+ if (id < 0) {
259
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
236
260
  return Qnil;
237
261
  }
238
262
  ptr->batch.logits[id] = RTEST(value) ? true : false;
@@ -242,8 +266,8 @@ private:
242
266
  static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
243
267
  LLaMABatchWrapper* ptr = get_llama_batch(self);
244
268
  const int32_t id = NUM2INT(idx);
245
- if (id < 0 || id >= ptr->batch.n_tokens) {
246
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
269
+ if (id < 0) {
270
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
247
271
  return Qnil;
248
272
  }
249
273
  return ptr->batch.logits[id] ? Qtrue : Qfalse;
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.10.3'
6
+ VERSION = '0.11.0'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b1710'
9
+ LLAMA_CPP_VERSION = 'b1768'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -149,7 +149,7 @@ module LLaMACpp
149
149
  class Batch
150
150
  public
151
151
 
152
- def initialize: (n_tokens: Integer, embd: Integer, n_seq_max: Integer) -> void
152
+ def initialize: (max_n_token: Integer, n_embd: Integer, max_n_seq: Integer) -> void
153
153
  def n_tokens=: (Integer) -> Integer
154
154
  def n_tokens: () -> Integer
155
155
  def all_pos_zero=: (Integer) -> Integer
@@ -162,6 +162,8 @@ module LLaMACpp
162
162
  def get_token: (Integer) -> Integer
163
163
  def set_pos: (Integer, Integer) -> Integer
164
164
  def get_pos: (Integer) -> Integer
165
+ def set_n_seq_id: (Integer, Integer) -> Integer
166
+ def get_n_seq_id: (Integer) -> Integer
165
167
  def set_seq_id: (Integer, Integer, Integer) -> Integer
166
168
  def get_seq_id: (Integer, Integer) -> Integer
167
169
  def set_logit: (Integer, bool) -> bool
File without changes
File without changes