rllama 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 108458213c9b89f4ecdc5a8874d0ac5c18a6c4a32b57439ee6839e50bd584e39
4
- data.tar.gz: 9abb5f179ca54740adb0321211283089ddcf46b7eefc156501419d174f46bc57
3
+ metadata.gz: e3ddb5865414a3b5393e2aab0eeaf8853ddce4a21df46038faf04dfe7fd1b4cd
4
+ data.tar.gz: cac2dae8787473817d3ddb8bcfbb1e97829a89e1ad7556d250a1fd3ce8f65819
5
5
  SHA512:
6
- metadata.gz: 624a9689bc2ddc6b0ac12c17ea94b3e1a7169907ea775d2f9f2f6706475d025dafda35e5dbfd79e7143a36687591c62d488819ea9170908843944b2dd1cff6f6
7
- data.tar.gz: 7ec05d5870c58f58f160045a272a4e07a8eba96101803846cf4e4689fde3c2fb2e64f61c2b36dfbcc0e18614d08aa6bedc39174a0b22d3ff1be9bce04ea33d12
6
+ metadata.gz: a111915a6be3bb92f1319f7abab57b21c6efcbef86dbd4e3c235b1d9ac7a5ce548b8a8a2d9e205e9847ab966f7509076bb1e46075333036bba29f0854e4d9e88
7
+ data.tar.gz: 6b6dc6ab6b76908e9f479d463ec1d952c9b78913d18c5bbb3f94e05a4d054e5f50228b762fb11f07faf806d6e185f7d8d5c065f6b3048b3f3594735bcdff7583
@@ -12,12 +12,21 @@ module Rllama
12
12
 
13
13
  @ctx_params = Cpp.llama_context_default_params
14
14
 
15
- @ctx_params[:n_ctx] = @n_ctx
16
- @ctx_params[:n_batch] = @n_batch
15
+ @ctx_params[:n_ctx] = @n_ctx if @n_ctx
16
+ @ctx_params[:n_batch] = @n_batch if @n_batch
17
17
 
18
18
  if @embeddings
19
- @ctx_params[:n_seq_max] = [@n_batch, @model.n_seq_max].min
19
+ seq_cap = @model.n_seq_max
20
+
21
+ if @n_batch&.positive? && seq_cap&.positive?
22
+ @ctx_params[:n_seq_max] = [@n_batch, seq_cap].min
23
+ elsif seq_cap&.positive?
24
+ @ctx_params[:n_seq_max] = seq_cap
25
+ end
26
+
20
27
  @ctx_params[:embeddings] = true
28
+ @ctx_params[:kv_unified] = true
29
+ @ctx_params[:n_ubatch] = @n_batch if @n_batch&.positive?
21
30
  end
22
31
 
23
32
  @pointer = Cpp.llama_init_from_model(model.pointer, @ctx_params)
@@ -141,19 +150,31 @@ module Rllama
141
150
  end
142
151
  alias message generate
143
152
 
144
- def embed(strings, normalize: true, batch_size: 512)
145
- is_array = strings.is_a?(Array)
153
+ def embed(strings_or_tokens, normalize: true, batch_size: 512)
154
+ is_tokens = strings_or_tokens.is_a?(Array) &&
155
+ (strings_or_tokens[0].is_a?(Integer) ||
156
+ (strings_or_tokens[0].is_a?(Array) && strings_or_tokens[0][0].is_a?(Integer)))
146
157
 
147
- strings = Array(strings) unless is_array
158
+ input_is_array = is_tokens ? strings_or_tokens[0].is_a?(Array) : strings_or_tokens.is_a?(Array)
148
159
 
149
- tokenized_strings = strings.map do |text|
150
- max_tokens = text.bytesize + 2
151
- tokens_ptr = FFI::MemoryPointer.new(:int32, max_tokens)
152
- count = Cpp.llama_tokenize(@model.vocab, text, text.bytesize, tokens_ptr, max_tokens, true, false)
160
+ normalized_inputs = input_is_array ? strings_or_tokens : [strings_or_tokens]
161
+
162
+ tokenized_strings =
163
+ if is_tokens
164
+ input_is_array ? strings_or_tokens : [strings_or_tokens]
165
+ else
166
+ normalized_inputs.map { |text| @model.tokenize(text) }
167
+ end
153
168
 
154
- raise Error, "Failed to tokenize text: '#{text}'" if count.negative?
169
+ max_tokens_in_prompt = tokenized_strings.map(&:length).max || 0
155
170
 
156
- tokens_ptr.read_array_of_int32(count)
171
+ if max_tokens_in_prompt > batch_size
172
+ raise Error, "batch_size (#{batch_size}) is smaller than the longest prompt (#{max_tokens_in_prompt} tokens)."
173
+ end
174
+
175
+ if max_tokens_in_prompt > @n_batch
176
+ raise Error, "Context n_batch (#{@n_batch}) is smaller than the longest " \
177
+ "prompt (#{max_tokens_in_prompt} tokens). Increase batch_size when calling embed."
157
178
  end
