gte 0.0.4 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,13 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::pipeline::{extract_output_tensor, InputTensors};
3
4
  use crate::postprocess::mean_pool;
4
5
  use crate::tokenizer::Tokenized;
5
- use ndarray::{Array2, ArrayView2, Ix2};
6
+ use ndarray::{Array2, Ix2};
6
7
  use ort::execution_providers::{
7
8
  CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
8
9
  };
9
10
  use ort::session::Session;
10
- use ort::session::SessionInputValue;
11
- use ort::value::Value;
12
11
  use std::path::Path;
13
12
 
14
13
  pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
@@ -23,7 +22,7 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
23
22
  .with_optimization_level(opt_level)?
24
23
  .with_memory_pattern(true)?;
25
24
 
26
- let providers = preferred_execution_providers();
25
+ let providers = preferred_execution_providers(config.execution_providers.as_deref());
27
26
  if !providers.is_empty() {
28
27
  builder = builder.with_execution_providers(providers)?;
29
28
  }
@@ -35,17 +34,40 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
35
34
  Ok(builder.commit_from_file(model_path)?)
36
35
  }
37
36
 
38
- fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
39
- let order = std::env::var("GTE_EXECUTION_PROVIDERS")
40
- .unwrap_or_else(|_| "xnnpack".to_string())
41
- .to_ascii_lowercase();
37
+ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
38
+ let order = resolve_provider_order(order_override);
42
39
 
43
40
  let mut providers = Vec::new();
44
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
41
+ for provider in parse_provider_registrations(order.as_str()) {
45
42
  match provider {
46
- "xnnpack" => providers.push(XNNPACKExecutionProvider::default().build().fail_silently()),
43
+ "xnnpack" => {
44
+ providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
45
+ }
47
46
  "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
48
- "none" => {}
47
+ _ => {}
48
+ }
49
+ }
50
+ providers
51
+ }
52
+
53
+ fn resolve_provider_order(order_override: Option<&str>) -> String {
54
+ let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
55
+ resolve_provider_order_with_env(order_override, env_order.as_deref())
56
+ }
57
+
58
+ fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
59
+ order_override
60
+ .or(env_order)
61
+ .unwrap_or("cpu")
62
+ .to_ascii_lowercase()
63
+ }
64
+
65
+ fn parse_provider_registrations(order: &str) -> Vec<&str> {
66
+ let mut providers = Vec::new();
67
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
68
+ match provider {
69
+ "xnnpack" | "coreml" => providers.push(provider),
70
+ "none" | "cpu" => {}
49
71
  _ => {}
50
72
  }
51
73
  }
@@ -57,40 +79,9 @@ pub fn run_session(
57
79
  tokenized: &Tokenized,
58
80
  config: &ModelConfig,
59
81
  ) -> Result<Array2<f32>> {
60
- let input_ids_view: ArrayView2<'_, i64> =
61
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
62
- let attn_masks_view: ArrayView2<'_, i64> =
63
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
64
-
65
- let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
66
- inputs.push((
67
- "input_ids",
68
- SessionInputValue::from(Value::from_array(input_ids_view)?),
69
- ));
70
- if config.with_attention_mask {
71
- inputs.push((
72
- "attention_mask",
73
- SessionInputValue::from(Value::from_array(attn_masks_view)?),
74
- ));
75
- }
76
- if let Some(type_ids) = tokenized.type_ids.as_deref() {
77
- let type_ids_view: ArrayView2<'_, i64> =
78
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
79
- inputs.push((
80
- "token_type_ids",
81
- SessionInputValue::from(Value::from_array(type_ids_view)?),
82
- ));
83
- }
84
-
85
- let outputs = session.run(inputs)?;
86
- let tensor_value = outputs.get(config.output_tensor.as_str()).ok_or_else(|| {
87
- GteError::Inference(format!(
88
- "output tensor '{}' not found in model outputs",
89
- &config.output_tensor
90
- ))
91
- })?;
92
-
93
- let array = tensor_value.try_extract_tensor::<f32>()?;
82
+ let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
83
+ let outputs = session.run(input_tensors.inputs)?;
84
+ let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
94
85
 
95
86
  match config.mode {
96
87
  ExtractorMode::Token(idx) => {
@@ -111,8 +102,47 @@ pub fn run_session(
111
102
  ndim
112
103
  ))
113
104
  })?;
