llama_cpp 0.10.3 → 0.11.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/llama_cpp/extconf.rb +35 -110
  5. data/ext/llama_cpp/llama_cpp.cpp +52 -28
  6. data/lib/llama_cpp/version.rb +2 -2
  7. data/sig/llama_cpp.rbs +3 -1
  8. data/vendor/include/.gitkeep +0 -0
  9. data/vendor/lib/.gitkeep +0 -0
  10. data/vendor/tmp/llama.cpp/Makefile +758 -0
  11. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
  12. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
  13. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
  14. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
  15. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
  16. data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
  17. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
  18. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
  19. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
  20. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
  21. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
  22. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
  23. metadata +30 -27
  24. data/ext/llama_cpp/src/ggml-opencl.h +0 -25
  25. data/ext/llama_cpp/src/llama-util.h +0 -546
  26. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
  27. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
  28. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
  29. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
  30. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
  31. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
  32. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
  33. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
  34. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
  35. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
  36. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
  37. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e679eaf867f62033f7d586a8ef131f2126cb3efb2fde49af7c0be17492a66edf
4
- data.tar.gz: da1e9828c456677dc877db6b9754e961ceff27ecfc93c48abd7624d9bb8cdd29
3
+ metadata.gz: 58b6e91201c53b1ced4db60f325d3ced3fa486e24a84d53b0e5c62f613e33fc9
4
+ data.tar.gz: 7b1c4594a79c8ac86aef84be3608dbd51e397c8fe4226d65b3ee87aa1fc800b2
5
5
  SHA512:
6
- metadata.gz: b1fd0737acaa229493e2cbacc79f5b0b6b91233d40e26b57ab7005945ddba79ea3f44e2cca8a0d75df3695373f8eaa2fdfd4ff766a166a688c051beb2acfb126
7
- data.tar.gz: '01889a0ff9ebabd400fa374066659686ee84d4afab973cdd55b36ce5588bded1ed424a88296c1a26acc413f1e4f98f9f6e36eebaf7f37874b91a335dd147d3f4'
6
+ metadata.gz: aece2e7a49f08d0799ff6eb24904ef176fc916eeb57380916b2c8397ea3236991b52fd806aa8c76822a7c1beac86348f3ceb7094880c8d79015debc62babaa0c
7
+ data.tar.gz: 2049d26027e8be4e47bbbb12a9a521776c369ca45d05743dec3c96249a09fe67e31a21aa09dcb8d717f39ee29904ee082bcbfa292fd6c1e956d6e319809ca31c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,16 @@
1
+ ## [[0.11.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.11.0)] - 2024-01-07
2
+
3
+ - Add `set_n_seq_id` and `get_n_seq_id` methods to `Batch`.
4
+
5
+ **Breaking Changes**
6
+ - Change to build shared and static libraries of llama.cpp using its Makefile.
7
+ - Change keyword arguments of `Batch` constructor.
8
+ - Remove upper limit check for index value in `Batch` methods.
9
+
10
+ ## [[0.10.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.10.4)] - 2024-01-06
11
+
12
+ - Bump bundled llama.cpp from b1710 to b1768.
13
+
1
14
  ## [[0.10.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.2...v0.10.3)] - 2023-12-29
2
15
 
3
16
  - Bump bundled llama.cpp from b1686 to b1710.
data/LICENSE.txt CHANGED
@@ -1,6 +1,6 @@
1
1
  The MIT License (MIT)
2
2
 
3
- Copyright (c) 2023 Atsushi Tatsuma
3
+ Copyright (c) 2023-2024 Atsushi Tatsuma
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal
@@ -2,119 +2,44 @@
2
2
 
3
3
  require 'mkmf'
4
4
  require 'fileutils'
