kiribi-gemma4_e2b 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +101 -0
- data/Rakefile +4 -0
- data/ext/kiribi-gemma4_e2b/extconf.rb +84 -0
- data/lib/kiribi/gemma4/e2b/audio_encoder.rb +135 -0
- data/lib/kiribi/gemma4/e2b/model.rb +175 -0
- data/lib/kiribi/gemma4/e2b/version.rb +9 -0
- data/lib/kiribi/gemma4/e2b/vision_encoder.rb +85 -0
- data/lib/kiribi/gemma4/e2b.rb +23 -0
- metadata +88 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: b816119d78e60fc9147954ffde9bdba6723c979fa76fde28c1ae7fe38e77816c
|
|
4
|
+
data.tar.gz: 340bf72b9558d98de1f03506a1f0055f5b17ab4580b296d7d94d3e8f34c5bca1
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: b1c4fae7b415010072e44245110b2e2a7dd2d3232a3c9324e468dc717d6c17e3a7054b3df5490f5e09981de754df360f905a3c22d49b0e723f70aeebec260532
|
|
7
|
+
data.tar.gz: 47def9793155eaa399a29a0650bc3e89627e6babe3778f8abcb7d362daba7bd8dc05225a9ac7b330cea9cc90192c4f5987e464176ab553e300dc1e8959e53882
|
data/README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# kiribi-gemma4_e2b
|
|
2
|
+
|
|
3
|
+
Google Gemma 4 E2B (2.3B parameters) multimodal model for text, image, and audio.
|
|
4
|
+
|
|
5
|
+
Based on [onnx-community/gemma-4-E2B-it-ONNX](https://huggingface.co/onnx-community/gemma-4-E2B-it-ONNX) (ONNX format, FP32).
|
|
6
|
+
|
|
7
|
+
**!!CAUTION!! :** This gem downloads ~22GB of model files from HuggingFace during installation. Be mindful of disk space and network bandwidth.
|
|
8
|
+
|
|
9
|
+
## Installation
|
|
10
|
+
|
|
11
|
+
```sh
|
|
12
|
+
gem install kiribi-gemma4_e2b
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Model files (~22GB) are downloaded from HuggingFace during installation.
|
|
16
|
+
|
|
17
|
+
### Requirements
|
|
18
|
+
|
|
19
|
+
- Ruby >= 3.4.0
|
|
20
|
+
- `ffmpeg` / `ffprobe` (for image and audio preprocessing by the caller)
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
### Text generation
|
|
25
|
+
|
|
26
|
+
```ruby
|
|
27
|
+
require "kiribi/gemma4/e2b"
|
|
28
|
+
|
|
29
|
+
model = Kiribi::Gemma4::E2B.load
|
|
30
|
+
model.generate("Hello!")
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
### Multi-turn chat
|
|
34
|
+
|
|
35
|
+
```ruby
|
|
36
|
+
model.chat([
|
|
37
|
+
{ role: "system", content: "You are a helpful assistant." },
|
|
38
|
+
{ role: "user", content: "What is Ruby?" },
|
|
39
|
+
{ role: "model", content: "Ruby is a dynamic programming language." },
|
|
40
|
+
{ role: "user", content: "Who created it?" },
|
|
41
|
+
])
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Image understanding
|
|
45
|
+
|
|
46
|
+
Preprocessing is the caller's responsibility. Use `ffmpeg`/`ffprobe` to obtain raw RGB pixels:
|
|
47
|
+
|
|
48
|
+
```ruby
|
|
49
|
+
require "kiribi/gemma4/e2b"
|
|
50
|
+
|
|
51
|
+
model = Kiribi::Gemma4::E2B.load
|
|
52
|
+
encoder = model.load_vision_encoder # loads vision_encoder.onnx
|
|
53
|
+
|
|
54
|
+
# 1. Get original dimensions
|
|
55
|
+
info = IO.popen(["ffprobe", "-v", "error", "-select_streams", "v:0", "-show_entries", "stream=width,height", "-of", "csv=p=0", "photo.png"], &:read)
|
|
56
|
+
original_width, original_height = info.strip.split(",").map(&:to_i)
|
|
57
|
+
|
|
58
|
+
# 2. Compute the size to resize to
|
|
59
|
+
input_width, input_height = encoder.input_size_of(original_width, original_height)
|
|
60
|
+
|
|
61
|
+
# 3. Resize (caller's choice of tool)
|
|
62
|
+
blob = IO.popen(["ffmpeg", "-i", "photo.png", "-vf", "scale=#{input_width}:#{input_height}:flags=bicubic", "-f", "rawvideo", "-pix_fmt", "rgb24", "-v", "error", "-"], "rb", &:read)
|
|
63
|
+
|
|
64
|
+
# 4. Encode
|
|
65
|
+
features = encoder.encode(blob, input_width, input_height)
|
|
66
|
+
|
|
67
|
+
model.chat([
|
|
68
|
+
{ role: "user", content: [
|
|
69
|
+
{ type: "image", features: },
|
|
70
|
+
{ type: "text", text: "What is in this image?" },
|
|
71
|
+
] },
|
|
72
|
+
])
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### Audio transcription
|
|
76
|
+
|
|
77
|
+
```ruby
|
|
78
|
+
require "kiribi/gemma4/e2b"
|
|
79
|
+
|
|
80
|
+
model = Kiribi::Gemma4::E2B.load
|
|
81
|
+
encoder = model.load_audio_encoder # loads audio_encoder.onnx
|
|
82
|
+
|
|
83
|
+
# 1. Decode to 16kHz mono f32le PCM
|
|
84
|
+
pcm = IO.popen(["ffmpeg", "-i", "audio.mp3", "-f", "f32le", "-acodec", "pcm_f32le", "-ar", "16000", "-ac", "1", "-", err: "/dev/null"], "rb", &:read)
|
|
85
|
+
|
|
86
|
+
# 2. Encode
|
|
87
|
+
features = encoder.encode(pcm)
|
|
88
|
+
|
|
89
|
+
model.chat([
|
|
90
|
+
{ role: "user", content: [
|
|
91
|
+
{ type: "audio", features: },
|
|
92
|
+
{ type: "text", text: "Transcribe the following speech segment in its original language." },
|
|
93
|
+
] },
|
|
94
|
+
])
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
## License
|
|
98
|
+
|
|
99
|
+
This gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
|
|
100
|
+
|
|
101
|
+
The model weights are licensed under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) by Google.
|
data/Rakefile
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "fileutils"
|
|
4
|
+
require "net/http"
|
|
5
|
+
|
|
6
|
+
GEM_NAME = "kiribi-gemma4_e2b"
|
|
7
|
+
HF_REPO = "matsudai17/gemma-4-E2B-it-ONNX"
|
|
8
|
+
HF_BASE_URL = "https://huggingface.co/#{HF_REPO}/resolve/main/onnx"
|
|
9
|
+
|
|
10
|
+
MODEL_FILES = %w[
|
|
11
|
+
embed_tokens.onnx
|
|
12
|
+
embed_tokens.onnx_data
|
|
13
|
+
embed_tokens.onnx_data_1
|
|
14
|
+
decoder_model_merged.onnx
|
|
15
|
+
decoder_model_merged.onnx_data
|
|
16
|
+
decoder_model_merged.onnx_data_1
|
|
17
|
+
decoder_model_merged.onnx_data_2
|
|
18
|
+
decoder_model_merged.onnx_data_3
|
|
19
|
+
decoder_model_merged.onnx_data_4
|
|
20
|
+
vision_encoder.onnx
|
|
21
|
+
vision_encoder.onnx_data
|
|
22
|
+
audio_encoder.onnx
|
|
23
|
+
audio_encoder.onnx_data
|
|
24
|
+
]
|
|
25
|
+
TOKENIZER_FILE = "tokenizer.json"
|
|
26
|
+
TOKENIZER_URL = "https://huggingface.co/#{HF_REPO}/resolve/main/#{TOKENIZER_FILE}"
|
|
27
|
+
|
|
28
|
+
BUILD_DIRPATH = File.expand_path(File.join(__dir__, "../../lib/#{GEM_NAME}/vendor/build"))
|
|
29
|
+
|
|
30
|
+
def download_file(url, dest)
|
|
31
|
+
redirect_count = 0
|
|
32
|
+
loop do
|
|
33
|
+
raise "Too many redirects" if redirect_count >= 10
|
|
34
|
+
|
|
35
|
+
uri = URI.parse(url)
|
|
36
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
|
37
|
+
http.use_ssl = (uri.scheme == "https")
|
|
38
|
+
request = Net::HTTP::Get.new(uri.request_uri)
|
|
39
|
+
|
|
40
|
+
http.request(request) do |resp|
|
|
41
|
+
case resp
|
|
42
|
+
when Net::HTTPSuccess
|
|
43
|
+
FileUtils.mkdir_p(File.dirname(dest))
|
|
44
|
+
File.open(dest, "wb") do |f|
|
|
45
|
+
resp.read_body { |chunk| f.write(chunk) }
|
|
46
|
+
end
|
|
47
|
+
return
|
|
48
|
+
when Net::HTTPRedirection
|
|
49
|
+
url = resp["Location"]
|
|
50
|
+
redirect_count += 1
|
|
51
|
+
else
|
|
52
|
+
raise "HTTP request failed for #{url} (status code: #{resp.code})"
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
if Dir.exist?(BUILD_DIRPATH)
|
|
59
|
+
puts "#{BUILD_DIRPATH} already exists, skipping download."
|
|
60
|
+
else
|
|
61
|
+
FileUtils.mkdir_p(BUILD_DIRPATH)
|
|
62
|
+
|
|
63
|
+
# Download model files
|
|
64
|
+
MODEL_FILES.each do |filename|
|
|
65
|
+
dest = File.join(BUILD_DIRPATH, filename)
|
|
66
|
+
if File.exist?(dest)
|
|
67
|
+
puts "#{filename} already exists, skipping."
|
|
68
|
+
else
|
|
69
|
+
puts "Downloading #{filename}..."
|
|
70
|
+
download_file("#{HF_BASE_URL}/#{filename}", dest)
|
|
71
|
+
puts " -> #{dest}"
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Download tokenizer
|
|
76
|
+
tokenizer_dest = File.join(BUILD_DIRPATH, TOKENIZER_FILE)
|
|
77
|
+
unless File.exist?(tokenizer_dest)
|
|
78
|
+
puts "Downloading #{TOKENIZER_FILE}..."
|
|
79
|
+
download_file(TOKENIZER_URL, tokenizer_dest)
|
|
80
|
+
puts " -> #{tokenizer_dest}"
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
File.write("Makefile", "all install clean:\n\t@echo \"Nothing to do for $(TARGET)\"\n")
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "onnxruntime"
|
|
4
|
+
|
|
5
|
+
module Kiribi
|
|
6
|
+
module Gemma4
|
|
7
|
+
module E2B
|
|
8
|
+
AUDIO_ENCODER_FILEPATH = File.expand_path(File.join(__dir__, "../../../kiribi-gemma4_e2b/vendor/build/audio_encoder.onnx"))
|
|
9
|
+
|
|
10
|
+
class AudioEncoder
|
|
11
|
+
def initialize
|
|
12
|
+
@model = OnnxRuntime::Model.new(AUDIO_ENCODER_FILEPATH)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# pcm_samples: 16kHz mono float32 PCM サンプル配列またはバイナリ文字列
|
|
16
|
+
# audio_features 配列を返す
|
|
17
|
+
def encode(pcm_samples)
|
|
18
|
+
pcm = pcm_samples.is_a?(String) ? pcm_samples.unpack("e*") : pcm_samples
|
|
19
|
+
|
|
20
|
+
frame_length = 320
|
|
21
|
+
hop_length = 160
|
|
22
|
+
fft_length = 512
|
|
23
|
+
num_mels = 128
|
|
24
|
+
mel_floor = 0.001
|
|
25
|
+
|
|
26
|
+
window = Array.new(frame_length) { 0.5 - 0.5 * Math.cos(2.0 * Math::PI * it / frame_length) }
|
|
27
|
+
mel_filters = build_mel_filterbank(fft_length / 2 + 1, num_mels, 0.0, 8000.0, 16_000)
|
|
28
|
+
|
|
29
|
+
pad_left = frame_length / 2
|
|
30
|
+
padded = Array.new(pad_left, 0.0) + pcm
|
|
31
|
+
mask_raw = Array.new(pad_left, false) + Array.new(pcm.length, true)
|
|
32
|
+
|
|
33
|
+
frame_size = frame_length + 1
|
|
34
|
+
num_frames = (padded.length - frame_size) / hop_length + 1
|
|
35
|
+
|
|
36
|
+
input_features = []
|
|
37
|
+
input_features_mask = []
|
|
38
|
+
|
|
39
|
+
num_frames.times do |fi|
|
|
40
|
+
start = fi * hop_length
|
|
41
|
+
windowed = frame_length.times.map { padded[start + it] * window[it] }
|
|
42
|
+
|
|
43
|
+
mag = rfft_magnitude(windowed, fft_length)
|
|
44
|
+
|
|
45
|
+
mel = num_mels.times.map do |m|
|
|
46
|
+
sum = 0.0
|
|
47
|
+
mag.each_with_index { |v, i| sum += v * mel_filters[i][m] }
|
|
48
|
+
Math.log(sum + mel_floor)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
end_idx = fi * hop_length + frame_size - 1
|
|
52
|
+
valid = end_idx < mask_raw.length && mask_raw[end_idx]
|
|
53
|
+
|
|
54
|
+
input_features << (valid ? mel : Array.new(num_mels, 0.0))
|
|
55
|
+
input_features_mask << valid
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# pad_to_multiple_of 128
|
|
59
|
+
padded_frames = ((input_features.length + 127) / 128) * 128
|
|
60
|
+
while input_features.length < padded_frames
|
|
61
|
+
input_features << Array.new(num_mels, 0.0)
|
|
62
|
+
input_features_mask << false
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
@model.predict({
|
|
66
|
+
"input_features" => [input_features],
|
|
67
|
+
"input_features_mask" => [input_features_mask],
|
|
68
|
+
})["audio_features"]
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
private
|
|
72
|
+
|
|
73
|
+
def build_mel_filterbank(num_fft_bins, num_mel_filters, min_freq, max_freq, sample_rate)
|
|
74
|
+
fft_freqs = (0...num_fft_bins).map { it.to_f * sample_rate / ((num_fft_bins - 1) * 2) }
|
|
75
|
+
mel_min = 2595.0 * Math.log10(1.0 + min_freq / 700.0)
|
|
76
|
+
mel_max = 2595.0 * Math.log10(1.0 + max_freq / 700.0)
|
|
77
|
+
mel_points = (0..num_mel_filters + 1).map { mel_min + it * (mel_max - mel_min) / (num_mel_filters + 1) }
|
|
78
|
+
hz_points = mel_points.map { 700.0 * (10.0**(it / 2595.0) - 1.0) }
|
|
79
|
+
|
|
80
|
+
filters = Array.new(num_fft_bins) { Array.new(num_mel_filters, 0.0) }
|
|
81
|
+
num_mel_filters.times do |m|
|
|
82
|
+
lower = hz_points[m]
|
|
83
|
+
center = hz_points[m + 1]
|
|
84
|
+
upper = hz_points[m + 2]
|
|
85
|
+
fft_freqs.each_with_index do |f, i|
|
|
86
|
+
if f >= lower && f <= center && center > lower
|
|
87
|
+
filters[i][m] = (f - lower) / (center - lower)
|
|
88
|
+
elsif f > center && f <= upper && upper > center
|
|
89
|
+
filters[i][m] = (upper - f) / (upper - center)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
filters
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def rfft_magnitude(real_signal, n)
|
|
97
|
+
padded = Array.new(n, 0.0)
|
|
98
|
+
real_signal.each_with_index { |v, i| padded[i] = v if i < n }
|
|
99
|
+
imag = Array.new(n, 0.0)
|
|
100
|
+
r, i = fft(padded, imag)
|
|
101
|
+
bins = n / 2 + 1
|
|
102
|
+
bins.times.map { Math.sqrt(r[it]**2 + i[it]**2) }
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def fft(x_real, x_imag)
|
|
106
|
+
n = x_real.length
|
|
107
|
+
return [x_real.dup, x_imag.dup] if n <= 1
|
|
108
|
+
|
|
109
|
+
even_r, even_i = fft(
|
|
110
|
+
(0...n / 2).map { x_real[it * 2] },
|
|
111
|
+
(0...n / 2).map { x_imag[it * 2] }
|
|
112
|
+
)
|
|
113
|
+
odd_r, odd_i = fft(
|
|
114
|
+
(0...n / 2).map { x_real[it * 2 + 1] },
|
|
115
|
+
(0...n / 2).map { x_imag[it * 2 + 1] }
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
result_r = Array.new(n)
|
|
119
|
+
result_i = Array.new(n)
|
|
120
|
+
half = n / 2
|
|
121
|
+
half.times do |k|
|
|
122
|
+
angle = -2.0 * Math::PI * k / n
|
|
123
|
+
tr = Math.cos(angle) * odd_r[k] - Math.sin(angle) * odd_i[k]
|
|
124
|
+
ti = Math.sin(angle) * odd_r[k] + Math.cos(angle) * odd_i[k]
|
|
125
|
+
result_r[k] = even_r[k] + tr
|
|
126
|
+
result_i[k] = even_i[k] + ti
|
|
127
|
+
result_r[k + half] = even_r[k] - tr
|
|
128
|
+
result_i[k + half] = even_i[k] - ti
|
|
129
|
+
end
|
|
130
|
+
[result_r, result_i]
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
end
|
|
135
|
+
end
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "onnxruntime"
|
|
4
|
+
require "tokenizers"
|
|
5
|
+
|
|
6
|
+
module Kiribi
|
|
7
|
+
module Gemma4
|
|
8
|
+
module E2B
|
|
9
|
+
TOKENIZER_FILEPATH = File.expand_path(File.join(__dir__, "../../../kiribi-gemma4_e2b/vendor/build/tokenizer.json"))
|
|
10
|
+
EMBED_MODEL_FILEPATH = File.expand_path(File.join(__dir__, "../../../kiribi-gemma4_e2b/vendor/build/embed_tokens.onnx"))
|
|
11
|
+
DECODER_MODEL_FILEPATH = File.expand_path(File.join(__dir__, "../../../kiribi-gemma4_e2b/vendor/build/decoder_model_merged.onnx"))
|
|
12
|
+
|
|
13
|
+
class Model
|
|
14
|
+
EOS_TOKEN_IDS = [1, 106, 50]
|
|
15
|
+
IMAGE_TOKEN_ID = 258_880
|
|
16
|
+
AUDIO_TOKEN_ID = 258_881
|
|
17
|
+
|
|
18
|
+
attr_reader :tokenizer
|
|
19
|
+
|
|
20
|
+
def initialize
|
|
21
|
+
@tokenizer = Tokenizers.from_file(TOKENIZER_FILEPATH)
|
|
22
|
+
@embed_model = OnnxRuntime::Model.new(EMBED_MODEL_FILEPATH)
|
|
23
|
+
@decoder_model = OnnxRuntime::Model.new(DECODER_MODEL_FILEPATH)
|
|
24
|
+
|
|
25
|
+
decoder_sess = OnnxRuntime::InferenceSession.new(DECODER_MODEL_FILEPATH)
|
|
26
|
+
@head_dims = decoder_sess.inputs
|
|
27
|
+
.select { it[:name].match?(/\Apast_key_values\.\d+\.key\z/) }
|
|
28
|
+
.sort_by { it[:name][/\d+/].to_i }
|
|
29
|
+
.map { it[:shape].last }
|
|
30
|
+
@num_layers = @head_dims.length
|
|
31
|
+
|
|
32
|
+
@num_logits_to_keep_1 = OnnxRuntime::OrtValue.from_shape_and_type([], :int64)
|
|
33
|
+
@num_logits_to_keep_1.data_ptr.write_int64(1)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# -------------------------------------------------
|
|
37
|
+
# 遅延ロード(名前で初期化コストを明示)
|
|
38
|
+
# -------------------------------------------------
|
|
39
|
+
|
|
40
|
+
def load_vision_encoder
|
|
41
|
+
@vision_encoder ||= VisionEncoder.new
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def load_audio_encoder
|
|
45
|
+
@audio_encoder ||= AudioEncoder.new
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# -------------------------------------------------
|
|
49
|
+
# 低レベル API: ONNX 呼び出しのみ
|
|
50
|
+
# -------------------------------------------------
|
|
51
|
+
|
|
52
|
+
def embed(input_ids)
|
|
53
|
+
@embed_model.predict({"input_ids" => [input_ids]})
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def forward(inputs_embeds:, per_layer_inputs:, attention_mask:, position_ids:, past_key_values: nil)
|
|
57
|
+
past_kv = past_key_values || init_kv_cache
|
|
58
|
+
input = {
|
|
59
|
+
"inputs_embeds" => inputs_embeds,
|
|
60
|
+
"attention_mask" => attention_mask,
|
|
61
|
+
"position_ids" => position_ids,
|
|
62
|
+
"num_logits_to_keep" => @num_logits_to_keep_1,
|
|
63
|
+
"per_layer_inputs" => per_layer_inputs,
|
|
64
|
+
}
|
|
65
|
+
input.merge!(past_kv)
|
|
66
|
+
out = @decoder_model.predict(input)
|
|
67
|
+
|
|
68
|
+
new_kv = {}
|
|
69
|
+
@num_layers.times do |i|
|
|
70
|
+
new_kv["past_key_values.#{i}.key"] = out["present.#{i}.key"]
|
|
71
|
+
new_kv["past_key_values.#{i}.value"] = out["present.#{i}.value"]
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
{logits: out["logits"], past_key_values: new_kv}
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def init_kv_cache
|
|
78
|
+
kv = {}
|
|
79
|
+
@num_layers.times do |i|
|
|
80
|
+
kv["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type([1, 1, 0, @head_dims[i]], :float)
|
|
81
|
+
kv["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type([1, 1, 0, @head_dims[i]], :float)
|
|
82
|
+
end
|
|
83
|
+
kv
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# -------------------------------------------------
|
|
87
|
+
# 高レベル API
|
|
88
|
+
# -------------------------------------------------
|
|
89
|
+
|
|
90
|
+
def generate(prompt, max_new_tokens: 256)
|
|
91
|
+
chat([{role: "user", content: prompt}], max_new_tokens:)
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
def chat(messages, max_new_tokens: 256)
|
|
95
|
+
prompt_parts = ["<bos>"]
|
|
96
|
+
encoded_media = []
|
|
97
|
+
|
|
98
|
+
messages.each do |msg|
|
|
99
|
+
role = msg[:role]
|
|
100
|
+
content = msg[:content]
|
|
101
|
+
prompt_parts << "<|turn>#{role}\n"
|
|
102
|
+
|
|
103
|
+
if content.is_a?(String)
|
|
104
|
+
prompt_parts << content
|
|
105
|
+
elsif content.is_a?(Array)
|
|
106
|
+
content.each do |part|
|
|
107
|
+
case part[:type]
|
|
108
|
+
when "text"
|
|
109
|
+
prompt_parts << part[:text]
|
|
110
|
+
when "image"
|
|
111
|
+
features = part[:features]
|
|
112
|
+
prompt_parts << "<|image>" + "<|image|>" * features.length + "<image|>\n"
|
|
113
|
+
encoded_media << {token_id: IMAGE_TOKEN_ID, features:}
|
|
114
|
+
when "audio"
|
|
115
|
+
features = part[:features]
|
|
116
|
+
prompt_parts << "<|audio>" + "<|audio|>" * features.length + "<audio|>\n"
|
|
117
|
+
encoded_media << {token_id: AUDIO_TOKEN_ID, features:}
|
|
118
|
+
end
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
prompt_parts << "<turn|>\n"
|
|
123
|
+
end
|
|
124
|
+
prompt_parts << "<|turn>model\n"
|
|
125
|
+
|
|
126
|
+
input_ids = tokenizer.encode(prompt_parts.join).ids
|
|
127
|
+
|
|
128
|
+
embeds = []
|
|
129
|
+
encoded_media.each do |media|
|
|
130
|
+
positions = input_ids.each_with_index
|
|
131
|
+
.select { |t, _| t == media[:token_id] }
|
|
132
|
+
.map(&:last)
|
|
133
|
+
.reject { |pos| embeds.any? { it[:pos] == pos } }
|
|
134
|
+
media[:features].each_with_index do |feat, idx|
|
|
135
|
+
break if idx >= positions.length
|
|
136
|
+
embeds << {pos: positions[idx], feat:}
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
past_kv = nil
|
|
141
|
+
generated = []
|
|
142
|
+
|
|
143
|
+
max_new_tokens.times do |step|
|
|
144
|
+
cur_ids = step == 0 ? input_ids : [generated.last]
|
|
145
|
+
seq_len = cur_ids.length
|
|
146
|
+
total_len = input_ids.length + generated.length
|
|
147
|
+
|
|
148
|
+
embed_out = embed(cur_ids)
|
|
149
|
+
inputs_embeds = embed_out["inputs_embeds"]
|
|
150
|
+
per_layer_inputs = embed_out["per_layer_inputs"]
|
|
151
|
+
|
|
152
|
+
if step == 0
|
|
153
|
+
embeds.each { inputs_embeds[0][it[:pos]] = it[:feat] }
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
result = forward(
|
|
157
|
+
inputs_embeds:,
|
|
158
|
+
per_layer_inputs:,
|
|
159
|
+
attention_mask: [Array.new(total_len, 1)],
|
|
160
|
+
position_ids: [(total_len - seq_len...total_len).to_a],
|
|
161
|
+
past_key_values: past_kv,
|
|
162
|
+
)
|
|
163
|
+
past_kv = result[:past_key_values]
|
|
164
|
+
|
|
165
|
+
next_token = result[:logits][0][-1].each_with_index.max_by { |v, _| v }[1]
|
|
166
|
+
break if EOS_TOKEN_IDS.include?(next_token)
|
|
167
|
+
generated << next_token
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
tokenizer.decode(generated)
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
end
|
|
175
|
+
end
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "onnxruntime"
|
|
4
|
+
|
|
5
|
+
module Kiribi
|
|
6
|
+
module Gemma4
|
|
7
|
+
module E2B
|
|
8
|
+
VISION_ENCODER_FILEPATH = File.expand_path(File.join(__dir__, "../../../kiribi-gemma4_e2b/vendor/build/vision_encoder.onnx"))
|
|
9
|
+
|
|
10
|
+
class VisionEncoder
|
|
11
|
+
PATCH_SIZE = 16
|
|
12
|
+
RESCALE_FACTOR = 1.0 / 255
|
|
13
|
+
MAX_SOFT_TOKENS = 280
|
|
14
|
+
POOLING_KERNEL = 3
|
|
15
|
+
MAX_PATCHES = MAX_SOFT_TOKENS * POOLING_KERNEL**2
|
|
16
|
+
SIDE_MULT = POOLING_KERNEL * PATCH_SIZE
|
|
17
|
+
|
|
18
|
+
def initialize
|
|
19
|
+
@model = OnnxRuntime::Model.new(VISION_ENCODER_FILEPATH)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
# 元画像サイズ (original_width, original_height) を受け取り、
|
|
23
|
+
# encode 前にリサイズすべきサイズ [width, height] を返す。
|
|
24
|
+
def input_size_of(original_width, original_height)
|
|
25
|
+
target_px = MAX_PATCHES * PATCH_SIZE**2
|
|
26
|
+
factor = Math.sqrt(target_px.to_f / (original_height * original_width))
|
|
27
|
+
|
|
28
|
+
width = (factor * original_width / SIDE_MULT).floor * SIDE_MULT
|
|
29
|
+
height = (factor * original_height / SIDE_MULT).floor * SIDE_MULT
|
|
30
|
+
|
|
31
|
+
if width == 0 && height == 0
|
|
32
|
+
raise "Image too small to resize"
|
|
33
|
+
elsif height == 0
|
|
34
|
+
height = SIDE_MULT
|
|
35
|
+
width = [(original_width / original_height) * SIDE_MULT, MAX_SOFT_TOKENS * SIDE_MULT].min
|
|
36
|
+
elsif width == 0
|
|
37
|
+
width = SIDE_MULT
|
|
38
|
+
height = [(original_height / original_width) * SIDE_MULT, MAX_SOFT_TOKENS * SIDE_MULT].min
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
[width, height]
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# blob_rgb: RGB24 raw bytes(既に width × height へリサイズ済み)
|
|
45
|
+
# image_features 配列を返す
|
|
46
|
+
def encode(blob_rgb, width, height)
|
|
47
|
+
blob = blob_rgb.is_a?(String) ? blob_rgb.unpack("C*") : blob_rgb
|
|
48
|
+
patches_w = width / PATCH_SIZE
|
|
49
|
+
patches_h = height / PATCH_SIZE
|
|
50
|
+
|
|
51
|
+
pixel_values = []
|
|
52
|
+
pixel_position_ids = []
|
|
53
|
+
|
|
54
|
+
patches_w.times do |col|
|
|
55
|
+
patches_h.times do |row|
|
|
56
|
+
patch = []
|
|
57
|
+
PATCH_SIZE.times do |dy|
|
|
58
|
+
PATCH_SIZE.times do |dx|
|
|
59
|
+
y = row * PATCH_SIZE + dy
|
|
60
|
+
x = col * PATCH_SIZE + dx
|
|
61
|
+
idx = (y * width + x) * 3
|
|
62
|
+
patch << blob[idx] * RESCALE_FACTOR
|
|
63
|
+
patch << blob[idx + 1] * RESCALE_FACTOR
|
|
64
|
+
patch << blob[idx + 2] * RESCALE_FACTOR
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
pixel_values << patch
|
|
68
|
+
pixel_position_ids << [col, row]
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
while pixel_values.length < MAX_PATCHES
|
|
73
|
+
pixel_values << Array.new(PATCH_SIZE**2 * 3, 0.0)
|
|
74
|
+
pixel_position_ids << [-1, -1]
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
@model.predict({
|
|
78
|
+
"pixel_values" => [pixel_values],
|
|
79
|
+
"pixel_position_ids" => [pixel_position_ids],
|
|
80
|
+
})["image_features"]
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "e2b/version"
|
|
4
|
+
require_relative "e2b/vision_encoder"
|
|
5
|
+
require_relative "e2b/audio_encoder"
|
|
6
|
+
require_relative "e2b/model"
|
|
7
|
+
require "kiribi"
|
|
8
|
+
|
|
9
|
+
module Kiribi
|
|
10
|
+
module Gemma4
|
|
11
|
+
extend Kiribi::Loader
|
|
12
|
+
|
|
13
|
+
module E2B
|
|
14
|
+
extend Kiribi::Loader
|
|
15
|
+
|
|
16
|
+
def self.instantiate
|
|
17
|
+
Model.new
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
Kiribi.register(Kiribi::Gemma4::E2B, order: 100_300_100)
|
metadata
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
|
2
|
+
name: kiribi-gemma4_e2b
|
|
3
|
+
version: !ruby/object:Gem::Version
|
|
4
|
+
version: 0.0.1
|
|
5
|
+
platform: ruby
|
|
6
|
+
authors:
|
|
7
|
+
- matsudai
|
|
8
|
+
bindir: bin
|
|
9
|
+
cert_chain: []
|
|
10
|
+
date: 1980-01-02 00:00:00.000000000 Z
|
|
11
|
+
dependencies:
|
|
12
|
+
- !ruby/object:Gem::Dependency
|
|
13
|
+
name: kiribi
|
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
|
15
|
+
requirements:
|
|
16
|
+
- - ">="
|
|
17
|
+
- !ruby/object:Gem::Version
|
|
18
|
+
version: 0.0.1
|
|
19
|
+
type: :runtime
|
|
20
|
+
prerelease: false
|
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
22
|
+
requirements:
|
|
23
|
+
- - ">="
|
|
24
|
+
- !ruby/object:Gem::Version
|
|
25
|
+
version: 0.0.1
|
|
26
|
+
- !ruby/object:Gem::Dependency
|
|
27
|
+
name: onnxruntime
|
|
28
|
+
requirement: !ruby/object:Gem::Requirement
|
|
29
|
+
requirements:
|
|
30
|
+
- - ">="
|
|
31
|
+
- !ruby/object:Gem::Version
|
|
32
|
+
version: 0.10.0
|
|
33
|
+
type: :runtime
|
|
34
|
+
prerelease: false
|
|
35
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
36
|
+
requirements:
|
|
37
|
+
- - ">="
|
|
38
|
+
- !ruby/object:Gem::Version
|
|
39
|
+
version: 0.10.0
|
|
40
|
+
- !ruby/object:Gem::Dependency
|
|
41
|
+
name: tokenizers
|
|
42
|
+
requirement: !ruby/object:Gem::Requirement
|
|
43
|
+
requirements:
|
|
44
|
+
- - ">="
|
|
45
|
+
- !ruby/object:Gem::Version
|
|
46
|
+
version: 0.6.0
|
|
47
|
+
type: :runtime
|
|
48
|
+
prerelease: false
|
|
49
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
50
|
+
requirements:
|
|
51
|
+
- - ">="
|
|
52
|
+
- !ruby/object:Gem::Version
|
|
53
|
+
version: 0.6.0
|
|
54
|
+
executables: []
|
|
55
|
+
extensions:
|
|
56
|
+
- ext/kiribi-gemma4_e2b/extconf.rb
|
|
57
|
+
extra_rdoc_files: []
|
|
58
|
+
files:
|
|
59
|
+
- README.md
|
|
60
|
+
- Rakefile
|
|
61
|
+
- ext/kiribi-gemma4_e2b/extconf.rb
|
|
62
|
+
- lib/kiribi/gemma4/e2b.rb
|
|
63
|
+
- lib/kiribi/gemma4/e2b/audio_encoder.rb
|
|
64
|
+
- lib/kiribi/gemma4/e2b/model.rb
|
|
65
|
+
- lib/kiribi/gemma4/e2b/version.rb
|
|
66
|
+
- lib/kiribi/gemma4/e2b/vision_encoder.rb
|
|
67
|
+
homepage: https://github.com/matsudai/kiribi
|
|
68
|
+
licenses:
|
|
69
|
+
- MIT
|
|
70
|
+
metadata: {}
|
|
71
|
+
rdoc_options: []
|
|
72
|
+
require_paths:
|
|
73
|
+
- lib
|
|
74
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
|
75
|
+
requirements:
|
|
76
|
+
- - ">="
|
|
77
|
+
- !ruby/object:Gem::Version
|
|
78
|
+
version: 3.4.0
|
|
79
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
80
|
+
requirements:
|
|
81
|
+
- - ">="
|
|
82
|
+
- !ruby/object:Gem::Version
|
|
83
|
+
version: '0'
|
|
84
|
+
requirements: []
|
|
85
|
+
rubygems_version: 4.0.6
|
|
86
|
+
specification_version: 4
|
|
87
|
+
summary: Easy to use some onnx models.
|
|
88
|
+
test_files: []
|