gte 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 566cb32a193255c0cf3d087a1a907cdbf0a96292ee58d4676fd18745d55ec1b2
4
+ data.tar.gz: 56479bb218282bf189ace46852a42766d4c2211230be635b6a79414cf3eb82c8
5
+ SHA512:
6
+ metadata.gz: f873b83b16e4cf2685f26b84d26f4c3b8abd90a9d32d5e60046d1aad577d409704838d3da1329f741535fef1f1a90f1edeeeb6d4fbab559569117869ae42677e
7
+ data.tar.gz: d5d5b49b8f51cbf3222409b941fa39b3f00074bf34aafdfaf74d4a3fd37ab99ec8a90bba9be641efb42176b182fe191925042ae09d324c851324b27bd62031ce
data/Gemfile ADDED
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ source 'https://rubygems.org'
4
+
5
+ gemspec
6
+
7
+ gem 'rake'
8
+ gem 'rake-compiler'
9
+ gem 'rb_sys'
10
+ gem 'rspec'
11
+ gem 'rspec-benchmark'
12
+ gem 'rubocop', require: false
13
+
14
+ group :bench do
15
+ gem 'onnxruntime'
16
+ gem 'tokenizers'
17
+ end
data/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 elcuervo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,49 @@
1
+ # gte
2
+ ![](https://images.unsplash.com/photo-1551225183-94acb7d595b6?q=80&w=2274&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D)
3
+
4
+ `gte` is a Ruby gem with a Rust extension for fast text embeddings with ONNX Runtime.
5
+ Inspired by https://github.com/fbilhaut/gte-rs
6
+
7
+ ## Quick Start
8
+
9
+ ```ruby
10
+ require "gte"
11
+
12
+ model = GTE.new(ENV.fetch("GTE_MODEL_DIR"))
13
+ vector = model["query: hello world"]
14
+ ```
15
+
16
+ ## Model Directory
17
+
18
+ A model directory must include `tokenizer.json` and one ONNX model, resolved in this order:
19
+
20
+ 1. `onnx/text_model.onnx`
21
+ 2. `text_model.onnx`
22
+ 3. `onnx/model.onnx`
23
+ 4. `model.onnx`
24
+
25
+ ## Development
26
+
27
+ Run commands inside `nix develop`.
28
+
29
+ ```bash
30
+ bundle exec rake compile
31
+ cargo test --manifest-path ext/gte/Cargo.toml --no-default-features
32
+ bundle exec rspec
33
+ ```
34
+
35
+ ## Benchmark
36
+
37
+ The repo includes two benchmark paths:
38
+
39
+ ```bash
40
+ bundle exec rake bench:pure_compare
41
+ bundle exec rake bench:puma_compare
42
+ bundle exec rake bench:matrix_sweep
43
+ ```
44
+
45
+ For release tracking and regression detection, record a run entry in `RUNS.md`:
46
+
47
+ ```bash
48
+ bundle exec rake bench:record_run
49
+ ```
data/Rakefile ADDED
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'bundler/gem_tasks'
4
+ require 'rake/extensiontask'
5
+ begin
6
+ require 'rspec/core/rake_task'
7
+ RSpec::Core::RakeTask.new(:spec)
8
+ rescue LoadError
9
+ # rspec not available in cross-compile environment
10
+ end
11
+
12
+ spec = Gem::Specification.load('gte.gemspec')
13
+
14
+ Rake::ExtensionTask.new('gte', spec) do |ext|
15
+ ext.lib_dir = 'lib/gte'
16
+ ext.cross_compile = true
17
+ ext.cross_platform = %w[x86_64-linux arm64-darwin]
18
+ end
19
+
20
+ task default: %i[compile spec]
21
+
22
+ def run_in_nix(*command)
23
+ sh('nix', 'develop', '-c', *command)
24
+ end
25
+
26
+ namespace :bench do
27
+ desc 'Run pure-Ruby (onnxruntime gem) vs GTE benchmark comparison inside nix develop'
28
+ task :pure_compare do
29
+ run_in_nix('bundle', 'exec', 'ruby', 'bench/pure_ruby_compare.rb')
30
+ end
31
+
32
+ desc 'Run Puma-like concurrent single-request benchmark (GTE vs pure Ruby)'
33
+ task :puma_compare do
34
+ run_in_nix(
35
+ 'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
36
+ '--output', 'bench/results/puma_compare_latest.json',
37
+ '--iterations', '80',
38
+ '--runs', '3'
39
+ )
40
+ end
41
+
42
+ desc 'Sweep execution-provider and thread settings for Puma-like benchmark'
43
+ task :matrix_sweep do
44
+ run_in_nix(
45
+ 'bundle', 'exec', 'ruby', 'bench/puma_matrix_sweep.rb',
46
+ '--iterations', '80',
47
+ '--runs', '3'
48
+ )
49
+ end
50
+
51
+ desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal/regression checks'
52
+ task :record_run do
53
+ run_in_nix(
54
+ 'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
55
+ '--output', 'bench/results/puma_compare_latest.json',
56
+ '--iterations', '80',
57
+ '--runs', '3'
58
+ )
59
+ run_in_nix(
60
+ 'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'append',
61
+ '--result', 'bench/results/puma_compare_latest.json'
62
+ )
63
+ run_in_nix(
64
+ 'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
65
+ '--result', 'bench/results/puma_compare_latest.json'
66
+ )
67
+ end
68
+
69
+ desc 'Validate current Puma benchmark output against 2x goal and regression policy'
70
+ task :check_goal do
71
+ run_in_nix(
72
+ 'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
73
+ '--result', 'bench/results/puma_compare_latest.json'
74
+ )
75
+ end
76
+ end
data/VERSION ADDED
@@ -0,0 +1 @@
1
+ 0.0.1
@@ -0,0 +1,37 @@
1
+ [package]
2
+ name = "gte"
3
+ version = "0.0.1"
4
+ edition = "2021"
5
+ authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
+ license = "MIT"
7
+ publish = false
8
+ build = "build.rs"
9
+
10
+ [lib]
11
+ # cdylib: Ruby FFI extension; rlib: enables integration tests in tests/ to link as external crate
12
+ crate-type = ["cdylib", "rlib"]
13
+
14
+ [features]
15
+ # ruby-ffi: gate magnus + rb-sys (Ruby C symbols) so Rust integration tests can link without Ruby.
16
+ # This feature is enabled by default for the cdylib build (rake compile / extconf.rb).
17
+ # When running `cargo test`, this feature must be excluded: `cargo test --no-default-features`.
18
+ default = ["ruby-ffi"]
19
+ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
20
+
21
+ [dependencies]
22
+ rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
23
+ magnus = { version = "0.8", optional = true }
24
+ ort = { version = "=2.0.0-rc.9", features = ["ndarray"] }
25
+ ort-sys = "=2.0.0-rc.9"
26
+ tokenizers = "0.21.0"
27
+ ndarray = "0.16.0"
28
+ half = "2"
29
+ serde = { version = "1", features = ["derive"] }
30
+ serde_json = "1"
31
+
32
+ [dev-dependencies]
33
+ criterion = "0.5"
34
+
35
+ [[bench]]
36
+ name = "hot_path"
37
+ harness = false
@@ -0,0 +1,53 @@
1
+ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
2
+ use gte::postprocess::{mean_pool, normalize_l2};
3
+ use ndarray::{Array2, Array3};
4
+
5
+ fn build_hidden_states(batch: usize, seq: usize, dim: usize) -> Array3<f32> {
6
+ Array3::from_shape_fn((batch, seq, dim), |(b, s, d)| {
7
+ (((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0
8
+ })
9
+ }
10
+
11
+ fn build_attention_mask(batch: usize, seq: usize) -> Array2<i64> {
12
+ Array2::from_shape_fn((batch, seq), |(_, s)| if s % 11 == 10 { 0 } else { 1 })
13
+ }
14
+
15
+ fn bench_mean_pool(c: &mut Criterion) {
16
+ let mut group = c.benchmark_group("mean_pool");
17
+ for (batch, seq, dim) in [(1, 32, 384), (8, 64, 384), (32, 64, 768)] {
18
+ let hidden_states = build_hidden_states(batch, seq, dim);
19
+ let attention_mask = build_attention_mask(batch, seq);
20
+ group.bench_with_input(
21
+ BenchmarkId::from_parameter(format!("{batch}x{seq}x{dim}")),
22
+ &(batch, seq, dim),
23
+ |b, _| {
24
+ b.iter(|| {
25
+ mean_pool(
26
+ black_box(hidden_states.view()),
27
+ black_box(attention_mask.view()),
28
+ )
29
+ .unwrap()
30
+ })
31
+ },
32
+ );
33
+ }
34
+ group.finish();
35
+ }
36
+
37
+ fn bench_normalize_l2(c: &mut Criterion) {
38
+ let mut group = c.benchmark_group("normalize_l2");
39
+ for (rows, dim) in [(1, 384), (8, 384), (32, 768), (128, 768)] {
40
+ let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)| {
41
+ (((row * 19 + col * 7) % 113) as f32) / 113.0
42
+ });
43
+ group.bench_with_input(
44
+ BenchmarkId::from_parameter(format!("{rows}x{dim}")),
45
+ &(rows, dim),
46
+ |b, _| b.iter(|| normalize_l2(black_box(embeddings.clone()))),
47
+ );
48
+ }
49
+ group.finish();
50
+ }
51
+
52
+ criterion_group!(benches, bench_mean_pool, bench_normalize_l2);
53
+ criterion_main!(benches);
data/ext/gte/build.rs ADDED
@@ -0,0 +1,25 @@
1
+ fn main() {
2
+ let version = std::fs::read_to_string("../../VERSION")
3
+ .expect("VERSION file not found")
4
+ .trim()
5
+ .to_string();
6
+
7
+ let cargo_version = env!("CARGO_PKG_VERSION");
8
+
9
+ assert_eq!(
10
+ version, cargo_version,
11
+ "VERSION file ({}) doesn't match Cargo.toml ({}). Update Cargo.toml to match.",
12
+ version, cargo_version
13
+ );
14
+
15
+ println!("cargo:rerun-if-changed=../../VERSION");
16
+
17
+ // Ensure the ORT shared library can be found at runtime via @rpath on macOS.
18
+ // ORT_LIB_LOCATION is set by the Nix dev shell when ORT_STRATEGY=system.
19
+ if let Ok(ort_lib) = std::env::var("ORT_LIB_LOCATION") {
20
+ let lib_dir = std::path::Path::new(&ort_lib).join("lib");
21
+ if lib_dir.exists() {
22
+ println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_dir.display());
23
+ }
24
+ }
25
+ }
@@ -0,0 +1,6 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mkmf'
4
+ require 'rb_sys/mkmf'
5
+
6
+ create_rust_makefile('gte/gte')
@@ -0,0 +1,342 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::postprocess::normalize_l2 as normalize_l2_rows;
4
+ use crate::session::{build_session, run_session};
5
+ use crate::tokenizer::{Tokenized, Tokenizer};
6
+ use ndarray::Array2;
7
+ use ort::session::Session;
8
+ use std::path::{Path, PathBuf};
9
+
10
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
11
+ pub enum ModelFamily {
12
+ E5Like,
13
+ SiglipLike,
14
+ ClipLike,
15
+ Other,
16
+ }
17
+
18
+ pub struct Embedder {
19
+ tokenizer: Tokenizer,
20
+ session: Session,
21
+ config: ModelConfig,
22
+ }
23
+
24
+ impl Embedder {
25
+ pub fn new<P1, P2>(tokenizer_path: P1, model_path: P2, config: ModelConfig) -> Result<Self>
26
+ where
27
+ P1: AsRef<Path>,
28
+ P2: AsRef<Path>,
29
+ {
30
+ let tokenizer = Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids)?;
31
+ let session = build_session(model_path, &config)?;
32
+ Ok(Self {
33
+ tokenizer,
34
+ session,
35
+ config,
36
+ })
37
+ }
38
+
39
+ pub fn from_dir<P: AsRef<Path>>(
40
+ dir: P,
41
+ num_threads: usize,
42
+ optimization_level: u8,
43
+ ) -> Result<Self> {
44
+ let dir = dir.as_ref();
45
+ let tokenizer_path = dir.join("tokenizer.json");
46
+ let model_path = resolve_model_path(dir)?;
47
+
48
+ if !tokenizer_path.exists() {
49
+ return Err(GteError::Tokenizer(format!(
50
+ "tokenizer.json not found in {}",
51
+ dir.display()
52
+ )));
53
+ }
54
+
55
+ let max_length = read_max_length(dir);
56
+ let temp_config = ModelConfig {
57
+ max_length,
58
+ output_tensor: String::new(),
59
+ mode: ExtractorMode::Raw,
60
+ with_type_ids: false,
61
+ with_attention_mask: true,
62
+ num_threads,
63
+ optimization_level,
64
+ };
65
+ let session = build_session(&model_path, &temp_config)?;
66
+
67
+ validate_supported_inputs(&session)?;
68
+ let with_type_ids = session.inputs.iter().any(|i| i.name == "token_type_ids");
69
+ let with_attention_mask = session.inputs.iter().any(|i| i.name == "attention_mask");
70
+ let output_tensor = select_output_tensor(&session)?;
71
+ let output_base = output_basename(output_tensor.as_str()).to_string();
72
+ let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
73
+ if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
74
+ return Err(GteError::Inference(
75
+ "cannot use mean pooling without attention_mask input".to_string(),
76
+ ));
77
+ }
78
+
79
+ let tuned_num_threads = tune_num_threads(
80
+ num_threads,
81
+ with_attention_mask,
82
+ with_type_ids,
83
+ output_base.as_str(),
84
+ );
85
+
86
+ let config = ModelConfig {
87
+ max_length,
88
+ output_tensor,
89
+ mode,
90
+ with_type_ids,
91
+ with_attention_mask,
92
+ num_threads: tuned_num_threads,
93
+ optimization_level,
94
+ };
95
+
96
+ let session = if tuned_num_threads != num_threads {
97
+ build_session(&model_path, &config)?
98
+ } else {
99
+ session
100
+ };
101
+
102
+ let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
103
+
104
+ Ok(Self {
105
+ tokenizer,
106
+ session,
107
+ config,
108
+ })
109
+ }
110
+
111
+ pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
112
+ let tokenized = self.tokenize(&texts)?;
113
+ self.run(&tokenized)
114
+ }
115
+
116
+ pub fn tokenize(&self, texts: &[String]) -> crate::error::Result<Tokenized> {
117
+ self.tokenizer.tokenize(texts)
118
+ }
119
+
120
+ pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
121
+ run_session(&self.session, tokenized, &self.config)
122
+ }
123
+
124
+ }
125
+
126
+ fn tune_num_threads(
127
+ requested: usize,
128
+ with_attention_mask: bool,
129
+ with_type_ids: bool,
130
+ output_name: &str,
131
+ ) -> usize {
132
+ if requested > 0 {
133
+ return requested;
134
+ }
135
+
136
+ let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
137
+ let target_concurrency = puma_target_concurrency();
138
+ let host_cores = host_parallelism();
139
+ let budgeted_threads = (host_cores / target_concurrency).max(1);
140
+
141
+ match family {
142
+ // Puma-like workloads typically run many concurrent single-item requests where
143
+ // one intra-op thread per request gives the best tail behavior.
144
+ ModelFamily::E5Like | ModelFamily::ClipLike | ModelFamily::SiglipLike => {
145
+ budgeted_threads.min(1)
146
+ }
147
+ ModelFamily::Other => 0,
148
+ }
149
+ }
150
+
151
+ fn infer_model_family(
152
+ with_attention_mask: bool,
153
+ with_type_ids: bool,
154
+ output_name: &str,
155
+ ) -> ModelFamily {
156
+ if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
157
+ return ModelFamily::E5Like;
158
+ }
159
+ if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
160
+ return ModelFamily::SiglipLike;
161
+ }
162
+ if output_name == "text_embeds" && !with_attention_mask {
163
+ return ModelFamily::ClipLike;
164
+ }
165
+ ModelFamily::Other
166
+ }
167
+
168
+ fn puma_target_concurrency() -> usize {
169
+ std::env::var("GTE_PUMA_CONCURRENCY")
170
+ .ok()
171
+ .and_then(|raw| raw.parse::<usize>().ok())
172
+ .filter(|value| *value > 0)
173
+ .unwrap_or(16)
174
+ }
175
+
176
+ fn host_parallelism() -> usize {
177
+ std::thread::available_parallelism()
178
+ .map(|n| n.get())
179
+ .unwrap_or(1)
180
+ }
181
+
182
+ fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
183
+ let candidates = [
184
+ dir.join("onnx").join("text_model.onnx"),
185
+ dir.join("text_model.onnx"),
186
+ dir.join("onnx").join("model.onnx"),
187
+ dir.join("model.onnx"),
188
+ ];
189
+ for path in &candidates {
190
+ if path.exists() {
191
+ return Ok(path.clone());
192
+ }
193
+ }
194
+ Err(GteError::Inference(format!(
195
+ "no ONNX model found in {} (checked text_model.onnx and model.onnx)",
196
+ dir.display()
197
+ )))
198
+ }
199
+
200
+ const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
201
+
202
+ fn validate_supported_inputs(session: &Session) -> Result<()> {
203
+ let unsupported: Vec<String> = session
204
+ .inputs
205
+ .iter()
206
+ .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
207
+ .map(|i| i.name.clone())
208
+ .collect();
209
+
210
+ if unsupported.is_empty() {
211
+ return Ok(());
212
+ }
213
+
214
+ let mut message = format!(
215
+ "unsupported model inputs for text embedding API: {}",
216
+ unsupported.join(", ")
217
+ );
218
+ if unsupported.iter().any(|n| n == "pixel_values") {
219
+ message.push_str(
220
+ ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
221
+ );
222
+ } else {
223
+ message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
224
+ }
225
+ Err(GteError::Inference(message))
226
+ }
227
+
228
+ fn output_name_matches(name: &str, preferred: &str) -> bool {
229
+ let lower = name.to_ascii_lowercase();
230
+ lower == preferred || lower.ends_with(&format!("/{}", preferred))
231
+ }
232
+
233
+ fn select_output_tensor(session: &Session) -> Result<String> {
234
+ const PREFERRED: [&str; 4] = [
235
+ "text_embeds",
236
+ "pooler_output",
237
+ "sentence_embedding",
238
+ "last_hidden_state",
239
+ ];
240
+
241
+ for preferred in PREFERRED {
242
+ if let Some(output) = session
243
+ .outputs
244
+ .iter()
245
+ .find(|o| output_name_matches(o.name.as_str(), preferred))
246
+ {
247
+ return Ok(output.name.clone());
248
+ }
249
+ }
250
+
251
+ session
252
+ .outputs
253
+ .first()
254
+ .map(|o| o.name.clone())
255
+ .ok_or_else(|| GteError::Inference("model has no outputs".into()))
256
+ }
257
+
258
+ fn read_max_length(dir: &Path) -> usize {
259
+ (|| -> Option<usize> {
260
+ let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
261
+ let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
262
+ let v = json.get("model_max_length")?;
263
+ let n = v
264
+ .as_u64()
265
+ .or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
266
+ Some((n as usize).min(8192))
267
+ })()
268
+ .unwrap_or(512)
269
+ }
270
+
271
+ #[cfg(test)]
272
+ mod tests {
273
+ use super::{infer_model_family, tune_num_threads, ModelFamily};
274
+
275
+ #[test]
276
+ fn infer_model_family_recognizes_known_signatures() {
277
+ assert_eq!(
278
+ infer_model_family(true, true, "last_hidden_state"),
279
+ ModelFamily::E5Like
280
+ );
281
+ assert_eq!(
282
+ infer_model_family(true, false, "last_hidden_state"),
283
+ ModelFamily::SiglipLike
284
+ );
285
+ assert_eq!(
286
+ infer_model_family(false, false, "text_embeds"),
287
+ ModelFamily::ClipLike
288
+ );
289
+ assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
290
+ }
291
+
292
+ #[test]
293
+ fn tune_num_threads_respects_requested_value() {
294
+ assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
295
+ }
296
+
297
+ #[test]
298
+ fn tune_num_threads_returns_ort_default_for_other_family() {
299
+ assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
300
+ }
301
+ }
302
+
303
+ fn output_basename(name: &str) -> &str {
304
+ name.rsplit('/').next().unwrap_or(name)
305
+ }
306
+
307
+ fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
308
+ let output = session
309
+ .outputs
310
+ .iter()
311
+ .find(|o| o.name == output_tensor)
312
+ .ok_or_else(|| {
313
+ GteError::Inference(format!(
314
+ "output tensor '{}' not found in model outputs",
315
+ output_tensor
316
+ ))
317
+ })?;
318
+
319
+ let ndims = match &output.output_type {
320
+ ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
321
+ other => {
322
+ return Err(GteError::Inference(format!(
323
+ "output is not a tensor: {:?}",
324
+ other
325
+ )))
326
+ }
327
+ };
328
+
329
+ match (output_basename(output_tensor), ndims) {
330
+ ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
331
+ (_, 2) => Ok(ExtractorMode::Raw),
332
+ (_, 3) => Ok(ExtractorMode::MeanPool),
333
+ (_, n) => Err(GteError::Inference(format!(
334
+ "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
335
+ n, output_tensor
336
+ ))),
337
+ }
338
+ }
339
+
340
+ pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
341
+ normalize_l2_rows(embeddings)
342
+ }
@@ -0,0 +1,48 @@
1
+ #[derive(Debug)]
2
+ pub enum GteError {
3
+ Tokenizer(String),
4
+ Inference(String),
5
+ Ort(String),
6
+ Shape(String),
7
+ }
8
+
9
+ impl std::fmt::Display for GteError {
10
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11
+ match self {
12
+ GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {}", msg),
13
+ GteError::Inference(msg) => write!(f, "GTE inference error: {}", msg),
14
+ GteError::Ort(msg) => write!(f, "GTE ORT error: {}", msg),
15
+ GteError::Shape(msg) => write!(f, "GTE shape error: {}", msg),
16
+ }
17
+ }
18
+ }
19
+
20
+ impl std::error::Error for GteError {}
21
+
22
+ impl From<ort::Error> for GteError {
23
+ fn from(e: ort::Error) -> Self {
24
+ GteError::Ort(e.to_string())
25
+ }
26
+ }
27
+
28
+ impl From<ndarray::ShapeError> for GteError {
29
+ fn from(e: ndarray::ShapeError) -> Self {
30
+ GteError::Shape(e.to_string())
31
+ }
32
+ }
33
+
34
+ pub type Result<T> = std::result::Result<T, GteError>;
35
+
36
+ #[cfg(feature = "ruby-ffi")]
37
+ impl From<GteError> for magnus::Error {
38
+ fn from(e: GteError) -> Self {
39
+ use magnus::prelude::*;
40
+
41
+ let ruby = magnus::Ruby::get().expect("From<GteError> called from Ruby thread");
42
+ let module = ruby.define_module("GTE").expect("GTE module must exist");
43
+ let gte_error_class = module
44
+ .const_get::<_, magnus::ExceptionClass>("Error")
45
+ .expect("GTE::Error must be defined before embedder methods are called");
46
+ magnus::Error::new(gte_error_class, e.to_string())
47
+ }
48
+ }