5
-
6
- abort 'libstdc++ is not found.' unless have_library('stdc++')
7
-
8
- $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c ggml-quants.c llama.cpp llama_cpp.cpp]
9
- $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
- $srcs << 'ggml-mpi.c' if with_config('mpi')
11
- $CFLAGS << ' -w -DNDEBUG'
12
- $CXXFLAGS << ' -std=c++11 -DNDEBUG'
13
- $INCFLAGS << ' -I$(srcdir)/src'
14
- $VPATH << '$(srcdir)/src'
15
-
16
- if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>', '-pthread')
17
- $CFLAGS << ' -pthread'
18
- $CXXFLAGS << ' -pthread'
19
- end
20
-
21
- if with_config('qkk_64')
22
- $CFLAGS << ' -DGGML_QKK_64'
23
- $CXXFLAGS << ' -DGGML_QKK_64'
24
- end
25
-
26
- if with_config('openblas')
27
- abort 'libopenblas is not found.' unless have_library('openblas')
28
- abort 'cblas.h is not found.' unless have_header('cblas.h')
29
-
30
- $CFLAGS << ' -DGGML_USE_OPENBLAS'
31
- end
32
-
33
- if with_config('blis')
34
- abort 'libblis is not found.' unless have_library('blis')
35
- abort 'cblas.h is not found.' unless have_header('cblas.h')
36
-
37
- $CFLAGS << ' -DGGML_USE_OPENBLAS'
38
- end
39
-
40
- if with_config('accelerate')
41
- abort 'Accelerate framework is not found.' unless have_framework('Accelerate')
42
-
43
- $CFLAGS << ' -DGGML_USE_ACCELERATE'
44
- end
45
-
46
- if with_config('metal')
47
- $CFLAGS << ' -DGGML_USE_METAL'
48
- $CXXFLAGS << ' -DGGML_USE_METAL'
49
- $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
50
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-metal.o llama.o llama_cpp.o]
51
- end
52
-
53
- if with_config('cublas')
54
- $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
55
- $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
56
- $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
57
- $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-cuda.o llama.o llama_cpp.o]
58
- end
59
-
60
- if with_config('clblast')
61
- abort 'libclblast is not found.' unless have_library('clblast')
62
-
63
- $CFLAGS << ' -DGGML_USE_CLBLAST'
64
- $CXXFLAGS << ' -DGGML_USE_CLBLAST'
65
- if RUBY_PLATFORM.match?(/darwin/)
66
- $LDFLAGS << ' -framework OpenCL'
67
- else
68
- abort 'libOpenCL is not found.' unless have_library('OpenCL')
5
+ require 'open3'
6
+
7
+ VENDOR_DIR = File.expand_path("#{__dir__}/../../vendor")
8
+ VENDOR_LIB_DIR = "#{VENDOR_DIR}/lib"
9
+ VENDOR_INC_DIR = "#{VENDOR_DIR}/include"
10
+ LLAMA_CPP_DIR = "#{VENDOR_DIR}/tmp/llama.cpp"
11
+
12
+ make_envs = +''
13
+ make_envs << ' LLAMA_DEBUG=1' if with_config('debug')
14
+ make_envs << ' LLAMA_QKK_64=1' if with_config('qkk-64')
15
+ make_envs << ' LLAMA_NO_ACCELERATE=1' if with_config('no-accelerate')
16
+ make_envs << ' LLAMA_OPENBLAS=1' if with_config('openblas')
17
+ make_envs << ' LLAMA_BLIS=1' if with_config('blis')
18
+ make_envs << ' LLAMA_CUBLAS=1' if with_config('cublas')
19
+ make_envs << ' LLAMA_CLBLAST=1' if with_config('clblast')
20
+ make_envs << ' LLAMA_HIPBLAS=1' if with_config('hipblas')
21
+ make_envs << ' LLAMA_MPI=1' if with_config('mpi')
22
+
23
+ Dir.chdir(LLAMA_CPP_DIR) do
24
+ _mkstdout, _mkstderr, mkstatus = Open3.capture3("make lib #{make_envs}".strip)
25
+ abort('Failed to build llama.cpp.') unless mkstatus.success?
26
+
27
+ FileUtils.cp(Dir.glob('libllama.*'), VENDOR_LIB_DIR)
28
+ FileUtils.cp(Dir.glob('*.h'), "#{VENDOR_DIR}/include/")
29
+ end
30
+
31
+ if RUBY_PLATFORM.match?(/darwin/)
32
+ Dir.chdir(VENDOR_LIB_DIR) do
33
+ _mkstdout, _mkstderr, mkstatus = Open3.capture3("install_name_tool -id #{VENDOR_LIB_DIR}/libllama.dylib libllama.dylib")
34
+ abort('Failed to set installation path for libllama.dylib.') unless mkstatus.success?
35
+ FileUtils.cp("#{LLAMA_CPP_DIR}/ggml-metal.metal", VENDOR_LIB_DIR)
69
36
  end
