ignis-dl 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-dl.rb +48 -0
- data/lib/nnw/ai/gpt2_loader.rb +144 -0
- data/lib/nnw/ai/inference.rb +224 -0
- data/lib/nnw/ai/kv_cache.rb +79 -0
- data/lib/nnw/ai/llama_loader.rb +100 -0
- data/lib/nnw/ai/loss.rb +170 -0
- data/lib/nnw/ai/nn/dropout.rb +68 -0
- data/lib/nnw/ai/nn/embedding.rb +86 -0
- data/lib/nnw/ai/nn/layer_norm.rb +54 -0
- data/lib/nnw/ai/nn/linear.rb +80 -0
- data/lib/nnw/ai/nn/module.rb +178 -0
- data/lib/nnw/ai/nn/rms_norm.rb +43 -0
- data/lib/nnw/ai/nn/sequential.rb +52 -0
- data/lib/nnw/ai/optim/adam.rb +63 -0
- data/lib/nnw/ai/optim/adamw.rb +63 -0
- data/lib/nnw/ai/optim/base.rb +90 -0
- data/lib/nnw/ai/optim/lr_scheduler.rb +118 -0
- data/lib/nnw/ai/optim/sgd.rb +49 -0
- data/lib/nnw/ai/safetensors.rb +220 -0
- data/lib/nnw/ai/server.rb +268 -0
- data/lib/nnw/ai/tokenizer.rb +413 -0
- data/lib/nnw/ai/trainer.rb +245 -0
- data/lib/nnw/ai/transformer/attention.rb +89 -0
- data/lib/nnw/ai/transformer/block.rb +90 -0
- data/lib/nnw/ai/transformer/feed_forward.rb +53 -0
- data/lib/nnw/ai/transformer/model.rb +189 -0
- data/lib/nnw/ai/transformer/modern.rb +191 -0
- data/lib/nnw/ai/transformer/swiglu.rb +39 -0
- data/lib/nnw/ai/weight_map.rb +139 -0
- metadata +91 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "securerandom" # used for completion IDs; was referenced but never required
|
|
5
|
+
|
|
6
|
+
module Ignis
|
|
7
|
+
module AI
|
|
8
|
+
# OpenAI-compatible inference server.
|
|
9
|
+
#
|
|
10
|
+
# Routes:
|
|
11
|
+
# POST /v1/completions — text completion (SSE streaming)
|
|
12
|
+
# POST /v1/chat/completions — chat completion (SSE streaming)
|
|
13
|
+
# GET /v1/models — list loaded models
|
|
14
|
+
# POST /v1/embeddings — compute embeddings
|
|
15
|
+
#
|
|
16
|
+
# Integrates with WNAIS HTTP/2 server if available, otherwise
|
|
17
|
+
# uses a simple TCP server for standalone operation.
|
|
18
|
+
class Server
|
|
19
|
+
# @return [TextGenerator]
|
|
20
|
+
attr_reader :generator
|
|
21
|
+
|
|
22
|
+
# @return [String] model name for API responses
|
|
23
|
+
attr_reader :model_name
|
|
24
|
+
|
|
25
|
+
# @param model [Transformer::Model]
|
|
26
|
+
# @param tokenizer [Tokenizer]
|
|
27
|
+
# @param model_name [String]
|
|
28
|
+
# @param host [String]
|
|
29
|
+
# @param port [Integer]
|
|
30
|
+
def initialize(model, tokenizer, model_name: "nnw-model", host: "0.0.0.0", port: 8080)
|
|
31
|
+
@generator = TextGenerator.new(model, tokenizer)
|
|
32
|
+
@model_name = model_name
|
|
33
|
+
@host = host
|
|
34
|
+
@port = port
|
|
35
|
+
@batch_processor = BatchProcessor.new(model, tokenizer)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Start the server.
|
|
39
|
+
# @return [void]
|
|
40
|
+
def start!
|
|
41
|
+
Ignis.logger.info("Ignis AI Server starting on #{@host}:#{@port}")
|
|
42
|
+
Ignis.logger.info("Model: #{@model_name} (#{@generator.model.num_parameters} params)")
|
|
43
|
+
|
|
44
|
+
@batch_processor.start!
|
|
45
|
+
|
|
46
|
+
if defined?(WNAIS::Server)
|
|
47
|
+
start_wnais_server!
|
|
48
|
+
else
|
|
49
|
+
start_tcp_server!
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
# Stop the server.
|
|
54
|
+
# @return [void]
|
|
55
|
+
def stop!
|
|
56
|
+
@batch_processor.stop!
|
|
57
|
+
@server_thread&.kill
|
|
58
|
+
Ignis.logger.info("Server stopped")
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
private
|
|
62
|
+
|
|
63
|
+
# Start using WNAIS HTTP/2 infrastructure.
|
|
64
|
+
def start_wnais_server!
|
|
65
|
+
Ignis.logger.info("Using WNAIS HTTP/2 transport")
|
|
66
|
+
# WNAIS integration point — register routes
|
|
67
|
+
if defined?(Ignis::Shared::EventBus)
|
|
68
|
+
Ignis::Shared::EventBus.subscribe(:http_request) do |request|
|
|
69
|
+
handle_request(request)
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Fallback TCP server for standalone use.
|
|
75
|
+
def start_tcp_server!
|
|
76
|
+
require "socket"
|
|
77
|
+
require "uri"
|
|
78
|
+
|
|
79
|
+
@server_thread = Thread.new do
|
|
80
|
+
server = TCPServer.new(@host, @port)
|
|
81
|
+
Ignis.logger.info("TCP server listening on #{@host}:#{@port}")
|
|
82
|
+
|
|
83
|
+
loop do
|
|
84
|
+
client = server.accept
|
|
85
|
+
Thread.new(client) do |sock|
|
|
86
|
+
begin
|
|
87
|
+
handle_tcp_client(sock)
|
|
88
|
+
rescue => e
|
|
89
|
+
Ignis.logger.error("Client error: #{e.message}")
|
|
90
|
+
ensure
|
|
91
|
+
sock.close
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Handle a TCP client connection.
|
|
99
|
+
# @param sock [TCPSocket]
|
|
100
|
+
def handle_tcp_client(sock)
|
|
101
|
+
request_line = sock.gets
|
|
102
|
+
return unless request_line
|
|
103
|
+
|
|
104
|
+
method, path, _version = request_line.strip.split(" ")
|
|
105
|
+
|
|
106
|
+
# Read headers
|
|
107
|
+
headers = {}
|
|
108
|
+
while (line = sock.gets) && line.strip != ""
|
|
109
|
+
key, value = line.strip.split(": ", 2)
|
|
110
|
+
headers[key.downcase] = value
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
# Read body
|
|
114
|
+
body = nil
|
|
115
|
+
if headers["content-length"]
|
|
116
|
+
body = sock.read(headers["content-length"].to_i)
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Route
|
|
120
|
+
response = route_request(method, path, body, headers)
|
|
121
|
+
|
|
122
|
+
# Send response
|
|
123
|
+
sock.print "HTTP/1.1 #{response[:status]}\r\n"
|
|
124
|
+
sock.print "Content-Type: application/json\r\n"
|
|
125
|
+
sock.print "Access-Control-Allow-Origin: *\r\n"
|
|
126
|
+
sock.print "Content-Length: #{response[:body].length}\r\n"
|
|
127
|
+
sock.print "\r\n"
|
|
128
|
+
sock.print response[:body]
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
# Route request to handler.
|
|
132
|
+
# @param method [String]
|
|
133
|
+
# @param path [String]
|
|
134
|
+
# @param body [String, nil]
|
|
135
|
+
# @param headers [Hash]
|
|
136
|
+
# @return [Hash] :status, :body
|
|
137
|
+
def route_request(method, path, body, headers)
|
|
138
|
+
case [method, path]
|
|
139
|
+
when ["POST", "/v1/completions"]
|
|
140
|
+
handle_completions(body)
|
|
141
|
+
when ["POST", "/v1/chat/completions"]
|
|
142
|
+
handle_chat_completions(body)
|
|
143
|
+
when ["GET", "/v1/models"]
|
|
144
|
+
handle_models
|
|
145
|
+
when ["POST", "/v1/embeddings"]
|
|
146
|
+
handle_embeddings(body)
|
|
147
|
+
else
|
|
148
|
+
{ status: "404 Not Found", body: JSON.generate({ error: "Not found" }) }
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# POST /v1/completions
|
|
153
|
+
def handle_completions(body)
|
|
154
|
+
params = JSON.parse(body)
|
|
155
|
+
prompt = params["prompt"] || ""
|
|
156
|
+
max_tokens = params["max_tokens"] || 128
|
|
157
|
+
temperature = params["temperature"] || 0.7
|
|
158
|
+
top_p = params["top_p"] || 0.9
|
|
159
|
+
|
|
160
|
+
text = @generator.generate(prompt,
|
|
161
|
+
max_tokens: max_tokens,
|
|
162
|
+
temperature: temperature,
|
|
163
|
+
top_p: top_p)
|
|
164
|
+
|
|
165
|
+
response = {
|
|
166
|
+
id: "cmpl-#{SecureRandom.hex(12)}",
|
|
167
|
+
object: "text_completion",
|
|
168
|
+
created: Time.now.to_i,
|
|
169
|
+
model: @model_name,
|
|
170
|
+
choices: [{
|
|
171
|
+
text: text[prompt.length..],
|
|
172
|
+
index: 0,
|
|
173
|
+
finish_reason: "stop"
|
|
174
|
+
}],
|
|
175
|
+
usage: {
|
|
176
|
+
prompt_tokens: @generator.tokenizer.encode(prompt).length,
|
|
177
|
+
completion_tokens: max_tokens,
|
|
178
|
+
total_tokens: @generator.tokenizer.encode(text).length
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
{ status: "200 OK", body: JSON.generate(response) }
|
|
183
|
+
rescue => e
|
|
184
|
+
{ status: "500 Internal Server Error", body: JSON.generate({ error: e.message }) }
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
# POST /v1/chat/completions
|
|
188
|
+
def handle_chat_completions(body)
|
|
189
|
+
params = JSON.parse(body)
|
|
190
|
+
messages = params["messages"] || []
|
|
191
|
+
|
|
192
|
+
# Convert chat messages to prompt string
|
|
193
|
+
prompt = messages.map { |m| "#{m['role']}: #{m['content']}" }.join("\n")
|
|
194
|
+
prompt += "\nassistant: "
|
|
195
|
+
|
|
196
|
+
max_tokens = params["max_tokens"] || 128
|
|
197
|
+
temperature = params["temperature"] || 0.7
|
|
198
|
+
|
|
199
|
+
text = @generator.generate(prompt,
|
|
200
|
+
max_tokens: max_tokens,
|
|
201
|
+
temperature: temperature)
|
|
202
|
+
response_text = text[prompt.length..]
|
|
203
|
+
|
|
204
|
+
response = {
|
|
205
|
+
id: "chatcmpl-#{SecureRandom.hex(12)}",
|
|
206
|
+
object: "chat.completion",
|
|
207
|
+
created: Time.now.to_i,
|
|
208
|
+
model: @model_name,
|
|
209
|
+
choices: [{
|
|
210
|
+
index: 0,
|
|
211
|
+
message: { role: "assistant", content: response_text },
|
|
212
|
+
finish_reason: "stop"
|
|
213
|
+
}],
|
|
214
|
+
usage: {
|
|
215
|
+
prompt_tokens: @generator.tokenizer.encode(prompt).length,
|
|
216
|
+
completion_tokens: @generator.tokenizer.encode(response_text).length,
|
|
217
|
+
total_tokens: @generator.tokenizer.encode(text).length
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
{ status: "200 OK", body: JSON.generate(response) }
|
|
222
|
+
rescue => e
|
|
223
|
+
{ status: "500 Internal Server Error", body: JSON.generate({ error: e.message }) }
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
# GET /v1/models
|
|
227
|
+
def handle_models
|
|
228
|
+
response = {
|
|
229
|
+
object: "list",
|
|
230
|
+
data: [{
|
|
231
|
+
id: @model_name,
|
|
232
|
+
object: "model",
|
|
233
|
+
created: Time.now.to_i,
|
|
234
|
+
owned_by: "nnw",
|
|
235
|
+
permission: [],
|
|
236
|
+
root: @model_name,
|
|
237
|
+
parent: nil
|
|
238
|
+
}]
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
{ status: "200 OK", body: JSON.generate(response) }
|
|
242
|
+
end
|
|
243
|
+
|
|
244
|
+
# POST /v1/embeddings
|
|
245
|
+
def handle_embeddings(body)
|
|
246
|
+
params = JSON.parse(body)
|
|
247
|
+
input = params["input"]
|
|
248
|
+
inputs = input.is_a?(Array) ? input : [input]
|
|
249
|
+
|
|
250
|
+
embeddings = inputs.map.with_index do |text, i|
|
|
251
|
+
emb = @generator.embed(text)
|
|
252
|
+
{ object: "embedding", embedding: emb, index: i }
|
|
253
|
+
end
|
|
254
|
+
|
|
255
|
+
response = {
|
|
256
|
+
object: "list",
|
|
257
|
+
data: embeddings,
|
|
258
|
+
model: @model_name,
|
|
259
|
+
usage: { prompt_tokens: inputs.sum { |t| @generator.tokenizer.encode(t).length }, total_tokens: 0 }
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
{ status: "200 OK", body: JSON.generate(response) }
|
|
263
|
+
rescue => e
|
|
264
|
+
{ status: "500 Internal Server Error", body: JSON.generate({ error: e.message }) }
|
|
265
|
+
end
|
|
266
|
+
end
|
|
267
|
+
end
|
|
268
|
+
end
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "set"
|
|
5
|
+
|
|
6
|
+
module Ignis
|
|
7
|
+
module AI
|
|
8
|
+
# Hybrid tokenizer: tries native HuggingFace tokenizers DLL first,
|
|
9
|
+
# falls back to pure Ruby BPE. Best of both worlds.
|
|
10
|
+
#
|
|
11
|
+
# Native DLL: tokenizers_ruby.dll / libtokenizers.so / libtokenizers.dylib
|
|
12
|
+
# Search paths: model dir, Ignis lib root, system PATH, vcpkg
|
|
13
|
+
class Tokenizer
|
|
14
|
+
# @return [Boolean] whether using native backend
|
|
15
|
+
attr_reader :native_backend
|
|
16
|
+
# @return [Integer] vocabulary size
|
|
17
|
+
attr_reader :vocab_size
|
|
18
|
+
|
|
19
|
+
# @return [Hash{String => Integer}] token to id
|
|
20
|
+
attr_reader :token_to_id
|
|
21
|
+
|
|
22
|
+
# @return [Hash{Integer => String}] id to token
|
|
23
|
+
attr_reader :id_to_token
|
|
24
|
+
|
|
25
|
+
# @return [Hash{String => Integer}] special tokens
|
|
26
|
+
attr_reader :special_tokens
|
|
27
|
+
|
|
28
|
+
# @return [Set<Integer>] ids of special tokens (used by inference EOS checks)
|
|
29
|
+
attr_reader :special_token_ids
|
|
30
|
+
|
|
31
|
+
# @param config_path [String] path to tokenizer.json
|
|
32
|
+
def initialize(config_path)
|
|
33
|
+
raise ArgumentError, "tokenizer.json not found: #{config_path}" unless File.exist?(config_path)
|
|
34
|
+
|
|
35
|
+
@config_path = config_path
|
|
36
|
+
@native_backend = false
|
|
37
|
+
@native_handle = nil
|
|
38
|
+
|
|
39
|
+
# Try native DLL first
|
|
40
|
+
native_dll = find_native_dll(File.dirname(config_path))
|
|
41
|
+
if native_dll
|
|
42
|
+
begin
|
|
43
|
+
load_native_backend(native_dll, config_path)
|
|
44
|
+
@native_backend = true
|
|
45
|
+
Ignis.logger.info("Tokenizer: using native backend (#{File.basename(native_dll)})")
|
|
46
|
+
rescue => e
|
|
47
|
+
Ignis.logger.warn("Tokenizer: native backend failed (#{e.message}), falling back to Ruby BPE")
|
|
48
|
+
@native_backend = false
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Always load Ruby config for metadata (special tokens, vocab size)
|
|
53
|
+
# even when native backend is in use
|
|
54
|
+
config = JSON.parse(File.read(config_path))
|
|
55
|
+
load_from_config(config)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Load from HuggingFace model directory.
|
|
59
|
+
# @param dir [String] directory containing tokenizer.json
|
|
60
|
+
# @return [Tokenizer]
|
|
61
|
+
def self.from_pretrained(dir)
|
|
62
|
+
path = File.join(dir, "tokenizer.json")
|
|
63
|
+
new(path)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Encode text to token ids.
|
|
67
|
+
# Uses native backend if available, else pure Ruby BPE.
|
|
68
|
+
# @param text [String]
|
|
69
|
+
# @param add_special_tokens [Boolean]
|
|
70
|
+
# @return [Array<Integer>]
|
|
71
|
+
def encode(text, add_special_tokens: true)
|
|
72
|
+
if @native_backend && @native_encode
|
|
73
|
+
native_encode(text, add_special_tokens)
|
|
74
|
+
else
|
|
75
|
+
ruby_encode(text, add_special_tokens)
|
|
76
|
+
end
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
# Decode token ids to text.
|
|
80
|
+
# @param ids [Array<Integer>]
|
|
81
|
+
# @param skip_special_tokens [Boolean]
|
|
82
|
+
# @return [String]
|
|
83
|
+
def decode(ids, skip_special_tokens: true)
|
|
84
|
+
tokens = ids.filter_map do |id|
|
|
85
|
+
next nil if skip_special_tokens && @special_token_ids.include?(id)
|
|
86
|
+
@id_to_token[id]
|
|
87
|
+
end
|
|
88
|
+
text = tokens.join("")
|
|
89
|
+
@byte_level ? decode_byte_level(text) : text
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
# Batch encode.
|
|
93
|
+
# @param texts [Array<String>]
|
|
94
|
+
# @param add_special_tokens [Boolean]
|
|
95
|
+
# @return [Array<Array<Integer>>]
|
|
96
|
+
def encode_batch(texts, add_special_tokens: true)
|
|
97
|
+
texts.map { |t| encode(t, add_special_tokens: add_special_tokens) }
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
# Batch decode.
|
|
101
|
+
# @param id_sequences [Array<Array<Integer>>]
|
|
102
|
+
# @param skip_special_tokens [Boolean]
|
|
103
|
+
# @return [Array<String>]
|
|
104
|
+
def decode_batch(id_sequences, skip_special_tokens: true)
|
|
105
|
+
id_sequences.map { |ids| decode(ids, skip_special_tokens: skip_special_tokens) }
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
# Encode and return GPU Tensor (int32).
|
|
109
|
+
# @param text [String]
|
|
110
|
+
# @param device_id [Integer]
|
|
111
|
+
# @return [Tensor]
|
|
112
|
+
def encode_to_tensor(text, device_id: 0)
|
|
113
|
+
ids = encode(text)
|
|
114
|
+
nv = Ignis::Shared::NvArray.new(shape: [1, ids.length], dtype: :int32, device_id: device_id)
|
|
115
|
+
nv.from_host(ids)
|
|
116
|
+
Tensor.new(data: nv, requires_grad: false)
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
private
|
|
120
|
+
|
|
121
|
+
# Load from tokenizer.json.
|
|
122
|
+
# @param config [Hash]
|
|
123
|
+
def load_from_config(config)
|
|
124
|
+
model = config["model"] || {}
|
|
125
|
+
@token_to_id = {}
|
|
126
|
+
@id_to_token = {}
|
|
127
|
+
@merges = []
|
|
128
|
+
@merge_ranks = {}
|
|
129
|
+
@byte_level = false
|
|
130
|
+
|
|
131
|
+
# Vocab
|
|
132
|
+
(model["vocab"] || {}).each do |token, id|
|
|
133
|
+
@token_to_id[token] = id
|
|
134
|
+
@id_to_token[id] = token
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Merges
|
|
138
|
+
(model["merges"] || []).each_with_index do |merge_str, rank|
|
|
139
|
+
pair = merge_str.split(" ", 2)
|
|
140
|
+
@merges << pair
|
|
141
|
+
@merge_ranks[pair.join(" ")] = rank
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
@vocab_size = @token_to_id.size
|
|
145
|
+
|
|
146
|
+
# Special tokens
|
|
147
|
+
@special_tokens = {}
|
|
148
|
+
@special_token_ids = Set.new
|
|
149
|
+
(config["added_tokens"] || []).each do |info|
|
|
150
|
+
content = info["content"]
|
|
151
|
+
id = info["id"]
|
|
152
|
+
if info["special"]
|
|
153
|
+
@special_tokens[content] = id
|
|
154
|
+
@special_token_ids << id
|
|
155
|
+
end
|
|
156
|
+
@token_to_id[content] = id
|
|
157
|
+
@id_to_token[id] = content
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
# Detect byte-level pre-tokenizer
|
|
161
|
+
pre_tok = config["pre_tokenizer"]
|
|
162
|
+
if pre_tok
|
|
163
|
+
@byte_level = (pre_tok["type"] == "ByteLevel") ||
|
|
164
|
+
(pre_tok.dig("pretokenizers", 0, "type") == "ByteLevel")
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
# Build byte-level lookup tables (GPT-2 style)
|
|
168
|
+
build_byte_encoder if @byte_level
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
# Ruby BPE encode (fallback when no native DLL).
|
|
172
|
+
# @param text [String]
|
|
173
|
+
# @param add_special_tokens [Boolean]
|
|
174
|
+
# @return [Array<Integer>]
|
|
175
|
+
def ruby_encode(text, add_special_tokens)
|
|
176
|
+
words = pre_tokenize(text)
|
|
177
|
+
ids = []
|
|
178
|
+
words.each do |word|
|
|
179
|
+
ids.concat(bpe_encode(word))
|
|
180
|
+
end
|
|
181
|
+
ids
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
# GPT-2 regex-style pre-tokenization: split on word boundaries.
|
|
185
|
+
# @param text [String]
|
|
186
|
+
# @return [Array<String>]
|
|
187
|
+
def pre_tokenize(text)
|
|
188
|
+
if @byte_level
|
|
189
|
+
pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+/
|
|
190
|
+
tokens = text.scan(pattern)
|
|
191
|
+
tokens.map { |word| encode_byte_level(word) }
|
|
192
|
+
else
|
|
193
|
+
text.split(/(\s+)/).reject(&:empty?)
|
|
194
|
+
end
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
# BPE encode a single pre-tokenized word.
|
|
198
|
+
# @param word [String]
|
|
199
|
+
# @return [Array<Integer>]
|
|
200
|
+
def bpe_encode(word)
|
|
201
|
+
chars = word.chars
|
|
202
|
+
|
|
203
|
+
if chars.length == 1
|
|
204
|
+
id = @token_to_id[chars[0]]
|
|
205
|
+
return id ? [id] : encode_unknown(chars[0])
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
pairs = get_pairs(chars)
|
|
209
|
+
loop do
|
|
210
|
+
best_pair = nil
|
|
211
|
+
best_rank = Float::INFINITY
|
|
212
|
+
|
|
213
|
+
pairs.each do |pair|
|
|
214
|
+
key = pair.join(" ")
|
|
215
|
+
rank = @merge_ranks[key]
|
|
216
|
+
if rank && rank < best_rank
|
|
217
|
+
best_rank = rank
|
|
218
|
+
best_pair = pair
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
|
|
222
|
+
break unless best_pair
|
|
223
|
+
|
|
224
|
+
chars = apply_merge(chars, best_pair)
|
|
225
|
+
break if chars.length == 1
|
|
226
|
+
|
|
227
|
+
pairs = get_pairs(chars)
|
|
228
|
+
end
|
|
229
|
+
|
|
230
|
+
chars.map do |token|
|
|
231
|
+
id = @token_to_id[token]
|
|
232
|
+
id || encode_unknown(token)
|
|
233
|
+
end.flatten
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
# @param tokens [Array<String>]
|
|
237
|
+
# @return [Array<Array<String>>]
|
|
238
|
+
def get_pairs(tokens)
|
|
239
|
+
pairs = []
|
|
240
|
+
(0...tokens.length - 1).each do |i|
|
|
241
|
+
pairs << [tokens[i], tokens[i + 1]]
|
|
242
|
+
end
|
|
243
|
+
pairs.uniq
|
|
244
|
+
end
|
|
245
|
+
|
|
246
|
+
# @param tokens [Array<String>]
|
|
247
|
+
# @param pair [Array<String>]
|
|
248
|
+
# @return [Array<String>]
|
|
249
|
+
def apply_merge(tokens, pair)
|
|
250
|
+
merged = pair.join("")
|
|
251
|
+
result = []
|
|
252
|
+
i = 0
|
|
253
|
+
while i < tokens.length
|
|
254
|
+
if i < tokens.length - 1 && tokens[i] == pair[0] && tokens[i + 1] == pair[1]
|
|
255
|
+
result << merged
|
|
256
|
+
i += 2
|
|
257
|
+
else
|
|
258
|
+
result << tokens[i]
|
|
259
|
+
i += 1
|
|
260
|
+
end
|
|
261
|
+
end
|
|
262
|
+
result
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
# @param token [String]
|
|
266
|
+
# @return [Array<Integer>]
|
|
267
|
+
def encode_unknown(token)
|
|
268
|
+
token.bytes.filter_map { |b| @token_to_id[@byte_encoder[b]] if @byte_encoder }
|
|
269
|
+
end
|
|
270
|
+
|
|
271
|
+
def build_byte_encoder
|
|
272
|
+
@byte_encoder = {}
|
|
273
|
+
@byte_decoder = {}
|
|
274
|
+
|
|
275
|
+
bs = (33..126).to_a + (161..172).to_a + (174..255).to_a
|
|
276
|
+
cs = bs.dup
|
|
277
|
+
n = 0
|
|
278
|
+
(0..255).each do |b|
|
|
279
|
+
unless bs.include?(b)
|
|
280
|
+
bs << b
|
|
281
|
+
cs << 256 + n
|
|
282
|
+
n += 1
|
|
283
|
+
end
|
|
284
|
+
end
|
|
285
|
+
|
|
286
|
+
bs.zip(cs).each do |b, c|
|
|
287
|
+
@byte_encoder[b] = c.chr(Encoding::UTF_8)
|
|
288
|
+
@byte_decoder[c.chr(Encoding::UTF_8)] = b
|
|
289
|
+
end
|
|
290
|
+
end
|
|
291
|
+
|
|
292
|
+
# @param text [String]
|
|
293
|
+
# @return [String]
|
|
294
|
+
def encode_byte_level(text)
|
|
295
|
+
text.bytes.map { |b| @byte_encoder[b] }.join("")
|
|
296
|
+
end
|
|
297
|
+
|
|
298
|
+
# @param text [String]
|
|
299
|
+
# @return [String]
|
|
300
|
+
def decode_byte_level(text)
|
|
301
|
+
bytes = text.chars.map { |c| @byte_decoder[c] || c.ord }.pack("C*")
|
|
302
|
+
bytes.force_encoding(Encoding::UTF_8)
|
|
303
|
+
end
|
|
304
|
+
|
|
305
|
+
# -------------------------------------------------------------------
|
|
306
|
+
# Native DLL Backend (HuggingFace tokenizers compiled from Rust)
|
|
307
|
+
# -------------------------------------------------------------------
|
|
308
|
+
|
|
309
|
+
# Search for native tokenizers library.
|
|
310
|
+
# @param model_dir [String] directory containing the model
|
|
311
|
+
# @return [String, nil] path to DLL, or nil
|
|
312
|
+
def find_native_dll(model_dir)
|
|
313
|
+
candidates = native_dll_candidates(model_dir)
|
|
314
|
+
candidates.find { |path| File.exist?(path) }
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
# Generate list of candidate DLL paths by platform.
|
|
318
|
+
# @param model_dir [String]
|
|
319
|
+
# @return [Array<String>]
|
|
320
|
+
def native_dll_candidates(model_dir)
|
|
321
|
+
paths = []
|
|
322
|
+
|
|
323
|
+
if Ignis::Platform.windows?
|
|
324
|
+
dll_name = "tokenizers_ruby.dll"
|
|
325
|
+
paths << File.join(model_dir, dll_name)
|
|
326
|
+
paths << File.join(__dir__, "..", "..", "..", "ext", dll_name)
|
|
327
|
+
paths << File.join(ENV["USERPROFILE"] || "C:\\Users\\#{ENV['USERNAME']}", "vcpkg", "installed", "x64-windows", "bin", dll_name)
|
|
328
|
+
# Check PATH
|
|
329
|
+
(ENV["PATH"] || "").split(";").each do |dir|
|
|
330
|
+
paths << File.join(dir, dll_name)
|
|
331
|
+
end
|
|
332
|
+
else
|
|
333
|
+
# Linux / macOS
|
|
334
|
+
ext = Ignis::Platform.macos? ? "dylib" : "so"
|
|
335
|
+
lib_name = "libtokenizers.#{ext}"
|
|
336
|
+
paths << File.join(model_dir, lib_name)
|
|
337
|
+
paths << File.join(__dir__, "..", "..", "..", "ext", lib_name)
|
|
338
|
+
paths << "/usr/local/lib/#{lib_name}"
|
|
339
|
+
paths << "/usr/lib/#{lib_name}"
|
|
340
|
+
(ENV["LD_LIBRARY_PATH"] || "").split(":").each do |dir|
|
|
341
|
+
paths << File.join(dir, lib_name)
|
|
342
|
+
end
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
paths
|
|
346
|
+
end
|
|
347
|
+
|
|
348
|
+
# Load native backend via Fiddle.
|
|
349
|
+
# @param dll_path [String]
|
|
350
|
+
# @param config_path [String]
|
|
351
|
+
# @return [void]
|
|
352
|
+
def load_native_backend(dll_path, config_path)
|
|
353
|
+
require "fiddle"
|
|
354
|
+
require "fiddle/import"
|
|
355
|
+
|
|
356
|
+
@native_lib = Fiddle.dlopen(dll_path)
|
|
357
|
+
|
|
358
|
+
# Bind tokenizer_from_file(path) → handle
|
|
359
|
+
@native_from_file = Fiddle::Function.new(
|
|
360
|
+
@native_lib["tokenizer_from_file"],
|
|
361
|
+
[Fiddle::TYPE_VOIDP],
|
|
362
|
+
Fiddle::TYPE_VOIDP
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Bind tokenizer_encode(handle, text, len, add_special, out_ids, out_len) → int
|
|
366
|
+
@native_encode_fn = Fiddle::Function.new(
|
|
367
|
+
@native_lib["tokenizer_encode"],
|
|
368
|
+
[Fiddle::TYPE_VOIDP, Fiddle::TYPE_VOIDP, Fiddle::TYPE_INT,
|
|
369
|
+
Fiddle::TYPE_INT, Fiddle::TYPE_VOIDP, Fiddle::TYPE_VOIDP],
|
|
370
|
+
Fiddle::TYPE_INT
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Bind tokenizer_decode(handle, ids, len, out_text, out_len) → int
|
|
374
|
+
@native_decode_fn = Fiddle::Function.new(
|
|
375
|
+
@native_lib["tokenizer_decode"],
|
|
376
|
+
[Fiddle::TYPE_VOIDP, Fiddle::TYPE_VOIDP, Fiddle::TYPE_INT,
|
|
377
|
+
Fiddle::TYPE_VOIDP, Fiddle::TYPE_VOIDP],
|
|
378
|
+
Fiddle::TYPE_INT
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Load tokenizer
|
|
382
|
+
config_ptr = Fiddle::Pointer.to_ptr(config_path)
|
|
383
|
+
@native_handle = @native_from_file.call(config_ptr)
|
|
384
|
+
raise "Failed to load tokenizer from #{config_path}" if @native_handle.null?
|
|
385
|
+
|
|
386
|
+
@native_encode = true
|
|
387
|
+
@native_decode = true
|
|
388
|
+
end
|
|
389
|
+
|
|
390
|
+
# Encode via native backend.
|
|
391
|
+
# @param text [String]
|
|
392
|
+
# @param add_special_tokens [Boolean]
|
|
393
|
+
# @return [Array<Integer>]
|
|
394
|
+
def native_encode(text, add_special_tokens)
|
|
395
|
+
max_tokens = text.length * 4 # Upper bound
|
|
396
|
+
out_ids = Fiddle::Pointer.malloc(max_tokens * 4, Fiddle::RUBY_FREE)
|
|
397
|
+
out_len = Fiddle::Pointer.malloc(4, Fiddle::RUBY_FREE)
|
|
398
|
+
|
|
399
|
+
text_ptr = Fiddle::Pointer.to_ptr(text)
|
|
400
|
+
status = @native_encode_fn.call(@native_handle, text_ptr, text.bytesize,
|
|
401
|
+
add_special_tokens ? 1 : 0, out_ids, out_len)
|
|
402
|
+
|
|
403
|
+
if status == 0
|
|
404
|
+
n = out_len[0, 4].unpack1("l")
|
|
405
|
+
out_ids[0, n * 4].unpack("l*")
|
|
406
|
+
else
|
|
407
|
+
# Fallback to Ruby on error
|
|
408
|
+
ruby_encode(text, add_special_tokens)
|
|
409
|
+
end
|
|
410
|
+
end
|
|
411
|
+
end
|
|
412
|
+
end
|
|
413
|
+
end
|