gte 0.0.1-arm64-darwin
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Gemfile +17 -0
- data/LICENSE +21 -0
- data/README.md +49 -0
- data/Rakefile +76 -0
- data/VERSION +1 -0
- data/ext/gte/Cargo.toml +37 -0
- data/ext/gte/benches/hot_path.rs +53 -0
- data/ext/gte/build.rs +25 -0
- data/ext/gte/extconf.rb +6 -0
- data/ext/gte/src/embedder.rs +342 -0
- data/ext/gte/src/error.rs +48 -0
- data/ext/gte/src/lib.rs +31 -0
- data/ext/gte/src/model_config.rs +17 -0
- data/ext/gte/src/postprocess.rs +113 -0
- data/ext/gte/src/ruby_embedder.rs +222 -0
- data/ext/gte/src/session.rs +123 -0
- data/ext/gte/src/tokenizer.rs +130 -0
- data/ext/gte/tests/embedder_unit_test.rs +39 -0
- data/ext/gte/tests/inference_integration_test.rs +62 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +44 -0
- data/lib/gte/3.0/gte.bundle +0 -0
- data/lib/gte/3.1/gte.bundle +0 -0
- data/lib/gte/3.2/gte.bundle +0 -0
- data/lib/gte/3.3/gte.bundle +0 -0
- data/lib/gte/3.4/gte.bundle +0 -0
- data/lib/gte/4.0/gte.bundle +0 -0
- data/lib/gte.rb +32 -0
- metadata +144 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 7164d25e0854514fe3e756bb85d6080522f76fcd256c37040b5e72ba77ae636a
|
|
4
|
+
data.tar.gz: eaf3946a41c519ac793935d8abeed283cdd933217a8843d69b1b7a38a4407c03
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: f10612ec4d0eaea1a14f52a6af17dc05fee4a7e78699dd9a5ef2a0ce67d6d4e20ea983aad40e11049e411920e57709dfb101542d87c97d84895afc309434c04f
|
|
7
|
+
data.tar.gz: 9b5408bbe8bf803a04c567709d63e15b0b0efac73a39183a6a43fd0e05ad48ac7242a4a4d8e9e88896c1b30ec04355ad552b18662d0a6273f46c38a3fc328075
|
data/Gemfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
source 'https://rubygems.org'
|
|
4
|
+
|
|
5
|
+
gemspec
|
|
6
|
+
|
|
7
|
+
gem 'rake'
|
|
8
|
+
gem 'rake-compiler'
|
|
9
|
+
gem 'rb_sys'
|
|
10
|
+
gem 'rspec'
|
|
11
|
+
gem 'rspec-benchmark'
|
|
12
|
+
gem 'rubocop', require: false
|
|
13
|
+
|
|
14
|
+
group :bench do
|
|
15
|
+
gem 'onnxruntime'
|
|
16
|
+
gem 'tokenizers'
|
|
17
|
+
end
|
data/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 elcuervo
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# gte
|
|
2
|
+

