gte 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +122 -10
- data/Rakefile +8 -0
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +34 -268
- data/ext/gte/src/lib.rs +3 -0
- data/ext/gte/src/model_profile.rs +179 -0
- data/ext/gte/src/pipeline.rs +60 -0
- data/ext/gte/src/postprocess.rs +25 -2
- data/ext/gte/src/reranker.rs +120 -0
- data/ext/gte/src/ruby_embedder.rs +165 -7
- data/ext/gte/src/session.rs +9 -39
- data/ext/gte/src/tokenizer.rs +21 -2
- data/ext/gte/tests/inference_integration_test.rs +8 -4
- data/ext/gte/tests/postprocess_unit_test.rs +17 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +4 -1
- data/lib/gte/config.rb +15 -0
- data/lib/gte/model.rb +35 -0
- data/lib/gte/reranker.rb +54 -0
- data/lib/gte/version.rb +5 -0
- data/lib/gte.rb +27 -19
- metadata +10 -2
data/ext/gte/src/session.rs
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
3
|
+
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
3
4
|
use crate::postprocess::mean_pool;
|
|
4
5
|
use crate::tokenizer::Tokenized;
|
|
5
|
-
use ndarray::{Array2,
|
|
6
|
+
use ndarray::{Array2, Ix2};
|
|
6
7
|
use ort::execution_providers::{
|
|
7
8
|
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
8
9
|
};
|
|
9
10
|
use ort::session::Session;
|
|
10
|
-
use ort::session::SessionInputValue;
|
|
11
|
-
use ort::value::Value;
|
|
12
11
|
use std::path::Path;
|
|
13
12
|
|
|
14
13
|
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
@@ -43,7 +42,9 @@ fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
|
|
|
43
42
|
let mut providers = Vec::new();
|
|
44
43
|
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
45
44
|
match provider {
|
|
46
|
-
"xnnpack" =>
|
|
45
|
+
"xnnpack" => {
|
|
46
|
+
providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
|
|
47
|
+
}
|
|
47
48
|
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
48
49
|
"none" => {}
|
|
49
50
|
_ => {}
|
|
@@ -57,40 +58,9 @@ pub fn run_session(
|
|
|
57
58
|
tokenized: &Tokenized,
|
|
58
59
|
config: &ModelConfig,
|
|
59
60
|
) -> Result<Array2<f32>> {
|
|
60
|
-
let
|
|
61
|
-
|
|
62
|
-
let
|
|
63
|
-
ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
|
|
64
|
-
|
|
65
|
-
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
66
|
-
inputs.push((
|
|
67
|
-
"input_ids",
|
|
68
|
-
SessionInputValue::from(Value::from_array(input_ids_view)?),
|
|
69
|
-
));
|
|
70
|
-
if config.with_attention_mask {
|
|
71
|
-
inputs.push((
|
|
72
|
-
"attention_mask",
|
|
73
|
-
SessionInputValue::from(Value::from_array(attn_masks_view)?),
|
|
74
|
-
));
|
|
75
|
-
}
|
|
76
|
-
if let Some(type_ids) = tokenized.type_ids.as_deref() {
|
|
77
|
-
let type_ids_view: ArrayView2<'_, i64> =
|
|
78
|
-
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
79
|
-
inputs.push((
|
|
80
|
-
"token_type_ids",
|
|
81
|
-
SessionInputValue::from(Value::from_array(type_ids_view)?),
|
|
82
|
-
));
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
let outputs = session.run(inputs)?;
|
|
86
|
-
let tensor_value = outputs.get(config.output_tensor.as_str()).ok_or_else(|| {
|
|
87
|
-
GteError::Inference(format!(
|
|
88
|
-
"output tensor '{}' not found in model outputs",
|
|
89
|
-
&config.output_tensor
|
|
90
|
-
))
|
|
91
|
-
})?;
|
|
92
|
-
|
|
93
|
-
let array = tensor_value.try_extract_tensor::<f32>()?;
|
|
61
|
+
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
62
|
+
let outputs = session.run(input_tensors.inputs)?;
|
|
63
|
+
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
94
64
|
|
|
95
65
|
match config.mode {
|
|
96
66
|
ExtractorMode::Token(idx) => {
|
|
@@ -111,7 +81,7 @@ pub fn run_session(
|
|
|
111
81
|
ndim
|
|
112
82
|
))
|
|
113
83
|
})?;
|
|
114
|
-
mean_pool(hidden_states,
|
|
84
|
+
mean_pool(hidden_states.view(), input_tensors.attention_mask)
|
|
115
85
|
}
|
|
116
86
|
ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
|
|
117
87
|
}
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -61,12 +61,31 @@ impl Tokenizer {
|
|
|
61
61
|
|
|
62
62
|
build_tokenized(&encodings, self.with_type_ids)
|
|
63
63
|
}
|
|
64
|
+
|
|
65
|
+
pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
|
|
66
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = pairs
|
|
67
|
+
.iter()
|
|
68
|
+
.map(|(left, right)| (left.as_str(), right.as_str()).into())
|
|
69
|
+
.collect();
|
|
70
|
+
let encodings = self
|
|
71
|
+
.tokenizer
|
|
72
|
+
.encode_batch_fast(encode_inputs, true)
|
|
73
|
+
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
74
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
75
|
+
}
|
|
64
76
|
}
|
|
65
77
|
|
|
66
|
-
fn build_tokenized_single(
|
|
78
|
+
fn build_tokenized_single(
|
|
79
|
+
encoding: &tokenizers::Encoding,
|
|
80
|
+
with_type_ids: bool,
|
|
81
|
+
) -> Result<Tokenized> {
|
|
67
82
|
let cols = encoding.len();
|
|
68
83
|
|
|
69
|
-
let input_ids: Vec<i64> = encoding
|
|
84
|
+
let input_ids: Vec<i64> = encoding
|
|
85
|
+
.get_ids()
|
|
86
|
+
.iter()
|
|
87
|
+
.map(|&value| i64::from(value))
|
|
88
|
+
.collect();
|
|
70
89
|
let attn_masks: Vec<i64> = encoding
|
|
71
90
|
.get_attention_mask()
|
|
72
91
|
.iter()
|
|
@@ -5,7 +5,8 @@ use gte::embedder::Embedder;
|
|
|
5
5
|
fn test_e5_single_embedding_shape() {
|
|
6
6
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
7
|
|
|
8
|
-
let embedder =
|
|
8
|
+
let embedder =
|
|
9
|
+
Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
|
|
9
10
|
let result = embedder
|
|
10
11
|
.embed(vec!["query: Hello world".to_string()])
|
|
11
12
|
.expect("embed should succeed");
|
|
@@ -19,7 +20,8 @@ fn test_e5_single_embedding_shape() {
|
|
|
19
20
|
fn test_clip_single_embedding_shape() {
|
|
20
21
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
21
22
|
|
|
22
|
-
let embedder =
|
|
23
|
+
let embedder =
|
|
24
|
+
Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
|
|
23
25
|
let result = embedder
|
|
24
26
|
.embed(vec!["a photo of a cat".to_string()])
|
|
25
27
|
.expect("embed should succeed");
|
|
@@ -33,7 +35,8 @@ fn test_clip_single_embedding_shape() {
|
|
|
33
35
|
fn test_e5_batch_embedding_shape() {
|
|
34
36
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
35
37
|
|
|
36
|
-
let embedder =
|
|
38
|
+
let embedder =
|
|
39
|
+
Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
|
|
37
40
|
let texts = vec![
|
|
38
41
|
"query: first sentence".to_string(),
|
|
39
42
|
"query: second sentence".to_string(),
|
|
@@ -51,7 +54,8 @@ fn test_e5_batch_embedding_shape() {
|
|
|
51
54
|
fn test_e5_long_input_truncation_no_error() {
|
|
52
55
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
53
56
|
|
|
54
|
-
let embedder =
|
|
57
|
+
let embedder =
|
|
58
|
+
Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
|
|
55
59
|
let very_long_text = "word ".repeat(1000);
|
|
56
60
|
let result = embedder
|
|
57
61
|
.embed(vec![very_long_text])
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
use gte::postprocess::sigmoid_scores;
|
|
2
|
+
use ndarray::array;
|
|
3
|
+
|
|
4
|
+
#[test]
|
|
5
|
+
fn test_sigmoid_scores_monotonic_and_bounded() {
|
|
6
|
+
let mut scores = array![-10.0f32, -1.0, 0.0, 1.0, 10.0];
|
|
7
|
+
sigmoid_scores(scores.view_mut());
|
|
8
|
+
|
|
9
|
+
assert!(scores[0] < scores[1]);
|
|
10
|
+
assert!(scores[1] < scores[2]);
|
|
11
|
+
assert!(scores[2] < scores[3]);
|
|
12
|
+
assert!(scores[3] < scores[4]);
|
|
13
|
+
|
|
14
|
+
for score in scores.iter() {
|
|
15
|
+
assert!((*score >= 0.0) && (*score <= 1.0));
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -40,5 +40,8 @@ fn test_e5_truncation_at_max_length() {
|
|
|
40
40
|
.expect("tokenize should not error on long input");
|
|
41
41
|
|
|
42
42
|
assert_eq!(tokenized.rows, 1);
|
|
43
|
-
assert_eq!(
|
|
43
|
+
assert_eq!(
|
|
44
|
+
tokenized.cols, 16,
|
|
45
|
+
"sequence length should be truncated to max_length"
|
|
46
|
+
);
|
|
44
47
|
}
|
data/lib/gte/config.rb
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GTE
|
|
4
|
+
module Config
|
|
5
|
+
Text = Data.define(
|
|
6
|
+
:model_dir, :threads, :optimization_level,
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
Reranker = Data.define(
|
|
11
|
+
:model_dir, :threads, :optimization_level,
|
|
12
|
+
:model_name, :sigmoid, :output_tensor, :max_length
|
|
13
|
+
)
|
|
14
|
+
end
|
|
15
|
+
end
|
data/lib/gte/model.rb
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GTE
|
|
4
|
+
class Model
|
|
5
|
+
attr_reader :config
|
|
6
|
+
|
|
7
|
+
def initialize(config)
|
|
8
|
+
raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
|
|
9
|
+
|
|
10
|
+
@config = config
|
|
11
|
+
@embedder = GTE::Embedder.new(
|
|
12
|
+
config.model_dir,
|
|
13
|
+
config.threads,
|
|
14
|
+
config.optimization_level,
|
|
15
|
+
config.model_name.to_s,
|
|
16
|
+
config.normalize,
|
|
17
|
+
config.output_tensor.to_s,
|
|
18
|
+
config.max_length || 0
|
|
19
|
+
)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def embed(texts)
|
|
23
|
+
return @embedder.embed_one(texts) if texts.is_a?(String)
|
|
24
|
+
|
|
25
|
+
@embedder.embed(Array(texts))
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def [](input)
|
|
29
|
+
case input
|
|
30
|
+
when String then embed(input).row(0)
|
|
31
|
+
when Array then embed(input)
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
data/lib/gte/reranker.rb
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GTE
|
|
4
|
+
class Reranker
|
|
5
|
+
class << self
|
|
6
|
+
def config(model_dir)
|
|
7
|
+
cfg = default_config(model_dir)
|
|
8
|
+
|
|
9
|
+
if block_given?
|
|
10
|
+
yielded = yield(cfg)
|
|
11
|
+
cfg = yielded if yielded.is_a?(Config::Reranker)
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
build(cfg)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
private
|
|
18
|
+
|
|
19
|
+
def default_config(model_dir)
|
|
20
|
+
Config::Reranker.new(
|
|
21
|
+
model_dir: File.expand_path(model_dir),
|
|
22
|
+
threads: 3,
|
|
23
|
+
optimization_level: 3,
|
|
24
|
+
model_name: nil,
|
|
25
|
+
sigmoid: false,
|
|
26
|
+
output_tensor: nil,
|
|
27
|
+
max_length: nil
|
|
28
|
+
)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def build(cfg)
|
|
32
|
+
new(
|
|
33
|
+
cfg.model_dir,
|
|
34
|
+
cfg.threads,
|
|
35
|
+
cfg.optimization_level,
|
|
36
|
+
cfg.model_name.to_s,
|
|
37
|
+
cfg.sigmoid,
|
|
38
|
+
cfg.output_tensor.to_s,
|
|
39
|
+
cfg.max_length || 0
|
|
40
|
+
)
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def rerank(query:, candidates:)
|
|
45
|
+
rows = Array(candidates).map(&:to_s)
|
|
46
|
+
scores = score(query.to_s, rows)
|
|
47
|
+
|
|
48
|
+
rows
|
|
49
|
+
.each_with_index
|
|
50
|
+
.map { |text, idx| { index: idx, score: scores[idx], text: text } }
|
|
51
|
+
.sort_by { |row| -row[:score] }
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
data/lib/gte/version.rb
ADDED
data/lib/gte.rb
CHANGED
|
@@ -1,36 +1,44 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
+
require 'gte/version'
|
|
4
|
+
|
|
3
5
|
begin
|
|
4
6
|
require "gte/#{RUBY_VERSION.to_f}/gte"
|
|
5
7
|
rescue LoadError
|
|
6
8
|
require 'gte/gte'
|
|
7
9
|
end
|
|
8
10
|
|
|
11
|
+
require 'gte/config'
|
|
12
|
+
require 'gte/model'
|
|
13
|
+
require 'gte/reranker'
|
|
14
|
+
|
|
9
15
|
module GTE
|
|
10
|
-
|
|
16
|
+
@model_cache_mutex = Mutex.new
|
|
17
|
+
@model_cache = {}
|
|
11
18
|
|
|
12
|
-
class
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
|
|
19
|
+
class << self
|
|
20
|
+
def config(model_dir)
|
|
21
|
+
cfg = Config::Text.new(
|
|
22
|
+
model_dir: File.expand_path(model_dir),
|
|
23
|
+
threads: 3,
|
|
24
|
+
optimization_level: 3,
|
|
25
|
+
model_name: nil,
|
|
26
|
+
normalize: true,
|
|
27
|
+
output_tensor: nil,
|
|
28
|
+
max_length: nil
|
|
29
|
+
)
|
|
16
30
|
|
|
17
|
-
|
|
18
|
-
if texts.is_a?(String)
|
|
19
|
-
@embedder.embed_one(texts)
|
|
20
|
-
else
|
|
21
|
-
@embedder.embed(Array(texts))
|
|
22
|
-
end
|
|
23
|
-
end
|
|
31
|
+
cfg = yield(cfg) if block_given?
|
|
24
32
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
when String then embed(input).row(0)
|
|
28
|
-
when Array then embed(input)
|
|
33
|
+
@model_cache_mutex.synchronize do
|
|
34
|
+
@model_cache[cache_key(cfg)] ||= Model.new(cfg)
|
|
29
35
|
end
|
|
30
36
|
end
|
|
31
|
-
end
|
|
32
37
|
|
|
33
|
-
|
|
34
|
-
|
|
38
|
+
private
|
|
39
|
+
|
|
40
|
+
def cache_key(cfg)
|
|
41
|
+
cfg.to_h
|
|
42
|
+
end
|
|
35
43
|
end
|
|
36
44
|
end
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.5
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2026-04-
|
|
11
|
+
date: 2026-04-15 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: rake
|
|
@@ -101,14 +101,22 @@ files:
|
|
|
101
101
|
- ext/gte/src/error.rs
|
|
102
102
|
- ext/gte/src/lib.rs
|
|
103
103
|
- ext/gte/src/model_config.rs
|
|
104
|
+
- ext/gte/src/model_profile.rs
|
|
105
|
+
- ext/gte/src/pipeline.rs
|
|
104
106
|
- ext/gte/src/postprocess.rs
|
|
107
|
+
- ext/gte/src/reranker.rs
|
|
105
108
|
- ext/gte/src/ruby_embedder.rs
|
|
106
109
|
- ext/gte/src/session.rs
|
|
107
110
|
- ext/gte/src/tokenizer.rs
|
|
108
111
|
- ext/gte/tests/embedder_unit_test.rs
|
|
109
112
|
- ext/gte/tests/inference_integration_test.rs
|
|
113
|
+
- ext/gte/tests/postprocess_unit_test.rs
|
|
110
114
|
- ext/gte/tests/tokenizer_unit_test.rs
|
|
111
115
|
- lib/gte.rb
|
|
116
|
+
- lib/gte/config.rb
|
|
117
|
+
- lib/gte/model.rb
|
|
118
|
+
- lib/gte/reranker.rb
|
|
119
|
+
- lib/gte/version.rb
|
|
112
120
|
homepage: https://github.com/elcuervo/gte
|
|
113
121
|
licenses:
|
|
114
122
|
- MIT
|