114
- mean_pool(hidden_states, attn_masks_view)
105
+ mean_pool(hidden_states.view(), input_tensors.attention_mask)
115
106
  }
116
107
  ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
117
108
  }
118
109
  }
110
+
111
+ #[cfg(test)]
112
+ mod tests {
113
+ use super::{parse_provider_registrations, resolve_provider_order_with_env};
114
+
115
+ #[test]
116
+ fn parse_provider_registrations_keeps_supported_order() {
117
+ let parsed = parse_provider_registrations("xnnpack,coreml");
118
+ assert_eq!(parsed, vec!["xnnpack", "coreml"]);
119
+ }
120
+
121
+ #[test]
122
+ fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
123
+ assert!(parse_provider_registrations("cpu").is_empty());
124
+ assert!(parse_provider_registrations("none").is_empty());
125
+ assert!(parse_provider_registrations("none,cpu").is_empty());
126
+ }
127
+
128
+ #[test]
129
+ fn parse_provider_registrations_ignores_unknowns_and_empties() {
130
+ let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
131
+ assert_eq!(parsed, vec!["xnnpack", "coreml"]);
132
+ }
133
+
134
+ #[test]
135
+ fn resolve_provider_order_prefers_override() {
136
+ assert_eq!(
137
+ resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
138
+ "xnnpack"
139
+ );
140
+ assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
141
+ }
142
+
143
+ #[test]
144
+ fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
145
+ assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
146
+ assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
147
+ }
148
+ }
@@ -61,12 +61,31 @@ impl Tokenizer {
61
61
 
62
62
  build_tokenized(&encodings, self.with_type_ids)
63
63
  }
64
+
65
+ pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
66
+ let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = pairs
67
+ .iter()
68
+ .map(|(left, right)| (left.as_str(), right.as_str()).into())
69
+ .collect();
70
+ let encodings = self
71
+ .tokenizer
72
+ .encode_batch_fast(encode_inputs, true)
73
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
74
+ build_tokenized(&encodings, self.with_type_ids)
75
+ }
64
76
  }
65
77
 
