llama_cpp 0.10.4 → 0.11.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/llama_cpp/extconf.rb +35 -110
- data/ext/llama_cpp/llama_cpp.cpp +52 -28
- data/lib/llama_cpp/version.rb +1 -1
- data/sig/llama_cpp.rbs +3 -1
- data/vendor/include/.gitkeep +0 -0
- data/vendor/lib/.gitkeep +0 -0
- data/vendor/tmp/llama.cpp/Makefile +758 -0
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
- metadata +29 -26
- data/ext/llama_cpp/src/llama-util.h +0 -546
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 58b6e91201c53b1ced4db60f325d3ced3fa486e24a84d53b0e5c62f613e33fc9
|
4
|
+
data.tar.gz: 7b1c4594a79c8ac86aef84be3608dbd51e397c8fe4226d65b3ee87aa1fc800b2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aece2e7a49f08d0799ff6eb24904ef176fc916eeb57380916b2c8397ea3236991b52fd806aa8c76822a7c1beac86348f3ceb7094880c8d79015debc62babaa0c
|
7
|
+
data.tar.gz: 2049d26027e8be4e47bbbb12a9a521776c369ca45d05743dec3c96249a09fe67e31a21aa09dcb8d717f39ee29904ee082bcbfa292fd6c1e956d6e319809ca31c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## [[0.11.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.11.0)] - 2024-01-07
|
2
|
+
|
3
|
+
- Add `set_n_seq_id` and `get_n_seq_id` methods to `Batch`.
|
4
|
+
|
5
|
+
**Breaking Changes**
|
6
|
+
- Change to build shared and static libraries of llama.cpp using its Makefile.
|
7
|
+
- Change keyword arguments of `Batch` constructor.
|
8
|
+
- Remove upper limit check for index value in `Batch` methods.
|
9
|
+
|
1
10
|
## [[0.10.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.10.3...v0.10.4)] - 2024-01-06
|
2
11
|
|
3
12
|
- Bump bundled llama.cpp from b1710 to b1768.
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -2,119 +2,44 @@
|
|
2
2
|
|
3
3
|
require 'mkmf'
|
4
4
|
require 'fileutils'
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
if with_config('
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
$CFLAGS << ' -DGGML_USE_OPENBLAS'
|
38
|
-
end
|
39
|
-
|
40
|
-
if with_config('accelerate')
|
41
|
-
abort 'Accelerate framework is not found.' unless have_framework('Accelerate')
|
42
|
-
|
43
|
-
$CFLAGS << ' -DGGML_USE_ACCELERATE'
|
44
|
-
end
|
45
|
-
|
46
|
-
if with_config('metal')
|
47
|
-
$CFLAGS << ' -DGGML_USE_METAL'
|
48
|
-
$CXXFLAGS << ' -DGGML_USE_METAL'
|
49
|
-
$LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
50
|
-
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-metal.o llama.o llama_cpp.o]
|
51
|
-
end
|
52
|
-
|
53
|
-
if with_config('cublas')
|
54
|
-
$CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
55
|
-
$CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
56
|
-
$LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
|
57
|
-
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-cuda.o llama.o llama_cpp.o]
|
58
|
-
end
|
59
|
-
|
60
|
-
if with_config('clblast')
|
61
|
-
abort 'libclblast is not found.' unless have_library('clblast')
|
62
|
-
|
63
|
-
$CFLAGS << ' -DGGML_USE_CLBLAST'
|
64
|
-
$CXXFLAGS << ' -DGGML_USE_CLBLAST'
|
65
|
-
if RUBY_PLATFORM.match?(/darwin/)
|
66
|
-
$LDFLAGS << ' -framework OpenCL'
|
67
|
-
else
|
68
|
-
abort 'libOpenCL is not found.' unless have_library('OpenCL')
|
5
|
+
require 'open3'
|
6
|
+
|
7
|
+
VENDOR_DIR = File.expand_path("#{__dir__}/../../vendor")
|
8
|
+
VENDOR_LIB_DIR = "#{VENDOR_DIR}/lib"
|
9
|
+
VENDOR_INC_DIR = "#{VENDOR_DIR}/include"
|
10
|
+
LLAMA_CPP_DIR = "#{VENDOR_DIR}/tmp/llama.cpp"
|
11
|
+
|
12
|
+
make_envs = +''
|
13
|
+
make_envs << ' LLAMA_DEBUG=1' if with_config('debug')
|
14
|
+
make_envs << ' LLAMA_QKK_64=1' if with_config('qkk-64')
|
15
|
+
make_envs << ' LLAMA_NO_ACCELERATE=1' if with_config('no-accelerate')
|
16
|
+
make_envs << ' LLAMA_OPENBLAS=1' if with_config('openblas')
|
17
|
+
make_envs << ' LLAMA_BLIS=1' if with_config('blis')
|
18
|
+
make_envs << ' LLAMA_CUBLAS=1' if with_config('cublas')
|
19
|
+
make_envs << ' LLAMA_CLBLAST=1' if with_config('clblast')
|
20
|
+
make_envs << ' LLAMA_HIPBLAS=1' if with_config('hipblas')
|
21
|
+
make_envs << ' LLAMA_MPI=1' if with_config('mpi')
|
22
|
+
|
23
|
+
Dir.chdir(LLAMA_CPP_DIR) do
|
24
|
+
_mkstdout, _mkstderr, mkstatus = Open3.capture3("make lib #{make_envs}".strip)
|
25
|
+
abort('Failed to build llama.cpp.') unless mkstatus.success?
|
26
|
+
|
27
|
+
FileUtils.cp(Dir.glob('libllama.*'), VENDOR_LIB_DIR)
|
28
|
+
FileUtils.cp(Dir.glob('*.h'), "#{VENDOR_DIR}/include/")
|
29
|
+
end
|
30
|
+
|
31
|
+
if RUBY_PLATFORM.match?(/darwin/)
|
32
|
+
Dir.chdir(VENDOR_LIB_DIR) do
|
33
|
+
_mkstdout, _mkstderr, mkstatus = Open3.capture3("install_name_tool -id #{VENDOR_LIB_DIR}/libllama.dylib libllama.dylib")
|
34
|
+
abort('Failed to set installation path for libllama.dylib.') unless mkstatus.success?
|
35
|
+
FileUtils.cp("#{LLAMA_CPP_DIR}/ggml-metal.metal", VENDOR_LIB_DIR)
|
69
36
|
end
|
70
37
|
end
|
71
38
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
$CFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
|
77
|
-
$CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
|
78
|
-
end
|
79
|
-
|
80
|
-
# @!visibility private
|
81
|
-
UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
|
39
|
+
abort('libstdc++ is not found.') unless have_library('stdc++')
|
40
|
+
abort('libllama is not found.') unless find_library('llama', nil, VENDOR_LIB_DIR)
|
41
|
+
abort('llama.h is not found.') unless find_header('llama.h', nil, VENDOR_INC_DIR)
|
82
42
|
|
83
|
-
|
84
|
-
if UNAME_M.match?(/x86_64|i686/) && try_compile('#include <stdio.h>', '-march=native -mtune=native')
|
85
|
-
$CFLAGS << ' -march=native -mtune=native'
|
86
|
-
$CXXFLAGS << ' -march=native -mtune=native'
|
87
|
-
elsif UNAME_M.match?(/aarch64/) && try_compile('#include <stdio.h>', '-mcpu=native')
|
88
|
-
$CFLAGS << ' -mcpu=native'
|
89
|
-
$CXXFLAGS << ' -mcpu=native'
|
90
|
-
elsif UNAME_M.match?(/armv6/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access')
|
91
|
-
$CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
|
92
|
-
$CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access'
|
93
|
-
elsif UNAME_M.match?(/armv7/) && try_compile('#include <stdio.h>', '-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations')
|
94
|
-
$CFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
|
95
|
-
$CXXFLAGS << ' -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations'
|
96
|
-
elsif UNAME_M.match?(/armv8/) && try_compile('#include <stdio.h>', '-mfp16-format=ieee -mno-unaligned-access')
|
97
|
-
$CFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
|
98
|
-
$CXXFLAGS << ' -mfp16-format=ieee -mno-unaligned-access'
|
99
|
-
end
|
100
|
-
# rubocop:enable Layout/LineLength
|
43
|
+
$CXXFLAGS << ' -std=c++11'
|
101
44
|
|
102
45
|
create_makefile('llama_cpp/llama_cpp')
|
103
|
-
|
104
|
-
if with_config('cublas')
|
105
|
-
File.open('Makefile', 'a') do |f|
|
106
|
-
f.puts 'ggml-cuda.o: ggml-cuda.cu ggml-cuda.h'
|
107
|
-
f.puts "\tnvcc -shared -Xcompiler -fPIC -arch=native -c -o $@ $<"
|
108
|
-
end
|
109
|
-
end
|
110
|
-
|
111
|
-
if with_config('metal')
|
112
|
-
File.open('Makefile', 'a') do |f|
|
113
|
-
f.puts 'ggml-metal.o: ggml-metal.m ggml-metal.h'
|
114
|
-
f.puts "\t$(CC) $(CFLAGS) -c $< -o $@"
|
115
|
-
end
|
116
|
-
|
117
|
-
metal_path = File.expand_path("#{__dir__}/src/ggml-metal.metal")
|
118
|
-
dest_path = File.expand_path("#{__dir__}/../../lib/llama_cpp/")
|
119
|
-
FileUtils.cp(metal_path, dest_path)
|
120
|
-
end
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -64,6 +64,8 @@ public:
|
|
64
64
|
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
65
65
|
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
66
66
|
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "set_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_n_seq_id), 2);
|
68
|
+
rb_define_method(rb_cLLaMABatch, "get_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_n_seq_id), 1);
|
67
69
|
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
|
68
70
|
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
|
69
71
|
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
@@ -75,30 +77,30 @@ private:
|
|
75
77
|
|
76
78
|
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
77
79
|
VALUE kw_args = Qnil;
|
78
|
-
ID kw_table[3] = { rb_intern("
|
80
|
+
ID kw_table[3] = { rb_intern("max_n_token"), rb_intern("n_embd"), rb_intern("max_n_seq") };
|
79
81
|
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
80
82
|
rb_scan_args(argc, argv, ":", &kw_args);
|
81
83
|
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
82
84
|
|
83
85
|
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
84
|
-
rb_raise(rb_eArgError, "
|
86
|
+
rb_raise(rb_eArgError, "max_n_token must be an integer");
|
85
87
|
return Qnil;
|
86
88
|
}
|
87
89
|
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
88
|
-
rb_raise(rb_eArgError, "
|
90
|
+
rb_raise(rb_eArgError, "n_embd must be an integer");
|
89
91
|
return Qnil;
|
90
92
|
}
|
91
93
|
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
92
|
-
rb_raise(rb_eArgError, "
|
94
|
+
rb_raise(rb_eArgError, "max_n_seq must be an integer");
|
93
95
|
return Qnil;
|
94
96
|
}
|
95
97
|
|
96
|
-
const int32_t
|
97
|
-
const int32_t
|
98
|
-
const int32_t
|
98
|
+
const int32_t max_n_token = NUM2INT(kw_values[0]);
|
99
|
+
const int32_t n_embd = NUM2INT(kw_values[1]);
|
100
|
+
const int32_t max_n_seq = NUM2INT(kw_values[2]);
|
99
101
|
|
100
102
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
101
|
-
ptr->batch = llama_batch_init(
|
103
|
+
ptr->batch = llama_batch_init(max_n_token, n_embd, max_n_seq);
|
102
104
|
|
103
105
|
return Qnil;
|
104
106
|
}
|
@@ -155,8 +157,8 @@ private:
|
|
155
157
|
static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
|
156
158
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
157
159
|
const int32_t id = NUM2INT(idx);
|
158
|
-
if (id < 0
|
159
|
-
rb_raise(rb_eArgError, "
|
160
|
+
if (id < 0) {
|
161
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
160
162
|
return Qnil;
|
161
163
|
}
|
162
164
|
ptr->batch.token[id] = NUM2INT(value);
|
@@ -166,8 +168,8 @@ private:
|
|
166
168
|
static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
|
167
169
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
168
170
|
const int32_t id = NUM2INT(idx);
|
169
|
-
if (id < 0
|
170
|
-
rb_raise(rb_eArgError, "id must be
|
171
|
+
if (id < 0) {
|
172
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
171
173
|
return Qnil;
|
172
174
|
}
|
173
175
|
return INT2NUM(ptr->batch.token[id]);
|
@@ -177,8 +179,8 @@ private:
|
|
177
179
|
static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
|
178
180
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
179
181
|
const int32_t id = NUM2INT(idx);
|
180
|
-
if (id < 0
|
181
|
-
rb_raise(rb_eArgError, "id must be
|
182
|
+
if (id < 0) {
|
183
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
182
184
|
return Qnil;
|
183
185
|
}
|
184
186
|
ptr->batch.pos[id] = NUM2INT(value);
|
@@ -188,24 +190,46 @@ private:
|
|
188
190
|
static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
|
189
191
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
190
192
|
const int32_t id = NUM2INT(idx);
|
191
|
-
if (id < 0
|
192
|
-
rb_raise(rb_eArgError, "id must be
|
193
|
+
if (id < 0) {
|
194
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
193
195
|
return Qnil;
|
194
196
|
}
|
195
197
|
return INT2NUM(ptr->batch.pos[id]);
|
196
198
|
}
|
197
199
|
|
200
|
+
// n_seq_id
|
201
|
+
static VALUE _llama_batch_set_n_seq_id(VALUE self, VALUE idx, VALUE value) {
|
202
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
203
|
+
const int32_t id = NUM2INT(idx);
|
204
|
+
if (id < 0) {
|
205
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
206
|
+
return Qnil;
|
207
|
+
}
|
208
|
+
ptr->batch.n_seq_id[id] = NUM2INT(value);
|
209
|
+
return INT2NUM(ptr->batch.n_seq_id[id]);
|
210
|
+
}
|
211
|
+
|
212
|
+
static VALUE _llama_batch_get_n_seq_id(VALUE self, VALUE idx) {
|
213
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
214
|
+
const int32_t id = NUM2INT(idx);
|
215
|
+
if (id < 0) {
|
216
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
217
|
+
return Qnil;
|
218
|
+
}
|
219
|
+
return INT2NUM(ptr->batch.n_seq_id[id]);
|
220
|
+
}
|
221
|
+
|
198
222
|
// seq_id
|
199
223
|
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
|
200
224
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
201
225
|
const int32_t i = NUM2INT(i_);
|
202
|
-
if (i < 0
|
203
|
-
rb_raise(rb_eArgError, "i must be
|
226
|
+
if (i < 0) {
|
227
|
+
rb_raise(rb_eArgError, "i must be greater or equal to 0");
|
204
228
|
return Qnil;
|
205
229
|
}
|
206
230
|
const int32_t j = NUM2INT(j_);
|
207
|
-
if (j < 0
|
208
|
-
rb_raise(rb_eArgError, "j must be
|
231
|
+
if (j < 0) {
|
232
|
+
rb_raise(rb_eArgError, "j must be greater or equal to 0");
|
209
233
|
return Qnil;
|
210
234
|
}
|
211
235
|
ptr->batch.seq_id[i][j] = NUM2INT(value);
|
@@ -215,13 +239,13 @@ private:
|
|
215
239
|
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
|
216
240
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
217
241
|
const int32_t i = NUM2INT(i_);
|
218
|
-
if (i < 0
|
219
|
-
rb_raise(rb_eArgError, "i must be
|
242
|
+
if (i < 0) {
|
243
|
+
rb_raise(rb_eArgError, "i must be greater or equal to 0");
|
220
244
|
return Qnil;
|
221
245
|
}
|
222
246
|
const int32_t j = NUM2INT(j_);
|
223
|
-
if (j < 0
|
224
|
-
rb_raise(rb_eArgError, "j must be
|
247
|
+
if (j < 0) {
|
248
|
+
rb_raise(rb_eArgError, "j must be greater or equal to 0");
|
225
249
|
return Qnil;
|
226
250
|
}
|
227
251
|
return INT2NUM(ptr->batch.seq_id[i][j]);
|
@@ -231,8 +255,8 @@ private:
|
|
231
255
|
static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
|
232
256
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
233
257
|
const int32_t id = NUM2INT(idx);
|
234
|
-
if (id < 0
|
235
|
-
rb_raise(rb_eArgError, "id must be
|
258
|
+
if (id < 0) {
|
259
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
236
260
|
return Qnil;
|
237
261
|
}
|
238
262
|
ptr->batch.logits[id] = RTEST(value) ? true : false;
|
@@ -242,8 +266,8 @@ private:
|
|
242
266
|
static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
|
243
267
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
244
268
|
const int32_t id = NUM2INT(idx);
|
245
|
-
if (id < 0
|
246
|
-
rb_raise(rb_eArgError, "id must be
|
269
|
+
if (id < 0) {
|
270
|
+
rb_raise(rb_eArgError, "id must be greater or equal to 0");
|
247
271
|
return Qnil;
|
248
272
|
}
|
249
273
|
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
data/lib/llama_cpp/version.rb
CHANGED
data/sig/llama_cpp.rbs
CHANGED
@@ -149,7 +149,7 @@ module LLaMACpp
|
|
149
149
|
class Batch
|
150
150
|
public
|
151
151
|
|
152
|
-
def initialize: (
|
152
|
+
def initialize: (max_n_token: Integer, n_embd: Integer, max_n_seq: Integer) -> void
|
153
153
|
def n_tokens=: (Integer) -> Integer
|
154
154
|
def n_tokens: () -> Integer
|
155
155
|
def all_pos_zero=: (Integer) -> Integer
|
@@ -162,6 +162,8 @@ module LLaMACpp
|
|
162
162
|
def get_token: (Integer) -> Integer
|
163
163
|
def set_pos: (Integer, Integer) -> Integer
|
164
164
|
def get_pos: (Integer) -> Integer
|
165
|
+
def set_n_seq_id: (Integer, Integer) -> Integer
|
166
|
+
def get_n_seq_id: (Integer) -> Integer
|
165
167
|
def set_seq_id: (Integer, Integer, Integer) -> Integer
|
166
168
|
def get_seq_id: (Integer, Integer) -> Integer
|
167
169
|
def set_logit: (Integer, bool) -> bool
|
File without changes
|
data/vendor/lib/.gitkeep
ADDED
File without changes
|