gte 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc149108c647dc5b14154bfbdc4975b53670b9ed3cf7d80760cc2b415c935a48
4
- data.tar.gz: 32a682a95d56c8fab8d0d64a7ada0c0347ae796b6aefe6191f9aca8fc96426c2
3
+ metadata.gz: 29659e3ab6072d858b1710a779c3d2e5981f7749782182d141ccd5e9790a1fbb
4
+ data.tar.gz: c42d51cfa1a2ba6a2e83249e8a725c978b11c7ef80c6d69f09a64e884be42031
5
5
  SHA512:
6
- metadata.gz: f5c69d954f51a51521b143b576942a9c0505ad60574c1727f963dd79e0b6c22cacc4e6d9af75394ae06f451521dbc788af51f1e79397a5cc66a41b4ce1b31933
7
- data.tar.gz: 9e75fdbc9b5c8cfdd9d0e377a7e4a944057ec604e38ab23d960c4ed75ec6a72ce1dd27c2dd1bb2802721387babdabe0996e0c42be34d17d98253e0582b375de1
6
+ metadata.gz: ff2c2b1450a6e82c07aacd2ec98437f03678d56eef9c5516f904021a54f59b2ba5c42b8f6af22b5c4b2dacea98615b99bc54d2c7cdc4e8fbccc1abc195fe9975
7
+ data.tar.gz: 04ca056458d40e2ba7fabcdbcab415a087d54802fb3bd86748dc901c2cf0ecb44072fd1820a73e3dcaca097f165df3e70bab747b38340cd738876af5f0ea7645
data/README.md CHANGED
@@ -41,6 +41,7 @@ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
41
  config.with(
42
42
  output_tensor: "last_hidden_state",
43
43
  max_length: 256,
44
+ padding: "batch_longest",
44
45
  optimization_level: 3
45
46
  )
46
47
  end
@@ -55,6 +56,7 @@ Config fields and defaults:
55
56
  - `normalize`: `true` (L2 normalization at Ruby-facing API)
56
57
  - `output_tensor`: `nil` (auto-select output tensor)
57
58
  - `max_length`: `nil` (uses tokenizer/model defaults)
