gte 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,13 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::pipeline::{extract_output_tensor, InputTensors};
3
4
  use crate::postprocess::mean_pool;
4
5
  use crate::tokenizer::Tokenized;
5
- use ndarray::{Array2, ArrayView2, Ix2};
6
+ use ndarray::{Array2, Ix2};
6
7
  use ort::execution_providers::{
7
8
  CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
8
9
  };
9
10
  use ort::session::Session;
10
- use ort::session::SessionInputValue;
11
- use ort::value::Value;
12
11
  use std::path::Path;
13
12
 
14
13
  pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
@@ -43,7 +42,9 @@ fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
43
42
  let mut providers = Vec::new();
44
43
  for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
45
44
  match provider {
46
- "xnnpack" => providers.push(XNNPACKExecutionProvider::default().build().fail_silently()),
45
+ "xnnpack" => {
46
+ providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
47
+ }
47
48
  "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
48
49
  "none" => {}
49
50
  _ => {}
@@ -57,40 +58,9 @@ pub fn run_session(
57
58
  tokenized: &Tokenized,
58
59
  config: &ModelConfig,
59
60
  ) -> Result<Array2<f32>> {
60
- let input_ids_view: ArrayView2<'_, i64> =
61
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
62
- let attn_masks_view: ArrayView2<'_, i64> =
63
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
64
-
65
- let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
66
- inputs.push((
67
- "input_ids",
68
- SessionInputValue::from(Value::from_array(input_ids_view)?),
69
- ));
70
- if config.with_attention_mask {
71
- inputs.push((
72
- "attention_mask",
73
- SessionInputValue::from(Value::from_array(attn_masks_view)?),
74
- ));
75
- }
76
- if let Some(type_ids) = tokenized.type_ids.as_deref() {
77
- let type_ids_view: ArrayView2<'_, i64> =
78
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
79
- inputs.push((
80
- "token_type_ids",
81
- SessionInputValue::from(Value::from_array(type_ids_view)?),
82
- ));
83
- }
84
-
85
- let outputs = session.run(inputs)?;
86
- let tensor_value = outputs.get(config.output_tensor.as_str()).ok_or_else(|| {
87
- GteError::Inference(format!(
88
- "output tensor '{}' not found in model outputs",
89
- &config.output_tensor
90
- ))
91
- })?;
92
-
93
- let array = tensor_value.try_extract_tensor::<f32>()?;
61
+ let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
62
+ let outputs = session.run(input_tensors.inputs)?;
63
+ let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
94
64
 
95
65
  match config.mode {
96
66
  ExtractorMode::Token(idx) => {
@@ -111,7 +81,7 @@ pub fn run_session(
111
81
  ndim
112
82
  ))
113
83
  })?;
114
- mean_pool(hidden_states, attn_masks_view)
84
+ mean_pool(hidden_states.view(), input_tensors.attention_mask)
115
85
  }
116
86
  ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
117
87
  }
@@ -61,12 +61,31 @@ impl Tokenizer {
61
61
 
62
62
  build_tokenized(&encodings, self.with_type_ids)
63
63
  }
64
+
65
+ pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
66
+ let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = pairs
67
+ .iter()
68
+ .map(|(left, right)| (left.as_str(), right.as_str()).into())
69
+ .collect();
70
+ let encodings = self
71
+ .tokenizer
72
+ .encode_batch_fast(encode_inputs, true)
73
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
74
+ build_tokenized(&encodings, self.with_type_ids)
75
+ }
64
76
  }
65
77
 