158
179
 
159
180
  all_embeddings = []
@@ -166,6 +187,9 @@ module Rllama
166
187
 
167
188
  batch[:n_tokens] = current_batch_token_count
168
189
 
190
+ memory_ptr = Cpp.llama_get_memory(@pointer)
191
+ Cpp.llama_memory_clear(memory_ptr, true) unless memory_ptr.null?
192
+
169
193
  raise Error, 'llama_decode failed' unless Cpp.llama_decode(@pointer, batch).zero?
170
194
 
171
195
  prompts_in_batch.each do |seq_id_in_batch|
@@ -179,7 +203,8 @@ module Rllama
179
203
  end
180
204
 
181
205
  current_batch_token_count = 0
182
- prompts_in_batch = []
206
+
207
+ prompts_in_batch.clear
183
208
  end
184
209
 
185
210
  tokenized_strings.each do |tokens|
@@ -207,7 +232,7 @@ module Rllama
207
232
 
208
233
  Cpp.llama_batch_free(batch)
209
234
 
210
- is_array ? all_embeddings : all_embeddings[0]
235
+ input_is_array ? all_embeddings : all_embeddings[0]
211
236
  end
212
237
 
213
238
  def embeddings?
data/lib/rllama/loader.rb CHANGED
@@ -62,6 +62,8 @@ module Rllama
62
62
 
63
63
  local_path = File.join(dir, org, repo, file_path)
64
64
 
65
+ return local_path if File.exist?(local_path)
66
+
65
67
  puts "Destination: #{local_path}"
66
68
 
67
69
  download_file(url, local_path, "HuggingFace model: #{hf_path}")
@@ -74,6 +76,8 @@ module Rllama
74
76
 
75
77
  local_path = File.join(dir, filename)
76
78
 
79
+ return local_path if File.exist?(local_path)
80
+
77
81
  puts "Destination: #{local_path}"
78
82
 
79
83
  download_file(url, local_path, "URL: #{url}")
@@ -82,8 +86,6 @@ module Rllama
82
86
  def download_file(url, local_path, description)
83
87
  FileUtils.mkdir_p(File.dirname(local_path))
84
88
 
85
- return local_path if File.exist?(local_path)
86
-
87
89
  temp_path = File.join(File.dirname(local_path), "~#{File.basename(local_path)}")
88
90
 
89
91
  existing_size = File.exist?(temp_path) ? File.size(temp_path) : 0
data/lib/rllama/model.rb CHANGED
@@ -47,11 +47,32 @@ module Rllama
47
47
  alias message generate
48
48
 
49
49
  def embed(prompt, normalize: true, batch_size: 512, &block)
50
- init_embedding_context do |ctx|
51
- ctx.embed(prompt, normalize:, batch_size:, &block)
50
+ inputs = prompt.is_a?(Array) ? prompt : [prompt]
51
+
52
+ tokenized_inputs = inputs.map { |text| tokenize(text, max_tokens: n_ctx_train) }
53
+ max_token_length = tokenized_inputs.map(&:length).max || 0
54
+
55
+ effective_batch_size = [batch_size, max_token_length].max
56
+ effective_ctx = [n_ctx_train, max_token_length].min
57
+
58
+ init_embedding_context(n_ctx: effective_ctx, n_batch: effective_batch_size) do |ctx|
59
+ inputs = prompt.is_a?(Array) ? tokenized_inputs : tokenized_inputs[0]
60
+
61
+ ctx.embed(inputs, normalize:, batch_size: effective_batch_size, &block)
52
62
  end
53
63
  end
54
64
 
65
+ def tokenize(text, max_tokens: nil)
66
+ size = text.bytesize + 2
67
+
68
+ tokens_ptr = FFI::MemoryPointer.new(:int32, size)
69
+ count = Cpp.llama_tokenize(vocab, text, text.bytesize, tokens_ptr, size, true, false)
70
+
71
+ raise Error, "Failed to tokenize text: '#{text}'" if count.negative?
72
+
73
+ tokens_ptr.read_array_of_int32([count, max_tokens].compact.min)
74
+ end
75
+
55
76
  def close
56
77
  Cpp.llama_model_free(@pointer)
57
78
  end
@@ -70,7 +91,7 @@ module Rllama
70
91
  context
71
92
  end
72
93
 
73
- def init_embedding_context(n_ctx: 2048, n_batch: 512, &)
94
+ def init_embedding_context(n_ctx: n_ctx_train, n_batch: 512, &)
74
95
  init_context(embeddings: true, n_ctx:, n_batch:, &)
75
96
  end
76
97
 
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Rllama
4
- VERSION = '1.0.0'
4
+ VERSION = '1.0.1'
5
5
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rllama
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0
4
+ version: 1.0.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Pete Matsyburka