70
37
  end
71
38
 
72
- if with_config('mpi')
73
- abort 'libmpi is not found.' unless have_library('mpi')
74
- abort 'mpi.h is not found.' unless have_header('mpi.h')
75
-
76
- $CFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
77
- $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
78
- end
79
-
80
- # @!visibility private
81
- UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
39
+ abort('libstdc++ is not found.') unless have_library('stdc++')
40
+ abort('libllama is not found.') unless find_library('llama', nil, VENDOR_LIB_DIR)
41
+ abort('llama.h is not found.') unless find_header('llama.h', nil, VENDOR_INC_DIR)
82
42
 
83
- # rubocop:disable Layout/LineLength
84
- if UNAME_M.match?(/x86_64|i686/) && try_compile('#include <stdio.h>', '-march=native -mtune=native')
85
- $CFLAGS << ' -march=native -mtune=native'
86
- $CXXFLAGS << ' -march=native -mtune=native'
87
- elsif UNAME_M.match?(/aarch64/) && try_compile('#include <stdio.h>', '-mcpu=native')
88
- $CFLAGS << ' -mcpu=native'
89
- $CXXFLAGS << ' -mcpu=native'
90
- elsif UNAME_M.match?(/armv6/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access')
91
- $CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
92
- $CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
93
- elsif UNAME_M.match?(/armv7/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations')
94
- $CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
95
- $CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
96
- elsif UNAME_M.match?(/armv8/) && try_compile('#include <stdio.h>', '-mfp16-format=ieee -mno-unaligned-access')
97
- $CFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
98
- $CXXFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
99
- end
100
- # rubocop:enable Layout/LineLength
43
+ $CXXFLAGS << ' -std=c++11'
101
44
 
102
45
  create_makefile('llama_cpp/llama_cpp')
103
-
104
- if with_config('cublas')
105
- File.open('Makefile', 'a') do |f|
106
- f.puts 'ggml-cuda.o: ggml-cuda.cu ggml-cuda.h'
107
- f.puts "\tnvcc -shared -Xcompiler -fPIC -arch=native -c -o $@ $<"
108
- end
109
- end
110
-
111
- if with_config('metal')
112
- File.open('Makefile', 'a') do |f|
113
- f.puts 'ggml-metal.o: ggml-metal.m ggml-metal.h'
114
- f.puts "\t$(CC) $(CFLAGS) -c $< -o $@"
115
- end
116
-
117
- metal_path = File.expand_path("#{__dir__}/src/ggml-metal.metal")
118
- dest_path = File.expand_path("#{__dir__}/../../lib/llama_cpp/")
119
- FileUtils.cp(metal_path, dest_path)
120
- end
@@ -64,6 +64,8 @@ public:
64
64
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
65
65
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
66
66
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
67
+ rb_define_method(rb_cLLaMABatch, "set_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_n_seq_id), 2);
68
+ rb_define_method(rb_cLLaMABatch, "get_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_n_seq_id), 1);
67
69
  rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
68
70
  rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
69
71
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
@@ -75,30 +77,30 @@ private:
75
77
 
76
78
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
77
79
  VALUE kw_args = Qnil;
78
- ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
80
+ ID kw_table[3] = { rb_intern("max_n_token"), rb_intern("n_embd"), rb_intern("max_n_seq") };
79
81
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
80
82
  rb_scan_args(argc, argv, ":", &kw_args);
81
83
  rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
82
84
 
83
85
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
84
- rb_raise(rb_eArgError, "n_tokens must be an integer");
86
+ rb_raise(rb_eArgError, "max_n_token must be an integer");
85
87
  return Qnil;
86
88
  }
87
89
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
88
- rb_raise(rb_eArgError, "embd must be an integer");
90
+ rb_raise(rb_eArgError, "n_embd must be an integer");
89
91
  return Qnil;
90
92
  }
91
93
  if (!RB_INTEGER_TYPE_P(kw_values[2])) {
92
- rb_raise(rb_eArgError, "n_seq_max must be an integer");
94
+ rb_raise(rb_eArgError, "max_n_seq must be an integer");
93
95
  return Qnil;
94
96
  }
95
97
 
96
- const int32_t n_tokens = NUM2INT(kw_values[0]);
97
- const int32_t embd = NUM2INT(kw_values[1]);
98
- const int32_t n_seq_max = NUM2INT(kw_values[2]);
98
+ const int32_t max_n_token = NUM2INT(kw_values[0]);
99
+ const int32_t n_embd = NUM2INT(kw_values[1]);
100
+ const int32_t max_n_seq = NUM2INT(kw_values[2]);
99
101
 
100
102
  LLaMABatchWrapper* ptr = get_llama_batch(self);
101
- ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
103
+ ptr->batch = llama_batch_init(max_n_token, n_embd, max_n_seq);
102
104
 
103
105
  return Qnil;
104
106
  }
@@ -155,8 +157,8 @@ private:
155
157
  static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
156
158
  LLaMABatchWrapper* ptr = get_llama_batch(self);
157
159
  const int32_t id = NUM2INT(idx);
158
- if (id < 0 || id >= ptr->batch.n_tokens) {
159
- rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
160
+ if (id < 0) {
161
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
160
162
  return Qnil;
161
163
  }
162
164
  ptr->batch.token[id] = NUM2INT(value);
@@ -166,8 +168,8 @@ private:
166
168
  static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
167
169
  LLaMABatchWrapper* ptr = get_llama_batch(self);
168
170
  const int32_t id = NUM2INT(idx);
169
- if (id < 0 || id >= ptr->batch.n_tokens) {
170
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
171
+ if (id < 0) {
172
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
171
173
  return Qnil;
172
174
  }
173
175
  return INT2NUM(ptr->batch.token[id]);
@@ -177,8 +179,8 @@ private:
177
179
  static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
178
180
  LLaMABatchWrapper* ptr = get_llama_batch(self);
179
181
  const int32_t id = NUM2INT(idx);
180
- if (id < 0 || id >= ptr->batch.n_tokens) {
181
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
182
+ if (id < 0) {
183
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
182
184
  return Qnil;
183
185
  }
184
186
  ptr->batch.pos[id] = NUM2INT(value);
@@ -188,24 +190,46 @@ private:
188
190
  static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
189
191
  LLaMABatchWrapper* ptr = get_llama_batch(self);
190
192
  const int32_t id = NUM2INT(idx);
191
- if (id < 0 || id >= ptr->batch.n_tokens) {
192
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
193
+ if (id < 0) {
194
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
193
195
  return Qnil;
194
196
  }
195
197
  return INT2NUM(ptr->batch.pos[id]);
196
198
  }
197
199
 
200
+ // n_seq_id
201
+ static VALUE _llama_batch_set_n_seq_id(VALUE self, VALUE idx, VALUE value) {
202
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
203
+ const int32_t id = NUM2INT(idx);
204
+ if (id < 0) {
205
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
206
+ return Qnil;
207
+ }
208
+ ptr->batch.n_seq_id[id] = NUM2INT(value);
209
+ return INT2NUM(ptr->batch.n_seq_id[id]);
210
+ }
211
+
212
+ static VALUE _llama_batch_get_n_seq_id(VALUE self, VALUE idx) {
213
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
214
+ const int32_t id = NUM2INT(idx);
215
+ if (id < 0) {
216
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
217
+ return Qnil;
218
+ }
219
+ return INT2NUM(ptr->batch.n_seq_id[id]);
220
+ }
221
+
198
222
  // seq_id
199
223
  static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
200
224
  LLaMABatchWrapper* ptr = get_llama_batch(self);
201
225
  const int32_t i = NUM2INT(i_);
202
- if (i < 0 || i >= ptr->batch.n_tokens) {
203
- rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
226
+ if (i < 0) {
227
+ rb_raise(rb_eArgError, "i must be greater or equal to 0");
204
228
  return Qnil;
205
229
  }
206
230
  const int32_t j = NUM2INT(j_);
207
- if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
208
- rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
231
+ if (j < 0) {
232
+ rb_raise(rb_eArgError, "j must be greater or equal to 0");
209
233
  return Qnil;
210
234
  }
211
235
  ptr->batch.seq_id[i][j] = NUM2INT(value);
@@ -215,13 +239,13 @@ private:
215
239
  static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
216
240
  LLaMABatchWrapper* ptr = get_llama_batch(self);
217
241
  const int32_t i = NUM2INT(i_);
218
- if (i < 0 || i >= ptr->batch.n_tokens) {
219
- rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
242
+ if (i < 0) {
243
+ rb_raise(rb_eArgError, "i must be greater or equal to 0");
220
244
  return Qnil;
221
245
  }
222
246
  const int32_t j = NUM2INT(j_);
223
- if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
224
- rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
247
+ if (j < 0) {
248
+ rb_raise(rb_eArgError, "j must be greater or equal to 0");
225
249
  return Qnil;
226
250
  }
227
251
  return INT2NUM(ptr->batch.seq_id[i][j]);
@@ -231,8 +255,8 @@ private:
231
255
  static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
232
256
  LLaMABatchWrapper* ptr = get_llama_batch(self);
233
257
  const int32_t id = NUM2INT(idx);
234
- if (id < 0 || id >= ptr->batch.n_tokens) {
235
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
258
+ if (id < 0) {
259
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
236
260
  return Qnil;
237
261
  }
238
262
  ptr->batch.logits[id] = RTEST(value) ? true : false;
@@ -242,8 +266,8 @@ private:
242
266
  static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
243
267
  LLaMABatchWrapper* ptr = get_llama_batch(self);
244
268
  const int32_t id = NUM2INT(idx);
245
- if (id < 0 || id >= ptr->batch.n_tokens) {
246
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
269
+ if (id < 0) {
270
+ rb_raise(rb_eArgError, "id must be greater or equal to 0");
247
271
  return Qnil;
248
272
  }
249
273
  return ptr->batch.logits[id] ? Qtrue : Qfalse;
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.10.3'
6
+ VERSION = '0.11.0'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b1710'
9
+ LLAMA_CPP_VERSION = 'b1768'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -149,7 +149,7 @@ module LLaMACpp
149
149
  class Batch
150
150
  public
151
151
 
152
- def initialize: (n_tokens: Integer, embd: Integer, n_seq_max: Integer) -> void
152
+ def initialize: (max_n_token: Integer, n_embd: Integer, max_n_seq: Integer) -> void
153
153
  def n_tokens=: (Integer) -> Integer
154
154
  def n_tokens: () -> Integer
155
155
  def all_pos_zero=: (Integer) -> Integer
@@ -162,6 +162,8 @@ module LLaMACpp
162
162
  def get_token: (Integer) -> Integer
163
163
  def set_pos: (Integer, Integer) -> Integer
164
164
  def get_pos: (Integer) -> Integer
165
+ def set_n_seq_id: (Integer, Integer) -> Integer
166
+ def get_n_seq_id: (Integer) -> Integer
165
167
  def set_seq_id: (Integer, Integer, Integer) -> Integer
166
168
  def get_seq_id: (Integer, Integer) -> Integer
167
169
  def set_logit: (Integer, bool) -> bool
File without changes
File without changes