66
- fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Result<Tokenized> {
78
+ fn build_tokenized_single(
79
+ encoding: &tokenizers::Encoding,
80
+ with_type_ids: bool,
81
+ ) -> Result<Tokenized> {
67
82
  let cols = encoding.len();
68
83
 
69
- let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&value| i64::from(value)).collect();
84
+ let input_ids: Vec<i64> = encoding
85
+ .get_ids()
86
+ .iter()
87
+ .map(|&value| i64::from(value))
88
+ .collect();
70
89
  let attn_masks: Vec<i64> = encoding
71
90
  .get_attention_mask()
72
91
  .iter()
@@ -5,7 +5,8 @@ use gte::embedder::Embedder;
5
5
  fn test_e5_single_embedding_shape() {
6
6
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
7
 
8
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
8
+ let embedder =
9
+ Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
9
10
  let result = embedder
10
11
  .embed(vec!["query: Hello world".to_string()])
11
12
  .expect("embed should succeed");
@@ -19,7 +20,8 @@ fn test_e5_single_embedding_shape() {
19
20
  fn test_clip_single_embedding_shape() {
20
21
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
21
22
 
22
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
23
+ let embedder =
24
+ Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
23
25
  let result = embedder
24
26
  .embed(vec!["a photo of a cat".to_string()])
25
27
  .expect("embed should succeed");
@@ -33,7 +35,8 @@ fn test_clip_single_embedding_shape() {
33
35
  fn test_e5_batch_embedding_shape() {
34
36
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
35
37
 
36
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
38
+ let embedder =
39
+ Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
37
40
  let texts = vec![
38
41
  "query: first sentence".to_string(),
39
42
  "query: second sentence".to_string(),
@@ -51,7 +54,8 @@ fn test_e5_batch_embedding_shape() {
51
54
  fn test_e5_long_input_truncation_no_error() {
52
55
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
53
56
 
54
- let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
57
+ let embedder =
58
+ Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
55
59
  let very_long_text = "word ".repeat(1000);
56
60
  let result = embedder
57
61
  .embed(vec![very_long_text])
@@ -0,0 +1,17 @@
1
+ use gte::postprocess::sigmoid_scores;
2
+ use ndarray::array;
3
+
4
+ #[test]
5
+ fn test_sigmoid_scores_monotonic_and_bounded() {
6
+ let mut scores = array![-10.0f32, -1.0, 0.0, 1.0, 10.0];
7
+ sigmoid_scores(scores.view_mut());
8
+
9
+ assert!(scores[0] < scores[1]);
10
+ assert!(scores[1] < scores[2]);
11
+ assert!(scores[2] < scores[3]);
12
+ assert!(scores[3] < scores[4]);
13
+
14
+ for score in scores.iter() {
15
+ assert!((*score >= 0.0) && (*score <= 1.0));
16
+ }
17
+ }
@@ -40,5 +40,8 @@ fn test_e5_truncation_at_max_length() {
40
40
  .expect("tokenize should not error on long input");
41
41
 
42
42
  assert_eq!(tokenized.rows, 1);
43
- assert_eq!(tokenized.cols, 16, "sequence length should be truncated to max_length");
43
+ assert_eq!(
44
+ tokenized.cols, 16,
45
+ "sequence length should be truncated to max_length"
46
+ );
44
47
  }
data/lib/gte/config.rb ADDED
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ module Config
5
+ Text = Data.define(
6
+ :model_dir, :threads, :optimization_level,
7
+ :model_name, :normalize, :output_tensor, :max_length
8
+ )
9
+
10
+ Reranker = Data.define(
11
+ :model_dir, :threads, :optimization_level,
12
+ :model_name, :sigmoid, :output_tensor, :max_length
13
+ )
14
+ end
15
+ end
data/lib/gte/model.rb ADDED
@@ -0,0 +1,35 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Model
5
+ attr_reader :config
6
+
7
+ def initialize(config)
8
+ raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
9
+
10
+ @config = config
11
+ @embedder = GTE::Embedder.new(
12
+ config.model_dir,
13
+ config.threads,
14
+ config.optimization_level,
15
+ config.model_name.to_s,
16
+ config.normalize,
17
+ config.output_tensor.to_s,
18
+ config.max_length || 0
19
+ )
20
+ end
21
+
22
+ def embed(texts)
23
+ return @embedder.embed_one(texts) if texts.is_a?(String)
24
+
25
+ @embedder.embed(Array(texts))
26
+ end
27
+
28
+ def [](input)
29
+ case input
30
+ when String then embed(input).row(0)
31
+ when Array then embed(input)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,54 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Reranker
5
+ class << self
6
+ def config(model_dir)
7
+ cfg = default_config(model_dir)
8
+
9
+ if block_given?
10
+ yielded = yield(cfg)
11
+ cfg = yielded if yielded.is_a?(Config::Reranker)
12
+ end
13
+
14
+ build(cfg)
15
+ end
16
+
17
+ private
18
+
19
+ def default_config(model_dir)
20
+ Config::Reranker.new(
21
+ model_dir: File.expand_path(model_dir),
22
+ threads: 3,
23
+ optimization_level: 3,
24
+ model_name: nil,
25
+ sigmoid: false,
26
+ output_tensor: nil,
27
+ max_length: nil
28
+ )
29
+ end
30
+
31
+ def build(cfg)
32
+ new(
33
+ cfg.model_dir,
34
+ cfg.threads,
35
+ cfg.optimization_level,
36
+ cfg.model_name.to_s,
37
+ cfg.sigmoid,
38
+ cfg.output_tensor.to_s,
39
+ cfg.max_length || 0
40
+ )
41
+ end
42
+ end
43
+
44
+ def rerank(query:, candidates:)
45
+ rows = Array(candidates).map(&:to_s)
46
+ scores = score(query.to_s, rows)
47
+
48
+ rows
49
+ .each_with_index
50
+ .map { |text, idx| { index: idx, score: scores[idx], text: text } }
51
+ .sort_by { |row| -row[:score] }
52
+ end
53
+ end
54
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ VERSION = File.read(File.expand_path('../../VERSION', __dir__)).strip
5
+ end
data/lib/gte.rb CHANGED
@@ -1,36 +1,44 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'gte/version'
4
+
3
5
  begin
4
6
  require "gte/#{RUBY_VERSION.to_f}/gte"
5
7
  rescue LoadError
6
8
  require 'gte/gte'
7
9
  end
8
10
 
11
+ require 'gte/config'
12
+ require 'gte/model'
13
+ require 'gte/reranker'
14
+
9
15
  module GTE
10
- VERSION = File.read(File.expand_path('../VERSION', __dir__)).strip
16
+ @model_cache_mutex = Mutex.new
17
+ @model_cache = {}
11
18
 
12
- class Model
13
- def initialize(dir, num_threads: 0, optimization_level: 3, model_name: nil)
14
- @embedder = GTE::Embedder.new(dir, num_threads, optimization_level, model_name.to_s)
15
- end
19
+ class << self
20
+ def config(model_dir)
21
+ cfg = Config::Text.new(
22
+ model_dir: File.expand_path(model_dir),
23
+ threads: 3,
24
+ optimization_level: 3,
25
+ model_name: nil,
26
+ normalize: true,
27
+ output_tensor: nil,
28
+ max_length: nil
29
+ )
16
30
 
17
- def embed(texts)
18
- if texts.is_a?(String)
19
- @embedder.embed_one(texts)
20
- else
21
- @embedder.embed(Array(texts))
22
- end
23
- end
31
+ cfg = yield(cfg) if block_given?
24
32
 
25
- def [](input)
26
- case input
27
- when String then embed(input).row(0)
28
- when Array then embed(input)
33
+ @model_cache_mutex.synchronize do
34
+ @model_cache[cache_key(cfg)] ||= Model.new(cfg)
29
35
  end
30
36
  end
31
- end
32
37
 
33
- def self.new(dir, num_threads: 0, optimization_level: 3, model_name: nil)
34
- Model.new(dir, num_threads: num_threads, optimization_level: optimization_level, model_name: model_name)
38
+ private
39
+
40
+ def cache_key(cfg)
41
+ cfg.to_h
42
+ end
35
43
  end
36
44
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.0.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2026-04-10 00:00:00.000000000 Z
11
+ date: 2026-04-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -101,14 +101,22 @@ files:
101
101
  - ext/gte/src/error.rs
102
102
  - ext/gte/src/lib.rs
103
103
  - ext/gte/src/model_config.rs
104
+ - ext/gte/src/model_profile.rs
105
+ - ext/gte/src/pipeline.rs
104
106
  - ext/gte/src/postprocess.rs
107
+ - ext/gte/src/reranker.rs
105
108
  - ext/gte/src/ruby_embedder.rs
106
109
  - ext/gte/src/session.rs
107
110
  - ext/gte/src/tokenizer.rs
108
111
  - ext/gte/tests/embedder_unit_test.rs
109
112
  - ext/gte/tests/inference_integration_test.rs
113
+ - ext/gte/tests/postprocess_unit_test.rs
110
114
  - ext/gte/tests/tokenizer_unit_test.rs
111
115
  - lib/gte.rb
116
+ - lib/gte/config.rb
117
+ - lib/gte/model.rb
118
+ - lib/gte/reranker.rb
119
+ - lib/gte/version.rb
112
120
  homepage: https://github.com/elcuervo/gte
113
121
  licenses:
114
122
  - MIT