|
|
3
|
+
|
|
4
|
+
`gte` is a Ruby gem with a Rust extension for fast text embeddings with ONNX Runtime.
|
|
5
|
+
Inspired by https://github.com/fbilhaut/gte-rs
|
|
6
|
+
|
|
7
|
+
## Quick Start
|
|
8
|
+
|
|
9
|
+
```ruby
|
|
10
|
+
require "gte"
|
|
11
|
+
|
|
12
|
+
model = GTE.new(ENV.fetch("GTE_MODEL_DIR"))
|
|
13
|
+
vector = model["query: hello world"]
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Model Directory
|
|
17
|
+
|
|
18
|
+
A model directory must include `tokenizer.json` and one ONNX model, resolved in this order:
|
|
19
|
+
|
|
20
|
+
1. `onnx/text_model.onnx`
|
|
21
|
+
2. `text_model.onnx`
|
|
22
|
+
3. `onnx/model.onnx`
|
|
23
|
+
4. `model.onnx`
|
|
24
|
+
|
|
25
|
+
## Development
|
|
26
|
+
|
|
27
|
+
Run commands inside `nix develop`.
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
bundle exec rake compile
|
|
31
|
+
cargo test --manifest-path ext/gte/Cargo.toml --no-default-features
|
|
32
|
+
bundle exec rspec
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
## Benchmark
|
|
36
|
+
|
|
37
|
+
The repo includes two benchmark paths:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
bundle exec rake bench:pure_compare
|
|
41
|
+
bundle exec rake bench:puma_compare
|
|
42
|
+
bundle exec rake bench:matrix_sweep
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
For release tracking and regression detection, record a run entry in `RUNS.md`:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
bundle exec rake bench:record_run
|
|
49
|
+
```
|
data/Rakefile
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'bundler/gem_tasks'
|
|
4
|
+
require 'rake/extensiontask'
|
|
5
|
+
begin
|
|
6
|
+
require 'rspec/core/rake_task'
|
|
7
|
+
RSpec::Core::RakeTask.new(:spec)
|
|
8
|
+
rescue LoadError
|
|
9
|
+
# rspec not available in cross-compile environment
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
spec = Gem::Specification.load('gte.gemspec')
|
|
13
|
+
|
|
14
|
+
Rake::ExtensionTask.new('gte', spec) do |ext|
|
|
15
|
+
ext.lib_dir = 'lib/gte'
|
|
16
|
+
ext.cross_compile = true
|
|
17
|
+
ext.cross_platform = %w[x86_64-linux arm64-darwin]
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
task default: %i[compile spec]
|
|
21
|
+
|
|
22
|
+
def run_in_nix(*command)
|
|
23
|
+
sh('nix', 'develop', '-c', *command)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
namespace :bench do
|
|
27
|
+
desc 'Run pure-Ruby (onnxruntime gem) vs GTE benchmark comparison inside nix develop'
|
|
28
|
+
task :pure_compare do
|
|
29
|
+
run_in_nix('bundle', 'exec', 'ruby', 'bench/pure_ruby_compare.rb')
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
desc 'Run Puma-like concurrent single-request benchmark (GTE vs pure Ruby)'
|
|
33
|
+
task :puma_compare do
|
|
34
|
+
run_in_nix(
|
|
35
|
+
'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
|
|
36
|
+
'--output', 'bench/results/puma_compare_latest.json',
|
|
37
|
+
'--iterations', '80',
|
|
38
|
+
'--runs', '3'
|
|
39
|
+
)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
desc 'Sweep execution-provider and thread settings for Puma-like benchmark'
|
|
43
|
+
task :matrix_sweep do
|
|
44
|
+
run_in_nix(
|
|
45
|
+
'bundle', 'exec', 'ruby', 'bench/puma_matrix_sweep.rb',
|
|
46
|
+
'--iterations', '80',
|
|
47
|
+
'--runs', '3'
|
|
48
|
+
)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal/regression checks'
|
|
52
|
+
task :record_run do
|
|
53
|
+
run_in_nix(
|
|
54
|
+
'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
|
|
55
|
+
'--output', 'bench/results/puma_compare_latest.json',
|
|
56
|
+
'--iterations', '80',
|
|
57
|
+
'--runs', '3'
|
|
58
|
+
)
|
|
59
|
+
run_in_nix(
|
|
60
|
+
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'append',
|
|
61
|
+
'--result', 'bench/results/puma_compare_latest.json'
|
|
62
|
+
)
|
|
63
|
+
run_in_nix(
|
|
64
|
+
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
|
|
65
|
+
'--result', 'bench/results/puma_compare_latest.json'
|
|
66
|
+
)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
desc 'Validate current Puma benchmark output against 2x goal and regression policy'
|
|
70
|
+
task :check_goal do
|
|
71
|
+
run_in_nix(
|
|
72
|
+
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
|
|
73
|
+
'--result', 'bench/results/puma_compare_latest.json'
|
|
74
|
+
)
|
|
75
|
+
end
|
|
76
|
+
end
|
data/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.1
|
data/ext/gte/Cargo.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[package]
|
|
2
|
+
name = "gte"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
edition = "2021"
|
|
5
|
+
authors = ["elcuervo <elcuervo@elcuervo.net>"]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
publish = false
|
|
8
|
+
build = "build.rs"
|
|
9
|
+
|
|
10
|
+
[lib]
|
|
11
|
+
# cdylib: Ruby FFI extension; rlib: enables integration tests in tests/ to link as external crate
|
|
12
|
+
crate-type = ["cdylib", "rlib"]
|
|
13
|
+
|
|
14
|
+
[features]
|
|
15
|
+
# ruby-ffi: gate magnus + rb-sys (Ruby C symbols) so Rust integration tests can link without Ruby.
|
|
16
|
+
# This feature is enabled by default for the cdylib build (rake compile / extconf.rb).
|
|
17
|
+
# When running `cargo test`, this feature must be excluded: `cargo test --no-default-features`.
|
|
18
|
+
default = ["ruby-ffi"]
|
|
19
|
+
ruby-ffi = ["dep:magnus", "dep:rb-sys"]
|
|
20
|
+
|
|
21
|
+
[dependencies]
|
|
22
|
+
rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
|
|
23
|
+
magnus = { version = "0.8", optional = true }
|
|
24
|
+
ort = { version = "=2.0.0-rc.9", features = ["ndarray"] }
|
|
25
|
+
ort-sys = "=2.0.0-rc.9"
|
|
26
|
+
tokenizers = "0.21.0"
|
|
27
|
+
ndarray = "0.16.0"
|
|
28
|
+
half = "2"
|
|
29
|
+
serde = { version = "1", features = ["derive"] }
|
|
30
|
+
serde_json = "1"
|
|
31
|
+
|
|
32
|
+
[dev-dependencies]
|
|
33
|
+
criterion = "0.5"
|
|
34
|
+
|
|
35
|
+
[[bench]]
|
|
36
|
+
name = "hot_path"
|
|
37
|
+
harness = false
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
|
2
|
+
use gte::postprocess::{mean_pool, normalize_l2};
|
|
3
|
+
use ndarray::{Array2, Array3};
|
|
4
|
+
|
|
5
|
+
fn build_hidden_states(batch: usize, seq: usize, dim: usize) -> Array3<f32> {
|
|
6
|
+
Array3::from_shape_fn((batch, seq, dim), |(b, s, d)| {
|
|
7
|
+
(((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0
|
|
8
|
+
})
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
fn build_attention_mask(batch: usize, seq: usize) -> Array2<i64> {
|
|
12
|
+
Array2::from_shape_fn((batch, seq), |(_, s)| if s % 11 == 10 { 0 } else { 1 })
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
fn bench_mean_pool(c: &mut Criterion) {
|
|
16
|
+
let mut group = c.benchmark_group("mean_pool");
|
|
17
|
+
for (batch, seq, dim) in [(1, 32, 384), (8, 64, 384), (32, 64, 768)] {
|
|
18
|
+
let hidden_states = build_hidden_states(batch, seq, dim);
|
|
19
|
+
let attention_mask = build_attention_mask(batch, seq);
|
|
20
|
+
group.bench_with_input(
|
|
21
|
+
BenchmarkId::from_parameter(format!("{batch}x{seq}x{dim}")),
|
|
22
|
+
&(batch, seq, dim),
|
|
23
|
+
|b, _| {
|
|
24
|
+
b.iter(|| {
|
|
25
|
+
mean_pool(
|
|
26
|
+
black_box(hidden_states.view()),
|
|
27
|
+
black_box(attention_mask.view()),
|
|
28
|
+
)
|
|
29
|
+
.unwrap()
|
|
30
|
+
})
|
|
31
|
+
},
|
|
32
|
+
);
|
|
33
|
+
}
|
|
34
|
+
group.finish();
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
fn bench_normalize_l2(c: &mut Criterion) {
|
|
38
|
+
let mut group = c.benchmark_group("normalize_l2");
|
|
39
|
+
for (rows, dim) in [(1, 384), (8, 384), (32, 768), (128, 768)] {
|
|
40
|
+
let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)| {
|
|
41
|
+
(((row * 19 + col * 7) % 113) as f32) / 113.0
|
|
42
|
+
});
|
|
43
|
+
group.bench_with_input(
|
|
44
|
+
BenchmarkId::from_parameter(format!("{rows}x{dim}")),
|
|
45
|
+
&(rows, dim),
|
|
46
|
+
|b, _| b.iter(|| normalize_l2(black_box(embeddings.clone()))),
|
|
47
|
+
);
|
|
48
|
+
}
|
|
49
|
+
group.finish();
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
criterion_group!(benches, bench_mean_pool, bench_normalize_l2);
|
|
53
|
+
criterion_main!(benches);
|
data/ext/gte/build.rs
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
fn main() {
|
|
2
|
+
let version = std::fs::read_to_string("../../VERSION")
|
|
3
|
+
.expect("VERSION file not found")
|
|
4
|
+
.trim()
|
|
5
|
+
.to_string();
|
|
6
|
+
|
|
7
|
+
let cargo_version = env!("CARGO_PKG_VERSION");
|
|
8
|
+
|
|
9
|
+
assert_eq!(
|
|
10
|
+
version, cargo_version,
|
|
11
|
+
"VERSION file ({}) doesn't match Cargo.toml ({}). Update Cargo.toml to match.",
|
|
12
|
+
version, cargo_version
|
|
13
|
+
);
|
|
14
|
+
|
|
15
|
+
println!("cargo:rerun-if-changed=../../VERSION");
|
|
16
|
+
|
|
17
|
+
// Ensure the ORT shared library can be found at runtime via @rpath on macOS.
|
|
18
|
+
// ORT_LIB_LOCATION is set by the Nix dev shell when ORT_STRATEGY=system.
|
|
19
|
+
if let Ok(ort_lib) = std::env::var("ORT_LIB_LOCATION") {
|
|
20
|
+
let lib_dir = std::path::Path::new(&ort_lib).join("lib");
|
|
21
|
+
if lib_dir.exists() {
|
|
22
|
+
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_dir.display());
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
data/ext/gte/extconf.rb
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
3
|
+
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
4
|
+
use crate::session::{build_session, run_session};
|
|
5
|
+
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
6
|
+
use ndarray::Array2;
|
|
7
|
+
use ort::session::Session;
|
|
8
|
+
use std::path::{Path, PathBuf};
|
|
9
|
+
|
|
10
|
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
11
|
+
pub enum ModelFamily {
|
|
12
|
+
E5Like,
|
|
13
|
+
SiglipLike,
|
|
14
|
+
ClipLike,
|
|
15
|
+
Other,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
pub struct Embedder {
|
|
19
|
+
tokenizer: Tokenizer,
|
|
20
|
+
session: Session,
|
|
21
|
+
config: ModelConfig,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
impl Embedder {
|
|
25
|
+
pub fn new<P1, P2>(tokenizer_path: P1, model_path: P2, config: ModelConfig) -> Result<Self>
|
|
26
|
+
where
|
|
27
|
+
P1: AsRef<Path>,
|
|
28
|
+
P2: AsRef<Path>,
|
|
29
|
+
{
|
|
30
|
+
let tokenizer = Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
31
|
+
let session = build_session(model_path, &config)?;
|
|
32
|
+
Ok(Self {
|
|
33
|
+
tokenizer,
|
|
34
|
+
session,
|
|
35
|
+
config,
|
|
36
|
+
})
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
pub fn from_dir<P: AsRef<Path>>(
|
|
40
|
+
dir: P,
|
|
41
|
+
num_threads: usize,
|
|
42
|
+
optimization_level: u8,
|
|
43
|
+
) -> Result<Self> {
|
|
44
|
+
let dir = dir.as_ref();
|
|
45
|
+
let tokenizer_path = dir.join("tokenizer.json");
|
|
46
|
+
let model_path = resolve_model_path(dir)?;
|
|
47
|
+
|
|
48
|
+
if !tokenizer_path.exists() {
|
|
49
|
+
return Err(GteError::Tokenizer(format!(
|
|
50
|
+
"tokenizer.json not found in {}",
|
|
51
|
+
dir.display()
|
|
52
|
+
)));
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
let max_length = read_max_length(dir);
|
|
56
|
+
let temp_config = ModelConfig {
|
|
57
|
+
max_length,
|
|
58
|
+
output_tensor: String::new(),
|
|
59
|
+
mode: ExtractorMode::Raw,
|
|
60
|
+
with_type_ids: false,
|
|
61
|
+
with_attention_mask: true,
|
|
62
|
+
num_threads,
|
|
63
|
+
optimization_level,
|
|
64
|
+
};
|
|
65
|
+
let session = build_session(&model_path, &temp_config)?;
|
|
66
|
+
|
|
67
|
+
validate_supported_inputs(&session)?;
|
|
68
|
+
let with_type_ids = session.inputs.iter().any(|i| i.name == "token_type_ids");
|
|
69
|
+
let with_attention_mask = session.inputs.iter().any(|i| i.name == "attention_mask");
|
|
70
|
+
let output_tensor = select_output_tensor(&session)?;
|
|
71
|
+
let output_base = output_basename(output_tensor.as_str()).to_string();
|
|
72
|
+
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
73
|
+
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
74
|
+
return Err(GteError::Inference(
|
|
75
|
+
"cannot use mean pooling without attention_mask input".to_string(),
|
|
76
|
+
));
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
let tuned_num_threads = tune_num_threads(
|
|
80
|
+
num_threads,
|
|
81
|
+
with_attention_mask,
|
|
82
|
+
with_type_ids,
|
|
83
|
+
output_base.as_str(),
|
|
84
|
+
);
|
|
85
|
+
|
|
86
|
+
let config = ModelConfig {
|
|
87
|
+
max_length,
|
|
88
|
+
output_tensor,
|
|
89
|
+
mode,
|
|
90
|
+
with_type_ids,
|
|
91
|
+
with_attention_mask,
|
|
92
|
+
num_threads: tuned_num_threads,
|
|
93
|
+
optimization_level,
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
let session = if tuned_num_threads != num_threads {
|
|
97
|
+
build_session(&model_path, &config)?
|
|
98
|
+
} else {
|
|
99
|
+
session
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
103
|
+
|
|
104
|
+
Ok(Self {
|
|
105
|
+
tokenizer,
|
|
106
|
+
session,
|
|
107
|
+
config,
|
|
108
|
+
})
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
|
|
112
|
+
let tokenized = self.tokenize(&texts)?;
|
|
113
|
+
self.run(&tokenized)
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
pub fn tokenize(&self, texts: &[String]) -> crate::error::Result<Tokenized> {
|
|
117
|
+
self.tokenizer.tokenize(texts)
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
|
|
121
|
+
run_session(&self.session, tokenized, &self.config)
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
fn tune_num_threads(
|
|
127
|
+
requested: usize,
|
|
128
|
+
with_attention_mask: bool,
|
|
129
|
+
with_type_ids: bool,
|
|
130
|
+
output_name: &str,
|
|
131
|
+
) -> usize {
|
|
132
|
+
if requested > 0 {
|
|
133
|
+
return requested;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
|
|
137
|
+
let target_concurrency = puma_target_concurrency();
|
|
138
|
+
let host_cores = host_parallelism();
|
|
139
|
+
let budgeted_threads = (host_cores / target_concurrency).max(1);
|
|
140
|
+
|
|
141
|
+
match family {
|
|
142
|
+
// Puma-like workloads typically run many concurrent single-item requests where
|
|
143
|
+
// one intra-op thread per request gives the best tail behavior.
|
|
144
|
+
ModelFamily::E5Like | ModelFamily::ClipLike | ModelFamily::SiglipLike => {
|
|
145
|
+
budgeted_threads.min(1)
|
|
146
|
+
}
|
|
147
|
+
ModelFamily::Other => 0,
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
fn infer_model_family(
|
|
152
|
+
with_attention_mask: bool,
|
|
153
|
+
with_type_ids: bool,
|
|
154
|
+
output_name: &str,
|
|
155
|
+
) -> ModelFamily {
|
|
156
|
+
if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
|
|
157
|
+
return ModelFamily::E5Like;
|
|
158
|
+
}
|
|
159
|
+
if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
|
|
160
|
+
return ModelFamily::SiglipLike;
|
|
161
|
+
}
|
|
162
|
+
if output_name == "text_embeds" && !with_attention_mask {
|
|
163
|
+
return ModelFamily::ClipLike;
|
|
164
|
+
}
|
|
165
|
+
ModelFamily::Other
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
fn puma_target_concurrency() -> usize {
|
|
169
|
+
std::env::var("GTE_PUMA_CONCURRENCY")
|
|
170
|
+
.ok()
|
|
171
|
+
.and_then(|raw| raw.parse::<usize>().ok())
|
|
172
|
+
.filter(|value| *value > 0)
|
|
173
|
+
.unwrap_or(16)
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
fn host_parallelism() -> usize {
|
|
177
|
+
std::thread::available_parallelism()
|
|
178
|
+
.map(|n| n.get())
|
|
179
|
+
.unwrap_or(1)
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
|
|
183
|
+
let candidates = [
|
|
184
|
+
dir.join("onnx").join("text_model.onnx"),
|
|
185
|
+
dir.join("text_model.onnx"),
|
|
186
|
+
dir.join("onnx").join("model.onnx"),
|
|
187
|
+
dir.join("model.onnx"),
|
|
188
|
+
];
|
|
189
|
+
for path in &candidates {
|
|
190
|
+
if path.exists() {
|
|
191
|
+
return Ok(path.clone());
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
Err(GteError::Inference(format!(
|
|
195
|
+
"no ONNX model found in {} (checked text_model.onnx and model.onnx)",
|
|
196
|
+
dir.display()
|
|
197
|
+
)))
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
201
|
+
|
|
202
|
+
fn validate_supported_inputs(session: &Session) -> Result<()> {
|
|
203
|
+
let unsupported: Vec<String> = session
|
|
204
|
+
.inputs
|
|
205
|
+
.iter()
|
|
206
|
+
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
207
|
+
.map(|i| i.name.clone())
|
|
208
|
+
.collect();
|
|
209
|
+
|
|
210
|
+
if unsupported.is_empty() {
|
|
211
|
+
return Ok(());
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
let mut message = format!(
|
|
215
|
+
"unsupported model inputs for text embedding API: {}",
|
|
216
|
+
unsupported.join(", ")
|
|
217
|
+
);
|
|
218
|
+
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
219
|
+
message.push_str(
|
|
220
|
+
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
221
|
+
);
|
|
222
|
+
} else {
|
|
223
|
+
message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
|
|
224
|
+
}
|
|
225
|
+
Err(GteError::Inference(message))
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
229
|
+
let lower = name.to_ascii_lowercase();
|
|
230
|
+
lower == preferred || lower.ends_with(&format!("/{}", preferred))
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
fn select_output_tensor(session: &Session) -> Result<String> {
|
|
234
|
+
const PREFERRED: [&str; 4] = [
|
|
235
|
+
"text_embeds",
|
|
236
|
+
"pooler_output",
|
|
237
|
+
"sentence_embedding",
|
|
238
|
+
"last_hidden_state",
|
|
239
|
+
];
|
|
240
|
+
|
|
241
|
+
for preferred in PREFERRED {
|
|
242
|
+
if let Some(output) = session
|
|
243
|
+
.outputs
|
|
244
|
+
.iter()
|
|
245
|
+
.find(|o| output_name_matches(o.name.as_str(), preferred))
|
|
246
|
+
{
|
|
247
|
+
return Ok(output.name.clone());
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
session
|
|
252
|
+
.outputs
|
|
253
|
+
.first()
|
|
254
|
+
.map(|o| o.name.clone())
|
|
255
|
+
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
fn read_max_length(dir: &Path) -> usize {
|
|
259
|
+
(|| -> Option<usize> {
|
|
260
|
+
let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
|
|
261
|
+
let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
|
|
262
|
+
let v = json.get("model_max_length")?;
|
|
263
|
+
let n = v
|
|
264
|
+
.as_u64()
|
|
265
|
+
.or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
|
|
266
|
+
Some((n as usize).min(8192))
|
|
267
|
+
})()
|
|
268
|
+
.unwrap_or(512)
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
#[cfg(test)]
|
|
272
|
+
mod tests {
|
|
273
|
+
use super::{infer_model_family, tune_num_threads, ModelFamily};
|
|
274
|
+
|
|
275
|
+
#[test]
|
|
276
|
+
fn infer_model_family_recognizes_known_signatures() {
|
|
277
|
+
assert_eq!(
|
|
278
|
+
infer_model_family(true, true, "last_hidden_state"),
|
|
279
|
+
ModelFamily::E5Like
|
|
280
|
+
);
|
|
281
|
+
assert_eq!(
|
|
282
|
+
infer_model_family(true, false, "last_hidden_state"),
|
|
283
|
+
ModelFamily::SiglipLike
|
|
284
|
+
);
|
|
285
|
+
assert_eq!(
|
|
286
|
+
infer_model_family(false, false, "text_embeds"),
|
|
287
|
+
ModelFamily::ClipLike
|
|
288
|
+
);
|
|
289
|
+
assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
#[test]
|
|
293
|
+
fn tune_num_threads_respects_requested_value() {
|
|
294
|
+
assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
#[test]
|
|
298
|
+
fn tune_num_threads_returns_ort_default_for_other_family() {
|
|
299
|
+
assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
fn output_basename(name: &str) -> &str {
|
|
304
|
+
name.rsplit('/').next().unwrap_or(name)
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
308
|
+
let output = session
|
|
309
|
+
.outputs
|
|
310
|
+
.iter()
|
|
311
|
+
.find(|o| o.name == output_tensor)
|
|
312
|
+
.ok_or_else(|| {
|
|
313
|
+
GteError::Inference(format!(
|
|
314
|
+
"output tensor '{}' not found in model outputs",
|
|
315
|
+
output_tensor
|
|
316
|
+
))
|
|
317
|
+
})?;
|
|
318
|
+
|
|
319
|
+
let ndims = match &output.output_type {
|
|
320
|
+
ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
|
|
321
|
+
other => {
|
|
322
|
+
return Err(GteError::Inference(format!(
|
|
323
|
+
"output is not a tensor: {:?}",
|
|
324
|
+
other
|
|
325
|
+
)))
|
|
326
|
+
}
|
|
327
|
+
};
|
|
328
|
+
|
|
329
|
+
match (output_basename(output_tensor), ndims) {
|
|
330
|
+
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
331
|
+
(_, 2) => Ok(ExtractorMode::Raw),
|
|
332
|
+
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
333
|
+
(_, n) => Err(GteError::Inference(format!(
|
|
334
|
+
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
|
|
335
|
+
n, output_tensor
|
|
336
|
+
))),
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
|
|
341
|
+
normalize_l2_rows(embeddings)
|
|
342
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
#[derive(Debug)]
|
|
2
|
+
pub enum GteError {
|
|
3
|
+
Tokenizer(String),
|
|
4
|
+
Inference(String),
|
|
5
|
+
Ort(String),
|
|
6
|
+
Shape(String),
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
impl std::fmt::Display for GteError {
|
|
10
|
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
11
|
+
match self {
|
|
12
|
+
GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {}", msg),
|
|
13
|
+
GteError::Inference(msg) => write!(f, "GTE inference error: {}", msg),
|
|
14
|
+
GteError::Ort(msg) => write!(f, "GTE ORT error: {}", msg),
|
|
15
|
+
GteError::Shape(msg) => write!(f, "GTE shape error: {}", msg),
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
impl std::error::Error for GteError {}
|
|
21
|
+
|
|
22
|
+
impl From<ort::Error> for GteError {
|
|
23
|
+
fn from(e: ort::Error) -> Self {
|
|
24
|
+
GteError::Ort(e.to_string())
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl From<ndarray::ShapeError> for GteError {
|
|
29
|
+
fn from(e: ndarray::ShapeError) -> Self {
|
|
30
|
+
GteError::Shape(e.to_string())
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
pub type Result<T> = std::result::Result<T, GteError>;
|
|
35
|
+
|
|
36
|
+
#[cfg(feature = "ruby-ffi")]
|
|
37
|
+
impl From<GteError> for magnus::Error {
|
|
38
|
+
fn from(e: GteError) -> Self {
|
|
39
|
+
use magnus::prelude::*;
|
|
40
|
+
|
|
41
|
+
let ruby = magnus::Ruby::get().expect("From<GteError> called from Ruby thread");
|
|
42
|
+
let module = ruby.define_module("GTE").expect("GTE module must exist");
|
|
43
|
+
let gte_error_class = module
|
|
44
|
+
.const_get::<_, magnus::ExceptionClass>("Error")
|
|
45
|
+
.expect("GTE::Error must be defined before embedder methods are called");
|
|
46
|
+
magnus::Error::new(gte_error_class, e.to_string())
|
|
47
|
+
}
|
|
48
|
+
}
|