66
- fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Result<Tokenized> {
78
+ fn build_tokenized_single(
79
+ encoding: &tokenizers::Encoding,
80
+ with_type_ids: bool,
81
+ ) -> Result<Tokenized> {
67
82
  let cols = encoding.len();
68
83
 
69
- let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&value| i64::from(value)).collect();
84
+ let input_ids: Vec<i64> = encoding
85
+ .get_ids()
86
+ .iter()
87
+ .map(|&value| i64::from(value))
88
+ .collect();
70
89
  let attn_masks: Vec<i64> = encoding
71
90
  .get_attention_mask()
72
91
  .iter()
@@ -5,7 +5,8 @@ use gte::embedder::Embedder;
5
5
  fn test_e5_single_embedding_shape() {
6
6
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
7
 
8
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
8
+ let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
9
+ .expect("embedder should initialize");
9
10
  let result = embedder
10
11
  .embed(vec!["query: Hello world".to_string()])
11
12
  .expect("embed should succeed");
@@ -19,7 +20,8 @@ fn test_e5_single_embedding_shape() {
19
20
  fn test_clip_single_embedding_shape() {
20
21
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
21
22
 
22
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
23
+ let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
24
+ .expect("embedder should initialize");
23
25
  let result = embedder
24
26
  .embed(vec!["a photo of a cat".to_string()])
25
27
  .expect("embed should succeed");
@@ -33,7 +35,8 @@ fn test_clip_single_embedding_shape() {
33
35
  fn test_e5_batch_embedding_shape() {
34
36
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
35
37
 
36
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
38
+ let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
39
+ .expect("embedder should initialize");
37
40
  let texts = vec![
38
41
  "query: first sentence".to_string(),
39
42
  "query: second sentence".to_string(),
@@ -51,7 +54,8 @@ fn test_e5_batch_embedding_shape() {
51
54
  fn test_e5_long_input_truncation_no_error() {
52
55
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
53
56
 
54
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
57
+ let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
58
+ .expect("embedder should initialize");
55
59
  let very_long_text = "word ".repeat(1000);
56
60
  let result = embedder
57
61
  .embed(vec![very_long_text])
@@ -0,0 +1,17 @@
1
+ use gte::postprocess::sigmoid_scores;
2
+ use ndarray::array;
3
+
4
+ #[test]
5
+ fn test_sigmoid_scores_monotonic_and_bounded() {
6
+ let mut scores = array![-10.0f32, -1.0, 0.0, 1.0, 10.0];
7
+ sigmoid_scores(scores.view_mut());
8
+
9
+ assert!(scores[0] < scores[1]);
10
+ assert!(scores[1] < scores[2]);
11
+ assert!(scores[2] < scores[3]);
12
+ assert!(scores[3] < scores[4]);
13
+
14
+ for score in scores.iter() {
15
+ assert!((*score >= 0.0) && (*score <= 1.0));
16
+ }
17
+ }
@@ -40,5 +40,8 @@ fn test_e5_truncation_at_max_length() {
40
40
  .expect("tokenize should not error on long input");
41
41
 
42
42
  assert_eq!(tokenized.rows, 1);
43
- assert_eq!(tokenized.cols, 16, "sequence length should be truncated to max_length");
43
+ assert_eq!(
44
+ tokenized.cols, 16,
45
+ "sequence length should be truncated to max_length"
46
+ );
44
47
  }
data/lib/gte/config.rb ADDED
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ module Config
5
+ Text = Data.define(
6
+ :model_dir, :threads, :optimization_level,
7
+ :model_name, :normalize, :output_tensor, :max_length, :execution_providers
8
+ )
9
+
10
+ Reranker = Data.define(
11
+ :model_dir, :threads, :optimization_level,
12
+ :model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
13
+ )
14
+ end
15
+ end
@@ -0,0 +1,41 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Embedder
5
+ class << self
6
+ def config(model_dir)
7
+ cfg = default_config(model_dir)
8
+ cfg = yield(cfg) if block_given?
9
+ from_config(cfg)
10
+ end
11
+
12
+ def from_config(config)
13
+ new(
14
+ config.model_dir,
15
+ config.threads,
16
+ config.optimization_level,
17
+ config.model_name.to_s,
18
+ config.normalize,
19
+ config.output_tensor.to_s,
20
+ config.max_length || 0,
21
+ config.execution_providers.to_s
22
+ )
23
+ end
24
+
25
+ private
26
+
27
+ def default_config(model_dir)
28
+ Config::Text.new(
29
+ model_dir: File.expand_path(model_dir),
30
+ threads: 3,
31
+ optimization_level: 3,
32
+ model_name: nil,
33
+ normalize: true,
34
+ output_tensor: nil,
35
+ max_length: nil,
36
+ execution_providers: nil
37
+ )
38
+ end
39
+ end
40
+ end
41
+ end
data/lib/gte/model.rb ADDED
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Model
5
+ attr_reader :config
6
+
7
+ def initialize(config)
8
+ raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
9
+
10
+ @config = config
11
+ @embedder = GTE::Embedder.from_config(config)
12
+ end
13
+
14
+ def embed(texts)
15
+ return @embedder.embed_one(texts) if texts.is_a?(String)
16
+
17
+ @embedder.embed(Array(texts))
18
+ end
19
+
20
+ def [](input)
21
+ case input
22
+ when String then embed(input).row(0)
23
+ when Array then embed(input)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,56 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Reranker
5
+ class << self
6
+ def config(model_dir)
7
+ cfg = default_config(model_dir)
8
+
9
+ if block_given?
10
+ yielded = yield(cfg)
11
+ cfg = yielded if yielded.is_a?(Config::Reranker)
12
+ end
13
+
14
+ build(cfg)
15
+ end
16
+
17
+ private
18
+
19
+ def default_config(model_dir)
20
+ Config::Reranker.new(
21
+ model_dir: File.expand_path(model_dir),
22
+ threads: 3,
23
+ optimization_level: 3,
24
+ model_name: nil,
25
+ sigmoid: false,
26
+ output_tensor: nil,
27
+ max_length: nil,
28
+ execution_providers: nil
29
+ )
30
+ end
31
+
32
+ def build(cfg)
33
+ new(
34
+ cfg.model_dir,
35
+ cfg.threads,
36
+ cfg.optimization_level,
37
+ cfg.model_name.to_s,
38
+ cfg.sigmoid,
39
+ cfg.output_tensor.to_s,
40
+ cfg.max_length || 0,
41
+ cfg.execution_providers.to_s
42
+ )
43
+ end
44
+ end
45
+
46
+ def rerank(query:, candidates:)
47
+ rows = Array(candidates).map(&:to_s)
48
+ scores = score(query.to_s, rows)
49
+
50
+ rows
51
+ .each_with_index
52
+ .map { |text, idx| { index: idx, score: scores[idx], text: text } }
53
+ .sort_by { |row| -row[:score] }
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ VERSION = File.read(File.expand_path('../../VERSION', __dir__)).strip
5
+ end
data/lib/gte.rb CHANGED
@@ -1,55 +1,46 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'gte/version'
4
+
3
5
  begin
4
6
  require "gte/#{RUBY_VERSION.to_f}/gte"
5
7
  rescue LoadError
6
8
  require 'gte/gte'
7
9
  end
8
10
 
9
- module GTE
10
- VERSION = File.read(File.expand_path('../VERSION', __dir__)).strip
11
+ require 'gte/config'
12
+ require 'gte/embedder'
13
+ require 'gte/model'
14
+ require 'gte/reranker'
11
15
 
16
+ module GTE
12
17
  @model_cache_mutex = Mutex.new
13
18
  @model_cache = {}
14
19
 
15
- class Model
16
- def initialize(dir, num_threads: 0, optimization_level: 3, model_name: nil)
17
- @embedder = GTE::Embedder.new(dir, num_threads, optimization_level, model_name.to_s)
18
- end
20
+ class << self
21
+ def config(model_dir)
22
+ cfg = Config::Text.new(
23
+ model_dir: File.expand_path(model_dir),
24
+ threads: 3,
25
+ optimization_level: 3,
26
+ model_name: nil,
27
+ normalize: true,
28
+ output_tensor: nil,
29
+ max_length: nil,
30
+ execution_providers: nil
31
+ )
19
32
 
20
- def embed(texts)
21
- if texts.is_a?(String)
22
- @embedder.embed_one(texts)
23
- else
24
- @embedder.embed(Array(texts))
25
- end
26
- end
33
+ cfg = yield(cfg) if block_given?
27
34
 
28
- def [](input)
29
- case input
30
- when String then embed(input).row(0)
31
- when Array then embed(input)
35
+ @model_cache_mutex.synchronize do
36
+ @model_cache[cache_key(cfg)] ||= Model.new(cfg)
32
37
  end
33
38
  end
34
- end
35
39
 
36
- def self.new(dir, threads: 0, optimization: 3, model_name: nil)
37
- key = [
38
- File.expand_path(dir),
39
- Integer(threads),
40
- Integer(optimization),
41
- model_name.to_s
42
- ].freeze
43
-
44
- @model_cache_mutex.synchronize do
45
- @model_cache[key] ||= Model.new(
46
- key[0],
47
- num_threads: key[1],
48
- optimization_level: key[2],
49
- model_name: key[3].empty? ? nil : key[3]
50
- )
40
+ private
41
+
42
+ def cache_key(cfg)
43
+ cfg.to_h
51
44
  end
52
45
  end
53
-
54
- def self.fetch(*) = new(*)
55
46
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.4
4
+ version: 0.0.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2026-04-13 00:00:00.000000000 Z
11
+ date: 2026-04-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -101,14 +101,23 @@ files:
101
101
  - ext/gte/src/error.rs
102
102
  - ext/gte/src/lib.rs
103
103
  - ext/gte/src/model_config.rs
104
+ - ext/gte/src/model_profile.rs
105
+ - ext/gte/src/pipeline.rs
104
106
  - ext/gte/src/postprocess.rs
107
+ - ext/gte/src/reranker.rs
105
108
  - ext/gte/src/ruby_embedder.rs
106
109
  - ext/gte/src/session.rs
107
110
  - ext/gte/src/tokenizer.rs
108
111
  - ext/gte/tests/embedder_unit_test.rs
109
112
  - ext/gte/tests/inference_integration_test.rs
113
+ - ext/gte/tests/postprocess_unit_test.rs
110
114
  - ext/gte/tests/tokenizer_unit_test.rs
111
115
  - lib/gte.rb
116
+ - lib/gte/config.rb
117
+ - lib/gte/embedder.rb
118
+ - lib/gte/model.rb
119
+ - lib/gte/reranker.rb
120
+ - lib/gte/version.rb
112
121
  homepage: https://github.com/elcuervo/gte
113
122
  licenses:
114
123
  - MIT