59
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
58
60
  - `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
59
61
 
60
62
  Notes:
@@ -106,6 +108,7 @@ Reranker config fields and defaults:
106
108
  - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
107
109
  - `output_tensor`: `nil`
108
110
  - `max_length`: `nil`
111
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
109
112
  - `execution_providers`: `nil`
110
113
 
111
114
  ## Runtime + Result Examples
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.6
1
+ 0.0.7
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.6"
3
+ version = "0.0.7"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -1,12 +1,12 @@
1
1
  use crate::error::{GteError, Result};
2
- use crate::model_config::{ExtractorMode, ModelConfig};
2
+ use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
3
3
  use crate::model_profile::{
4
- has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
5
- resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
4
+ has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
5
+ resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
8
  use crate::session::{build_session, run_session};
9
- use crate::tokenizer::{Tokenized, Tokenizer};
9
+ use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
10
10
  use ndarray::Array2;
11
11
  use ort::session::Session;
12
12
  use std::path::Path;
@@ -23,7 +23,13 @@ impl Embedder {
23
23
  P1: AsRef<Path>,
24
24
  P2: AsRef<Path>,
25
25
  {
26
- let tokenizer = Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids)?;
26
+ let tokenizer = Tokenizer::new(
27
+ tokenizer_path,
28
+ config.max_length,
29
+ config.with_type_ids,
30
+ config.padding_mode,
31
+ None,
32
+ )?;
27
33
  let session = build_session(model_path, &config)?;
28
34
  Ok(Self {
29
35
  tokenizer,
@@ -36,10 +42,7 @@ impl Embedder {
36
42
  dir: P,
37
43
  num_threads: usize,
38
44
  optimization_level: u8,
39
- model_name: Option<&str>,
40
- output_tensor_override: Option<&str>,
41
- max_length_override: Option<usize>,
42
- execution_providers_override: Option<&str>,
45
+ overrides: ModelLoadOverrides<'_>,
43
46
  ) -> Result<Self> {
44
47
  const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
45
48
  "pooler_output",
@@ -50,31 +53,35 @@ impl Embedder {
50
53
 
51
54
  let dir = dir.as_ref();
52
55
  let tokenizer_path = resolve_tokenizer_path(dir)?;
53
- let model_path = match model_name.filter(|s| !s.is_empty()) {
56
+ let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
54
57
  Some(name) => resolve_named_model(dir, name)?,
55
58
  None => resolve_default_text_model(dir)?,
56
59
  };
57
60
 
58
- let max_length = if let Some(override_value) = max_length_override {
61
+ let tokenizer_profile = read_tokenizer_profile(dir);
62
+ let max_length = if let Some(override_value) = overrides.max_length {
59
63
  if override_value == 0 {
60
64
  return Err(GteError::Inference(
61
65
  "max_length override must be greater than 0".to_string(),
62
66
  ));
63
67
  }
64
- override_value
68
+ override_value.min(tokenizer_profile.safe_max_length)
65
69
  } else {
66
- read_max_length(dir)
70
+ tokenizer_profile.default_max_length
67
71
  };
72
+ let padding_mode =
73
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
68
74
 
69
75
  let session_config = ModelConfig {
70
76
  max_length,
77
+ padding_mode,
71
78
  output_tensor: String::new(),
72
79
  mode: ExtractorMode::Raw,
73
80
  with_type_ids: false,
74
81
  with_attention_mask: true,
75
82
  num_threads,
76
83
  optimization_level,
77
- execution_providers: execution_providers_override.map(str::to_string),
84
+ execution_providers: overrides.execution_providers.map(str::to_string),
78
85
  };
79
86
  let session = build_session(&model_path, &session_config)?;
80
87
 
@@ -82,7 +89,7 @@ impl Embedder {
82
89
  let with_type_ids = has_input(&session, "token_type_ids");
83
90
  let with_attention_mask = has_input(&session, "attention_mask");
84
91
  let output_tensor =
85
- select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
92
+ select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
86
93
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
87
94
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
88
95
  return Err(GteError::Inference(
@@ -92,16 +99,23 @@ impl Embedder {
92
99
 
93
100
  let config = ModelConfig {
94
101
  max_length,
102
+ padding_mode,
95
103
  output_tensor,
96
104
  mode,
97
105
  with_type_ids,
98
106
  with_attention_mask,
99
107
  num_threads,
100
108
  optimization_level,
101
- execution_providers: execution_providers_override.map(str::to_string),
109
+ execution_providers: overrides.execution_providers.map(str::to_string),
102
110
  };
103
111
 
104
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
112
+ let tokenizer = Tokenizer::new(
113
+ &tokenizer_path,
114
+ config.max_length,
115
+ config.with_type_ids,
116
+ config.padding_mode,
117
+ tokenizer_profile.fixed_padding_length,
118
+ )?;
105
119
 
106
120
  Ok(Self {
107
121
  tokenizer,
@@ -5,9 +5,18 @@ pub enum ExtractorMode {
5
5
  Raw,
6
6
  }
7
7
 
8
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9
+ pub enum PaddingMode {
10
+ #[default]
11
+ Auto,
12
+ BatchLongest,
13
+ Fixed,
14
+ }
15
+
8
16
  #[derive(Debug, Clone)]
9
17
  pub struct ModelConfig {
10
18
  pub max_length: usize,
19
+ pub padding_mode: PaddingMode,
11
20
  pub output_tensor: String,
12
21
  pub mode: ExtractorMode,
13
22
  pub with_type_ids: bool,
@@ -16,3 +25,12 @@ pub struct ModelConfig {
16
25
  pub optimization_level: u8,
17
26
  pub execution_providers: Option<String>,
18
27
  }
28
+
29
+ #[derive(Debug, Clone, Copy, Default)]
30
+ pub struct ModelLoadOverrides<'a> {
31
+ pub model_name: Option<&'a str>,
32
+ pub output_tensor: Option<&'a str>,
33
+ pub max_length: Option<usize>,
34
+ pub padding: Option<&'a str>,
35
+ pub execution_providers: Option<&'a str>,
36
+ }
@@ -1,9 +1,19 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::ExtractorMode;
3
3
  use ort::session::Session;
4
+ use serde_json::Value;
4
5
  use std::path::{Path, PathBuf};
5
6
 
6
7
  const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
8
+ const DEFAULT_MAX_LENGTH: usize = 512;
9
+ const MAX_SUPPORTED_LENGTH: usize = 8192;
10
+
11
+ #[derive(Debug, Clone, Copy)]
12
+ pub struct TokenizerProfile {
13
+ pub default_max_length: usize,
14
+ pub safe_max_length: usize,
15
+ pub fixed_padding_length: Option<usize>,
16
+ }
7
17
 
8
18
  pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
9
19
  let tokenizer_path = dir.join("tokenizer.json");
@@ -48,19 +58,78 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
48
58
  )))
49
59
  }
50
60
 
51
- pub fn read_max_length(dir: &Path) -> usize {
52
- (|| -> Option<usize> {
53
- let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
54
- let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
55
- let v = json.get("model_max_length")?;
56
- let n = v.as_u64().or_else(|| {
57
- v.as_f64()
58
- .filter(|&f| f > 0.0 && f < 1e15)
59
- .map(|f| f as u64)
60
- })?;
61
- Some((n as usize).min(8192))
62
- })()
63
- .unwrap_or(512)
61
+ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
62
+ let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
63
+ let tokenizer_json = read_json(dir.join("tokenizer.json"));
64
+
65
+ let fixed_padding_length = tokenizer_json
66
+ .as_ref()
67
+ .and_then(parse_fixed_padding_length_from_tokenizer_json);
68
+
69
+ let mut candidates = Vec::new();
70
+ if let Some(config) = tokenizer_config.as_ref() {
71
+ if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
72
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
73
+ }
74
+ if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
75
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
76
+ }
77
+ }
78
+
79
+ if let Some(tokenizer) = tokenizer_json.as_ref() {
80
+ if let Some(v) = tokenizer
81
+ .get("truncation")
82
+ .and_then(|truncation| truncation.get("max_length"))
83
+ .and_then(parse_positive_usize)
84
+ {
85
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
86
+ }
87
+ }
88
+
89
+ if let Some(v) = fixed_padding_length {
90
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
91
+ }
92
+
93
+ let default_max_length = candidates
94
+ .iter()
95
+ .copied()
96
+ .min()
97
+ .unwrap_or(DEFAULT_MAX_LENGTH)
98
+ .max(1);
99
+ let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
100
+
101
+ TokenizerProfile {
102
+ default_max_length,
103
+ safe_max_length,
104
+ fixed_padding_length,
105
+ }
106
+ }
107
+
108
+ fn read_json(path: PathBuf) -> Option<Value> {
109
+ let contents = std::fs::read_to_string(path).ok()?;
110
+ serde_json::from_str(&contents).ok()
111
+ }
112
+
113
+ fn parse_positive_usize(value: &Value) -> Option<usize> {
114
+ let raw = value
115
+ .as_u64()
116
+ .or_else(|| {
117
+ value
118
+ .as_f64()
119
+ .filter(|&v| v.is_finite() && v > 0.0)
120
+ .map(|v| v as u64)
121
+ })
122
+ .or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
123
+ let parsed = usize::try_from(raw).ok()?;
124
+ (parsed > 0).then_some(parsed)
125
+ }
126
+
127
+ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
128
+ tokenizer_json
129
+ .get("padding")
130
+ .and_then(|padding| padding.get("strategy"))
131
+ .and_then(|strategy| strategy.get("Fixed"))
132
+ .and_then(parse_positive_usize)
64
133
  }
65
134
 
66
135
  pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
@@ -177,3 +246,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
177
246
  ))),
178
247
  }
179
248
  }
249
+
250
+ #[cfg(test)]
251
+ mod tests {
252
+ use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
253
+ use serde_json::json;
254
+
255
+ #[test]
256
+ fn parse_positive_usize_handles_integer_float_and_string() {
257
+ assert_eq!(parse_positive_usize(&json!(64)), Some(64));
258
+ assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
259
+ assert_eq!(parse_positive_usize(&json!("64")), Some(64));
260
+ assert_eq!(parse_positive_usize(&json!(0)), None);
261
+ }
262
+
263
+ #[test]
264
+ fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
265
+ let tokenizer_json = json!({
266
+ "padding": {
267
+ "strategy": {
268
+ "Fixed": 64
269
+ }
270
+ }
271
+ });
272
+ assert_eq!(
273
+ parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
274
+ Some(64)
275
+ );
276
+ }
277
+ }
@@ -1,19 +1,20 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::{ModelLoadOverrides, PaddingMode};
2
3
  use crate::model_profile::{
3
- has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
4
- select_output_tensor, validate_supported_text_inputs,
4
+ has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
5
6
  };
6
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
7
8
  use crate::postprocess::sigmoid_scores;
8
9
  use crate::session::build_session;
9
- use crate::tokenizer::Tokenizer;
10
- use ndarray::Array1;
10
+ use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
11
  use ort::session::Session;
12
12
  use std::path::Path;
13
13
 
14
14
  #[derive(Debug, Clone)]
15
15
  struct RerankerConfig {
16
16
  max_length: usize,
17
+ padding_mode: PaddingMode,
17
18
  output_tensor: String,
18
19
  with_type_ids: bool,
19
20
  with_attention_mask: bool,
@@ -30,54 +31,62 @@ impl Reranker {
30
31
  dir: P,
31
32
  num_threads: usize,
32
33
  optimization_level: u8,
33
- model_name: Option<&str>,
34
- output_tensor_override: Option<&str>,
35
- max_length_override: Option<usize>,
36
- execution_providers_override: Option<&str>,
34
+ overrides: ModelLoadOverrides<'_>,
37
35
  ) -> Result<Self> {
38
36
  let dir = dir.as_ref();
39
37
  let tokenizer_path = resolve_tokenizer_path(dir)?;
40
- let model_path = match model_name.filter(|s| !s.is_empty()) {
38
+ let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
41
39
  Some(name) => resolve_named_model(dir, name)?,
42
40
  None => resolve_default_text_model(dir)?,
43
41
  };
44
42
 
45
- let max_length = if let Some(override_value) = max_length_override {
43
+ let tokenizer_profile = read_tokenizer_profile(dir);
44
+ let max_length = if let Some(override_value) = overrides.max_length {
46
45
  if override_value == 0 {
47
46
  return Err(GteError::Inference(
48
47
  "max_length override must be greater than 0".to_string(),
49
48
  ));
50
49
  }
51
- override_value
50
+ override_value.min(tokenizer_profile.safe_max_length)
52
51
  } else {
53
- read_max_length(dir)
52
+ tokenizer_profile.default_max_length
54
53
  };
54
+ let padding_mode =
55
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
55
56
 
56
57
  let probe_config = crate::model_config::ModelConfig {
57
58
  max_length,
59
+ padding_mode,
58
60
  output_tensor: String::new(),
59
61
  mode: crate::model_config::ExtractorMode::Raw,
60
62
  with_type_ids: false,
61
63
  with_attention_mask: true,
62
64
  num_threads,
63
65
  optimization_level,
64
- execution_providers: execution_providers_override.map(str::to_string),
66
+ execution_providers: overrides.execution_providers.map(str::to_string),
65
67
  };
66
68
  let session = build_session(&model_path, &probe_config)?;
67
69
 
68
70
  validate_supported_text_inputs(&session, "text reranking")?;
69
71
  let with_type_ids = has_input(&session, "token_type_ids");
70
72
  let with_attention_mask = has_input(&session, "attention_mask");
71
- let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
73
+ let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
72
74
 
73
75
  let config = RerankerConfig {
74
76
  max_length,
77
+ padding_mode,
75
78
  output_tensor,
76
79
  with_type_ids,
77
80
  with_attention_mask,
78
81
  };
79
82
 
80
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
83
+ let tokenizer = Tokenizer::new(
84
+ &tokenizer_path,
85
+ config.max_length,
86
+ config.with_type_ids,
87
+ config.padding_mode,
88
+ tokenizer_profile.fixed_padding_length,
89
+ )?;
81
90
 
82
91
  Ok(Self {
83
92
  tokenizer,
@@ -86,14 +95,27 @@ impl Reranker {
86
95
  })
87
96
  }
88
97
 
89
- pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
98
+ pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
90
99
  let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
91
- let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
100
+ self.score_tokenized(&tokenized, apply_sigmoid)
101
+ }
102
+
103
+ pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
104
+ let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
105
+ self.score_tokenized(&tokenized, apply_sigmoid)
106
+ }
107
+
108
+ fn score_tokenized(
109
+ &self,
110
+ tokenized: &crate::tokenizer::Tokenized,
111
+ apply_sigmoid: bool,
112
+ ) -> Result<Vec<f32>> {
113
+ let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
92
114
  let outputs = self.session.run(input_tensors.inputs)?;
93
115
  let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
94
116
 
95
117
  let mut scores = match array.ndim() {
96
- 1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
118
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
97
119
  2 => {
98
120
  let shape = array.shape();
99
121
  if shape[1] == 0 {
@@ -102,7 +124,7 @@ impl Reranker {
102
124
  self.config.output_tensor, shape
103
125
  )));
104
126
  }
105
- array.slice(ndarray::s![.., 0]).into_owned()
127
+ array.slice(ndarray::s![.., 0]).to_vec()
106
128
  }
107
129
  n => {
108
130
  return Err(GteError::Inference(format!(
@@ -113,10 +135,9 @@ impl Reranker {
113
135
  };
114
136
 
115
137
  if apply_sigmoid {
116
- sigmoid_scores(scores.view_mut());
138
+ sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
117
139
  }
118
140
 
119
141
  Ok(scores)
120
142
  }
121
-
122
143
  }
@@ -2,6 +2,7 @@
2
2
 
3
3
  use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
+ use crate::model_config::ModelLoadOverrides;
5
6
  use crate::reranker::Reranker;
6
7
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
7
8
  use std::os::raw::c_void;
@@ -38,7 +39,8 @@ unsafe impl Send for InferArgs {}
38
39
 
39
40
  struct ScoreArgs {
40
41
  reranker: *const Reranker,
41
- pairs: *const Vec<(String, String)>,
42
+ query: *const String,
43
+ candidates: *const Vec<String>,
42
44
  apply_sigmoid: bool,
43
45
  result: Option<Result<Vec<f32>, GteError>>,
44
46
  }
@@ -85,13 +87,15 @@ fn infer_without_gvl(
85
87
 
86
88
  fn score_without_gvl(
87
89
  reranker: &Arc<Reranker>,
88
- pairs: Vec<(String, String)>,
90
+ query: String,
91
+ candidates: Vec<String>,
89
92
  apply_sigmoid: bool,
90
93
  ) -> Result<Vec<f32>, Error> {
91
94
  let scores = unsafe {
92
95
  let mut args = ScoreArgs {
93
96
  reranker: Arc::as_ptr(reranker),
94
- pairs: &pairs as *const Vec<(String, String)>,
97
+ query: &query as *const String,
98
+ candidates: &candidates as *const Vec<String>,
95
99
  apply_sigmoid,
96
100
  result: None,
97
101
  };
@@ -135,8 +139,7 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
135
139
  unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
136
140
  let args = &mut *(ptr as *mut ScoreArgs);
137
141
  let run_result = catch_unwind(AssertUnwindSafe(|| {
138
- let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
139
- Ok(scores.to_vec())
142
+ (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
140
143
  }));
141
144
  args.result = Some(match run_result {
142
145
  Ok(result) => result,
@@ -171,6 +174,7 @@ impl RbEmbedder {
171
174
  normalize: bool,
172
175
  output_tensor: String,
173
176
  max_length: usize,
177
+ padding: String,
174
178
  execution_providers: String,
175
179
  ) -> Result<Self, Error> {
176
180
  let name = if model_name.is_empty() {
@@ -193,14 +197,23 @@ impl RbEmbedder {
193
197
  } else {
194
198
  Some(execution_providers.as_str())
195
199
  };
200
+ let padding_override = if padding.is_empty() {
201
+ None
202
+ } else {
203
+ Some(padding.as_str())
204
+ };
205
+ let overrides = ModelLoadOverrides {
206
+ model_name: name,
207
+ output_tensor: output_override,
208
+ max_length: max_length_override,
209
+ padding: padding_override,
210
+ execution_providers: execution_providers_override,
211
+ };
196
212
  let embedder = Embedder::from_dir(
197
213
  &dir_path,
198
214
  num_threads,
199
215
  optimization_level,
200
- name,
201
- output_override,
202
- max_length_override,
203
- execution_providers_override,
216
+ overrides,
204
217
  )
205
218
  .map_err(magnus::Error::from)?;
206
219
  Ok(RbEmbedder {
@@ -231,6 +244,7 @@ impl RbReranker {
231
244
  sigmoid: bool,
232
245
  output_tensor: String,
233
246
  max_length: usize,
247
+ padding: String,
234
248
  execution_providers: String,
235
249
  ) -> Result<Self, Error> {
236
250
  let name = if model_name.is_empty() {
@@ -253,14 +267,23 @@ impl RbReranker {
253
267
  } else {
254
268
  Some(execution_providers.as_str())
255
269
  };
270
+ let padding_override = if padding.is_empty() {
271
+ None
272
+ } else {
273
+ Some(padding.as_str())
274
+ };
275
+ let overrides = ModelLoadOverrides {
276
+ model_name: name,
277
+ output_tensor: output_override,
278
+ max_length: max_length_override,
279
+ padding: padding_override,
280
+ execution_providers: execution_providers_override,
281
+ };
256
282
  let reranker = Reranker::from_dir(
257
283
  &dir_path,
258
284
  num_threads,
259
285
  optimization_level,
260
- name,
261
- output_override,
262
- max_length_override,
263
- execution_providers_override,
286
+ overrides,
264
287
  )
265
288
  .map_err(magnus::Error::from)?;
266
289
  Ok(RbReranker {
@@ -276,11 +299,7 @@ impl RbReranker {
276
299
  candidates: RArray,
277
300
  ) -> Result<RArray, Error> {
278
301
  let candidates: Vec<String> = candidates.to_vec()?;
279
- let pairs: Vec<(String, String)> = candidates
280
- .into_iter()
281
- .map(|candidate| (query.clone(), candidate))
282
- .collect();
283
- let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
302
+ let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
284
303
 
285
304
  let out = ruby.ary_new_capa(scores.len());
286
305
  for score in scores {
@@ -376,12 +395,12 @@ impl RbTensor {
376
395
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
377
396
  let module = ruby.define_module("GTE")?;
378
397
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
379
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
398
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
380
399
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
381
400
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
382
401
 
383
402
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
384
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
403
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
385
404
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
386
405
 
387
406
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
@@ -1,4 +1,5 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::PaddingMode;
2
3
  use std::path::Path;
3
4
  use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
4
5
 
@@ -20,6 +21,8 @@ impl Tokenizer {
20
21
  tokenizer_path: P,
21
22
  max_length: usize,
22
23
  with_type_ids: bool,
24
+ padding_mode: PaddingMode,
25
+ fixed_padding_length: Option<usize>,
23
26
  ) -> Result<Self> {
24
27
  let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
25
28
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
@@ -33,7 +36,7 @@ impl Tokenizer {
33
36
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
34
37
 
35
38
  let padding = PaddingParams {
36
- strategy: PaddingStrategy::BatchLongest,
39
+ strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
37
40
  ..Default::default()
38
41
  };
39
42
  tokenizer.with_padding(Some(padding));
@@ -73,6 +76,56 @@ impl Tokenizer {
73
76
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
74
77
  build_tokenized(&encodings, self.with_type_ids)
75
78
  }
79
+
80
+ pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
81
+ let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
82
+ .iter()
83
+ .map(|candidate| (query, candidate.as_str()).into())
84
+ .collect();
85
+ let encodings = self
86
+ .tokenizer
87
+ .encode_batch_fast(encode_inputs, true)
88
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
89
+ build_tokenized(&encodings, self.with_type_ids)
90
+ }
91
+ }
92
+
93
+ pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
94
+ let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
95
+ return Ok(None);
96
+ };
97
+
98
+ let normalized = raw.to_ascii_lowercase().replace('-', "_");
99
+ let parsed = match normalized.as_str() {
100
+ "auto" => PaddingMode::Auto,
101
+ "batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
102
+ "fixed" => PaddingMode::Fixed,
103
+ _ => {
104
+ return Err(GteError::Inference(format!(
105
+ "invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
106
+ raw
107
+ )))
108
+ }
109
+ };
110
+ Ok(Some(parsed))
111
+ }
112
+
113
+ fn resolve_padding_strategy(
114
+ padding_mode: PaddingMode,
115
+ max_length: usize,
116
+ fixed_padding_length: Option<usize>,
117
+ ) -> PaddingStrategy {
118
+ match padding_mode {
119
+ PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
120
+ PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
121
+ PaddingMode::Auto => {
122
+ if fixed_padding_length.is_some() {
123
+ PaddingStrategy::Fixed(max_length)
124
+ } else {
125
+ PaddingStrategy::BatchLongest
126
+ }
127
+ }
128
+ }
76
129
  }
77
130
 
78
131
  fn build_tokenized_single(
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
121
174
  let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
122
175
 
123
176
  for encoding in encodings {
124
- input_ids.extend(encoding.get_ids().iter().map(|&value| i64::from(value)));
125
- attn_masks.extend(
126
- encoding
127
- .get_attention_mask()
128
- .iter()
129
- .map(|&value| i64::from(value)),
130
- );
177
+ for &value in encoding.get_ids() {
178
+ input_ids.push(i64::from(value));
179
+ }
180
+ for &value in encoding.get_attention_mask() {
181
+ attn_masks.push(i64::from(value));
182
+ }
131
183
 
132
184
  if let Some(type_ids) = type_ids.as_mut() {
133
- type_ids.extend(
134
- encoding
135
- .get_type_ids()
136
- .iter()
137
- .map(|&value| i64::from(value)),
138
- );
185
+ for &value in encoding.get_type_ids() {
186
+ type_ids.push(i64::from(value));
187
+ }
139
188
  }
140
189
  }
141
190
 
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
147
196
  type_ids,
148
197
  })
149
198
  }
199
+
200
+ #[cfg(test)]
201
+ mod tests {
202
+ use super::{parse_padding_mode_override, resolve_padding_strategy};
203
+ use crate::model_config::PaddingMode;
204
+ use tokenizers::PaddingStrategy;
205
+
206
+ #[test]
207
+ fn parse_padding_mode_override_accepts_expected_values() {
208
+ assert_eq!(
209
+ parse_padding_mode_override(Some("auto")).unwrap(),
210
+ Some(PaddingMode::Auto)
211
+ );
212
+ assert_eq!(
213
+ parse_padding_mode_override(Some("batch-longest")).unwrap(),
214
+ Some(PaddingMode::BatchLongest)
215
+ );
216
+ assert_eq!(
217
+ parse_padding_mode_override(Some("fixed")).unwrap(),
218
+ Some(PaddingMode::Fixed)
219
+ );
220
+ }
221
+
222
+ #[test]
223
+ fn parse_padding_mode_override_rejects_invalid_values() {
224
+ assert!(parse_padding_mode_override(Some("unknown")).is_err());
225
+ }
226
+
227
+ #[test]
228
+ fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
229
+ match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
230
+ PaddingStrategy::Fixed(64) => {}
231
+ other => panic!("expected Fixed(64), got {:?}", other),
232
+ }
233
+ }
234
+ }
@@ -1,11 +1,12 @@
1
1
  use gte::embedder::Embedder;
2
+ use gte::model_config::ModelLoadOverrides;
2
3
 
3
4
  #[test]
4
5
  #[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
5
6
  fn test_e5_single_embedding_shape() {
6
7
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
8
 
8
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
9
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
9
10
  .expect("embedder should initialize");
10
11
  let result = embedder
11
12
  .embed(vec!["query: Hello world".to_string()])
@@ -20,7 +21,7 @@ fn test_e5_single_embedding_shape() {
20
21
  fn test_clip_single_embedding_shape() {
21
22
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
22
23
 
23
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
24
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
24
25
  .expect("embedder should initialize");
25
26
  let result = embedder
26
27
  .embed(vec!["a photo of a cat".to_string()])
@@ -35,7 +36,7 @@ fn test_clip_single_embedding_shape() {
35
36
  fn test_e5_batch_embedding_shape() {
36
37
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
37
38
 
38
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
39
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
39
40
  .expect("embedder should initialize");
40
41
  let texts = vec![
41
42
  "query: first sentence".to_string(),
@@ -54,7 +55,7 @@ fn test_e5_batch_embedding_shape() {
54
55
  fn test_e5_long_input_truncation_no_error() {
55
56
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
56
57
 
57
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
58
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
58
59
  .expect("embedder should initialize");
59
60
  let very_long_text = "word ".repeat(1000);
60
61
  let result = embedder
@@ -1,3 +1,4 @@
1
+ use gte::model_config::PaddingMode;
1
2
  use gte::tokenizer::Tokenizer;
2
3
 
3
4
  #[test]
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
8
9
  "/tests/fixtures/e5/tokenizer.json"
9
10
  );
10
11
 
11
- let tokenizer = Tokenizer::new(TOKENIZER, 512, true).expect("tokenizer should load");
12
+ let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
13
+ .expect("tokenizer should load");
12
14
  let texts = vec![
13
15
  "Hello, world!".to_string(),
14
16
  "A second, longer sentence to test padding behavior.".to_string(),
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
33
35
  "/tests/fixtures/e5/tokenizer.json"
34
36
  );
35
37
 
36
- let tokenizer = Tokenizer::new(TOKENIZER, 16, false).expect("tokenizer should load");
38
+ let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
39
+ .expect("tokenizer should load");
37
40
  let long_text = "word ".repeat(200);
38
41
  let tokenized = tokenizer
39
42
  .tokenize(&[long_text])
data/lib/gte/config.rb CHANGED
@@ -4,12 +4,12 @@ module GTE
4
4
  module Config
5
5
  Text = Data.define(
6
6
  :model_dir, :threads, :optimization_level,
7
- :model_name, :normalize, :output_tensor, :max_length, :execution_providers
7
+ :model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
8
8
  )
9
9
 
10
10
  Reranker = Data.define(
11
11
  :model_dir, :threads, :optimization_level,
12
- :model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
12
+ :model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
13
13
  )
14
14
  end
15
15
  end
data/lib/gte/embedder.rb CHANGED
@@ -18,6 +18,7 @@ module GTE
18
18
  config.normalize,
19
19
  config.output_tensor.to_s,
20
20
  config.max_length || 0,
21
+ config.padding.to_s,
21
22
  config.execution_providers.to_s
22
23
  )
23
24
  end
@@ -33,6 +34,7 @@ module GTE
33
34
  normalize: true,
34
35
  output_tensor: nil,
35
36
  max_length: nil,
37
+ padding: nil,
36
38
  execution_providers: nil
37
39
  )
38
40
  end
data/lib/gte/reranker.rb CHANGED
@@ -25,6 +25,7 @@ module GTE
25
25
  sigmoid: false,
26
26
  output_tensor: nil,
27
27
  max_length: nil,
28
+ padding: nil,
28
29
  execution_providers: nil
29
30
  )
30
31
  end
@@ -38,6 +39,7 @@ module GTE
38
39
  cfg.sigmoid,
39
40
  cfg.output_tensor.to_s,
40
41
  cfg.max_length || 0,
42
+ cfg.padding.to_s,
41
43
  cfg.execution_providers.to_s
42
44
  )
43
45
  end
data/lib/gte.rb CHANGED
@@ -27,6 +27,7 @@ module GTE
27
27
  normalize: true,
28
28
  output_tensor: nil,
29
29
  max_length: nil,
30
+ padding: nil,
30
31
  execution_providers: nil
31
32
  )
32
33
 
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.6
4
+ version: 0.0.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo