gte 0.0.5 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ae83f737b57f798d39cf1fdc895d67948de27d36b46ea02c211a440d3acaa8c9
4
- data.tar.gz: 9eaf9651b2ccf1fdb93efe4666ed70537628453a8cf92e234b454560560a83e8
3
+ metadata.gz: 29659e3ab6072d858b1710a779c3d2e5981f7749782182d141ccd5e9790a1fbb
4
+ data.tar.gz: c42d51cfa1a2ba6a2e83249e8a725c978b11c7ef80c6d69f09a64e884be42031
5
5
  SHA512:
6
- metadata.gz: a262194a53bf804e47b0ef9c5910c1e2b814a9824823a92a73867a631c7b26310b3163e61997d9c163dab402a40d49946b76a64cc0421741ae235f623180cb95
7
- data.tar.gz: 6acf5b58140012df9fa25971ed0f1fdfa707cc3efbe5f7f22104e35ad57877778a08cf9f8b311017be8f40e255289e3249e35c1e3780ae231f9f66e08cbb6ac3
6
+ metadata.gz: ff2c2b1450a6e82c07aacd2ec98437f03678d56eef9c5516f904021a54f59b2ba5c42b8f6af22b5c4b2dacea98615b99bc54d2c7cdc4e8fbccc1abc195fe9975
7
+ data.tar.gz: 04ca056458d40e2ba7fabcdbcab415a087d54802fb3bd86748dc901c2cf0ecb44072fd1820a73e3dcaca097f165df3e70bab747b38340cd738876af5f0ea7645
data/README.md CHANGED
@@ -41,6 +41,7 @@ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
41
  config.with(
42
42
  output_tensor: "last_hidden_state",
43
43
  max_length: 256,
44
+ padding: "batch_longest",
44
45
  optimization_level: 3
45
46
  )
46
47
  end
@@ -55,12 +56,22 @@ Config fields and defaults:
55
56
  - `normalize`: `true` (L2 normalization at Ruby-facing API)
56
57
  - `output_tensor`: `nil` (auto-select output tensor)
57
58
  - `max_length`: `nil` (uses tokenizer/model defaults)
59
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
60
+ - `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
58
61
 
59
62
  Notes:
60
63
 
61
64
  - Return a `Config::Text` from the block (for example, `config.with(...)`).
62
65
  - Model instances are cached by full config key; different config values create different cached instances.
63
66
 
67
+ Low-level embedder setup (without model cache):
68
+
69
+ ```ruby
70
+ embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
71
+ config.with(threads: 0, execution_providers: "cpu")
72
+ end
73
+ ```
74
+
64
75
  ## Reranker
65
76
 
66
77
  Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
@@ -97,6 +108,8 @@ Reranker config fields and defaults:
97
108
  - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
98
109
  - `output_tensor`: `nil`
99
110
  - `max_length`: `nil`
111
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
112
+ - `execution_providers`: `nil`
100
113
 
101
114
  ## Runtime + Result Examples
102
115
 
@@ -123,14 +136,30 @@ Input policy is text-only. Graphs requiring unsupported multimodal inputs (such
123
136
 
124
137
  ## Execution Providers
125
138
 
126
- Default execution provider is `xnnpack` on all platforms (including macOS arm64).
139
+ Default behavior is CPU fallback via ONNX Runtime's default provider (no explicit provider registration).
140
+
141
+ Configure providers with `GTE_EXECUTION_PROVIDERS` (comma-separated, case-insensitive).
142
+ Supported values:
143
+
144
+ - `cpu` or `none`: CPU fallback (skip explicit provider registration)
145
+ - `xnnpack`
146
+ - `coreml`
127
147
 
128
- To opt in to CoreML explicitly:
148
+ Examples:
129
149
 
130
150
  ```bash
151
+ export GTE_EXECUTION_PROVIDERS=cpu
131
152
  export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
132
153
  ```
133
154
 
155
+ Ruby per-instance override (takes precedence over `GTE_EXECUTION_PROVIDERS`):
156
+
157
+ ```ruby
158
+ model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
159
+ config.with(execution_providers: "cpu")
160
+ end
161
+ ```
162
+
134
163
  ## Development
135
164
 
136
165
  Run commands inside `nix develop` via Make targets:
@@ -154,8 +183,13 @@ nix develop -c bundle exec rake bench:matrix_sweep
154
183
  nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
155
184
  ```
156
185
 
157
- For release tracking and regression detection, record a run entry in `RUNS.md`:
186
+ To run benchmark + append a `RUNS.md` entry + enforce goal checks:
158
187
 
159
188
  ```bash
160
189
  make bench-record
161
190
  ```
191
+
192
+ `bench/runs_ledger.rb check` is goal-focused by default:
193
+
194
+ - Enforces goal metric (`response_time_p95` ratio threshold).
195
+ - Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
data/Rakefile CHANGED
@@ -56,7 +56,7 @@ namespace :bench do
56
56
  )
57
57
  end
58
58
 
59
- desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal/regression checks'
59
+ desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal checks'
60
60
  task :record_run do
61
61
  run_in_nix(
62
62
  'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
@@ -74,7 +74,7 @@ namespace :bench do
74
74
  )
75
75
  end
76
76
 
77
- desc 'Validate current Puma benchmark output against 2x goal and regression policy'
77
+ desc 'Validate current Puma benchmark output against 2x goal only'
78
78
  task :check_goal do
79
79
  run_in_nix(
80
80
  'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.5
1
+ 0.0.7
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.5"
3
+ version = "0.0.7"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -1,12 +1,12 @@
1
1
  use crate::error::{GteError, Result};
2
- use crate::model_config::{ExtractorMode, ModelConfig};
2
+ use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
3
3
  use crate::model_profile::{
4
- has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
5
- resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
4
+ has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
5
+ resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
8
  use crate::session::{build_session, run_session};
9
- use crate::tokenizer::{Tokenized, Tokenizer};
9
+ use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
10
10
  use ndarray::Array2;
11
11
  use ort::session::Session;
12
12
  use std::path::Path;
@@ -23,7 +23,13 @@ impl Embedder {
23
23
  P1: AsRef<Path>,
24
24
  P2: AsRef<Path>,
25
25
  {
26
- let tokenizer = Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids)?;
26
+ let tokenizer = Tokenizer::new(
27
+ tokenizer_path,
28
+ config.max_length,
29
+ config.with_type_ids,
30
+ config.padding_mode,
31
+ None,
32
+ )?;
27
33
  let session = build_session(model_path, &config)?;
28
34
  Ok(Self {
29
35
  tokenizer,
@@ -36,9 +42,7 @@ impl Embedder {
36
42
  dir: P,
37
43
  num_threads: usize,
38
44
  optimization_level: u8,
39
- model_name: Option<&str>,
40
- output_tensor_override: Option<&str>,
41
- max_length_override: Option<usize>,
45
+ overrides: ModelLoadOverrides<'_>,
42
46
  ) -> Result<Self> {
43
47
  const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
44
48
  "pooler_output",
@@ -49,30 +53,35 @@ impl Embedder {
49
53
 
50
54
  let dir = dir.as_ref();
51
55
  let tokenizer_path = resolve_tokenizer_path(dir)?;
52
- let model_path = match model_name.filter(|s| !s.is_empty()) {
56
+ let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
53
57
  Some(name) => resolve_named_model(dir, name)?,
54
58
  None => resolve_default_text_model(dir)?,
55
59
  };
56
60
 
57
- let max_length = if let Some(override_value) = max_length_override {
61
+ let tokenizer_profile = read_tokenizer_profile(dir);
62
+ let max_length = if let Some(override_value) = overrides.max_length {
58
63
  if override_value == 0 {
59
64
  return Err(GteError::Inference(
60
65
  "max_length override must be greater than 0".to_string(),
61
66
  ));
62
67
  }
63
- override_value
68
+ override_value.min(tokenizer_profile.safe_max_length)
64
69
  } else {
65
- read_max_length(dir)
70
+ tokenizer_profile.default_max_length
66
71
  };
72
+ let padding_mode =
73
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
67
74
 
68
75
  let session_config = ModelConfig {
69
76
  max_length,
77
+ padding_mode,
70
78
  output_tensor: String::new(),
71
79
  mode: ExtractorMode::Raw,
72
80
  with_type_ids: false,
73
81
  with_attention_mask: true,
74
82
  num_threads,
75
83
  optimization_level,
84
+ execution_providers: overrides.execution_providers.map(str::to_string),
76
85
  };
77
86
  let session = build_session(&model_path, &session_config)?;
78
87
 
@@ -80,7 +89,7 @@ impl Embedder {
80
89
  let with_type_ids = has_input(&session, "token_type_ids");
81
90
  let with_attention_mask = has_input(&session, "attention_mask");
82
91
  let output_tensor =
83
- select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
92
+ select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
84
93
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
85
94
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
86
95
  return Err(GteError::Inference(
@@ -90,15 +99,23 @@ impl Embedder {
90
99
 
91
100
  let config = ModelConfig {
92
101
  max_length,
102
+ padding_mode,
93
103
  output_tensor,
94
104
  mode,
95
105
  with_type_ids,
96
106
  with_attention_mask,
97
107
  num_threads,
98
108
  optimization_level,
109
+ execution_providers: overrides.execution_providers.map(str::to_string),
99
110
  };
100
111
 
101
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
112
+ let tokenizer = Tokenizer::new(
113
+ &tokenizer_path,
114
+ config.max_length,
115
+ config.with_type_ids,
116
+ config.padding_mode,
117
+ tokenizer_profile.fixed_padding_length,
118
+ )?;
102
119
 
103
120
  Ok(Self {
104
121
  tokenizer,
@@ -5,13 +5,32 @@ pub enum ExtractorMode {
5
5
  Raw,
6
6
  }
7
7
 
8
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9
+ pub enum PaddingMode {
10
+ #[default]
11
+ Auto,
12
+ BatchLongest,
13
+ Fixed,
14
+ }
15
+
8
16
  #[derive(Debug, Clone)]
9
17
  pub struct ModelConfig {
10
18
  pub max_length: usize,
19
+ pub padding_mode: PaddingMode,
11
20
  pub output_tensor: String,
12
21
  pub mode: ExtractorMode,
13
22
  pub with_type_ids: bool,
14
23
  pub with_attention_mask: bool,
15
24
  pub num_threads: usize,
16
25
  pub optimization_level: u8,
26
+ pub execution_providers: Option<String>,
27
+ }
28
+
29
+ #[derive(Debug, Clone, Copy, Default)]
30
+ pub struct ModelLoadOverrides<'a> {
31
+ pub model_name: Option<&'a str>,
32
+ pub output_tensor: Option<&'a str>,
33
+ pub max_length: Option<usize>,
34
+ pub padding: Option<&'a str>,
35
+ pub execution_providers: Option<&'a str>,
17
36
  }
@@ -1,9 +1,19 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::ExtractorMode;
3
3
  use ort::session::Session;
4
+ use serde_json::Value;
4
5
  use std::path::{Path, PathBuf};
5
6
 
6
7
  const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
8
+ const DEFAULT_MAX_LENGTH: usize = 512;
9
+ const MAX_SUPPORTED_LENGTH: usize = 8192;
10
+
11
+ #[derive(Debug, Clone, Copy)]
12
+ pub struct TokenizerProfile {
13
+ pub default_max_length: usize,
14
+ pub safe_max_length: usize,
15
+ pub fixed_padding_length: Option<usize>,
16
+ }
7
17
 
8
18
  pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
9
19
  let tokenizer_path = dir.join("tokenizer.json");
@@ -48,19 +58,78 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
48
58
  )))
49
59
  }
50
60
 
51
- pub fn read_max_length(dir: &Path) -> usize {
52
- (|| -> Option<usize> {
53
- let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
54
- let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
55
- let v = json.get("model_max_length")?;
56
- let n = v.as_u64().or_else(|| {
57
- v.as_f64()
58
- .filter(|&f| f > 0.0 && f < 1e15)
59
- .map(|f| f as u64)
60
- })?;
61
- Some((n as usize).min(8192))
62
- })()
63
- .unwrap_or(512)
61
+ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
62
+ let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
63
+ let tokenizer_json = read_json(dir.join("tokenizer.json"));
64
+
65
+ let fixed_padding_length = tokenizer_json
66
+ .as_ref()
67
+ .and_then(parse_fixed_padding_length_from_tokenizer_json);
68
+
69
+ let mut candidates = Vec::new();
70
+ if let Some(config) = tokenizer_config.as_ref() {
71
+ if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
72
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
73
+ }
74
+ if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
75
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
76
+ }
77
+ }
78
+
79
+ if let Some(tokenizer) = tokenizer_json.as_ref() {
80
+ if let Some(v) = tokenizer
81
+ .get("truncation")
82
+ .and_then(|truncation| truncation.get("max_length"))
83
+ .and_then(parse_positive_usize)
84
+ {
85
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
86
+ }
87
+ }
88
+
89
+ if let Some(v) = fixed_padding_length {
90
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
91
+ }
92
+
93
+ let default_max_length = candidates
94
+ .iter()
95
+ .copied()
96
+ .min()
97
+ .unwrap_or(DEFAULT_MAX_LENGTH)
98
+ .max(1);
99
+ let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
100
+
101
+ TokenizerProfile {
102
+ default_max_length,
103
+ safe_max_length,
104
+ fixed_padding_length,
105
+ }
106
+ }
107
+
108
+ fn read_json(path: PathBuf) -> Option<Value> {
109
+ let contents = std::fs::read_to_string(path).ok()?;
110
+ serde_json::from_str(&contents).ok()
111
+ }
112
+
113
+ fn parse_positive_usize(value: &Value) -> Option<usize> {
114
+ let raw = value
115
+ .as_u64()
116
+ .or_else(|| {
117
+ value
118
+ .as_f64()
119
+ .filter(|&v| v.is_finite() && v > 0.0)
120
+ .map(|v| v as u64)
121
+ })
122
+ .or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
123
+ let parsed = usize::try_from(raw).ok()?;
124
+ (parsed > 0).then_some(parsed)
125
+ }
126
+
127
+ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
128
+ tokenizer_json
129
+ .get("padding")
130
+ .and_then(|padding| padding.get("strategy"))
131
+ .and_then(|strategy| strategy.get("Fixed"))
132
+ .and_then(parse_positive_usize)
64
133
  }
65
134
 
66
135
  pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
@@ -177,3 +246,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
177
246
  ))),
178
247
  }
179
248
  }
249
+
250
+ #[cfg(test)]
251
+ mod tests {
252
+ use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
253
+ use serde_json::json;
254
+
255
+ #[test]
256
+ fn parse_positive_usize_handles_integer_float_and_string() {
257
+ assert_eq!(parse_positive_usize(&json!(64)), Some(64));
258
+ assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
259
+ assert_eq!(parse_positive_usize(&json!("64")), Some(64));
260
+ assert_eq!(parse_positive_usize(&json!(0)), None);
261
+ }
262
+
263
+ #[test]
264
+ fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
265
+ let tokenizer_json = json!({
266
+ "padding": {
267
+ "strategy": {
268
+ "Fixed": 64
269
+ }
270
+ }
271
+ });
272
+ assert_eq!(
273
+ parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
274
+ Some(64)
275
+ );
276
+ }
277
+ }
@@ -1,19 +1,20 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::{ModelLoadOverrides, PaddingMode};
2
3
  use crate::model_profile::{
3
- has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
4
- select_output_tensor, validate_supported_text_inputs,
4
+ has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
5
6
  };
6
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
7
8
  use crate::postprocess::sigmoid_scores;
8
9
  use crate::session::build_session;
9
- use crate::tokenizer::Tokenizer;
10
- use ndarray::Array1;
10
+ use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
11
  use ort::session::Session;
12
12
  use std::path::Path;
13
13
 
14
14
  #[derive(Debug, Clone)]
15
15
  struct RerankerConfig {
16
16
  max_length: usize,
17
+ padding_mode: PaddingMode,
17
18
  output_tensor: String,
18
19
  with_type_ids: bool,
19
20
  with_attention_mask: bool,
@@ -30,52 +31,62 @@ impl Reranker {
30
31
  dir: P,
31
32
  num_threads: usize,
32
33
  optimization_level: u8,
33
- model_name: Option<&str>,
34
- output_tensor_override: Option<&str>,
35
- max_length_override: Option<usize>,
34
+ overrides: ModelLoadOverrides<'_>,
36
35
  ) -> Result<Self> {
37
36
  let dir = dir.as_ref();
38
37
  let tokenizer_path = resolve_tokenizer_path(dir)?;
39
- let model_path = match model_name.filter(|s| !s.is_empty()) {
38
+ let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
40
39
  Some(name) => resolve_named_model(dir, name)?,
41
40
  None => resolve_default_text_model(dir)?,
42
41
  };
43
42
 
44
- let max_length = if let Some(override_value) = max_length_override {
43
+ let tokenizer_profile = read_tokenizer_profile(dir);
44
+ let max_length = if let Some(override_value) = overrides.max_length {
45
45
  if override_value == 0 {
46
46
  return Err(GteError::Inference(
47
47
  "max_length override must be greater than 0".to_string(),
48
48
  ));
49
49
  }
50
- override_value
50
+ override_value.min(tokenizer_profile.safe_max_length)
51
51
  } else {
52
- read_max_length(dir)
52
+ tokenizer_profile.default_max_length
53
53
  };
54
+ let padding_mode =
55
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
54
56
 
55
57
  let probe_config = crate::model_config::ModelConfig {
56
58
  max_length,
59
+ padding_mode,
57
60
  output_tensor: String::new(),
58
61
  mode: crate::model_config::ExtractorMode::Raw,
59
62
  with_type_ids: false,
60
63
  with_attention_mask: true,
61
64
  num_threads,
62
65
  optimization_level,
66
+ execution_providers: overrides.execution_providers.map(str::to_string),
63
67
  };
64
68
  let session = build_session(&model_path, &probe_config)?;
65
69
 
66
70
  validate_supported_text_inputs(&session, "text reranking")?;
67
71
  let with_type_ids = has_input(&session, "token_type_ids");
68
72
  let with_attention_mask = has_input(&session, "attention_mask");
69
- let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
73
+ let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
70
74
 
71
75
  let config = RerankerConfig {
72
76
  max_length,
77
+ padding_mode,
73
78
  output_tensor,
74
79
  with_type_ids,
75
80
  with_attention_mask,
76
81
  };
77
82
 
78
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
83
+ let tokenizer = Tokenizer::new(
84
+ &tokenizer_path,
85
+ config.max_length,
86
+ config.with_type_ids,
87
+ config.padding_mode,
88
+ tokenizer_profile.fixed_padding_length,
89
+ )?;
79
90
 
80
91
  Ok(Self {
81
92
  tokenizer,
@@ -84,14 +95,27 @@ impl Reranker {
84
95
  })
85
96
  }
86
97
 
87
- pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
98
+ pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
88
99
  let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
89
- let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
100
+ self.score_tokenized(&tokenized, apply_sigmoid)
101
+ }
102
+
103
+ pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
104
+ let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
105
+ self.score_tokenized(&tokenized, apply_sigmoid)
106
+ }
107
+
108
+ fn score_tokenized(
109
+ &self,
110
+ tokenized: &crate::tokenizer::Tokenized,
111
+ apply_sigmoid: bool,
112
+ ) -> Result<Vec<f32>> {
113
+ let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
90
114
  let outputs = self.session.run(input_tensors.inputs)?;
91
115
  let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
92
116
 
93
117
  let mut scores = match array.ndim() {
94
- 1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
118
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
95
119
  2 => {
96
120
  let shape = array.shape();
97
121
  if shape[1] == 0 {
@@ -100,7 +124,7 @@ impl Reranker {
100
124
  self.config.output_tensor, shape
101
125
  )));
102
126
  }
103
- array.slice(ndarray::s![.., 0]).into_owned()
127
+ array.slice(ndarray::s![.., 0]).to_vec()
104
128
  }
105
129
  n => {
106
130
  return Err(GteError::Inference(format!(
@@ -111,10 +135,9 @@ impl Reranker {
111
135
  };
112
136
 
113
137
  if apply_sigmoid {
114
- sigmoid_scores(scores.view_mut());
138
+ sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
115
139
  }
116
140
 
117
141
  Ok(scores)
118
142
  }
119
-
120
143
  }
@@ -2,6 +2,7 @@
2
2
 
3
3
  use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
+ use crate::model_config::ModelLoadOverrides;
5
6
  use crate::reranker::Reranker;
6
7
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
7
8
  use std::os::raw::c_void;
@@ -38,7 +39,8 @@ unsafe impl Send for InferArgs {}
38
39
 
39
40
  struct ScoreArgs {
40
41
  reranker: *const Reranker,
41
- pairs: *const Vec<(String, String)>,
42
+ query: *const String,
43
+ candidates: *const Vec<String>,
42
44
  apply_sigmoid: bool,
43
45
  result: Option<Result<Vec<f32>, GteError>>,
44
46
  }
@@ -85,13 +87,15 @@ fn infer_without_gvl(
85
87
 
86
88
  fn score_without_gvl(
87
89
  reranker: &Arc<Reranker>,
88
- pairs: Vec<(String, String)>,
90
+ query: String,
91
+ candidates: Vec<String>,
89
92
  apply_sigmoid: bool,
90
93
  ) -> Result<Vec<f32>, Error> {
91
94
  let scores = unsafe {
92
95
  let mut args = ScoreArgs {
93
96
  reranker: Arc::as_ptr(reranker),
94
- pairs: &pairs as *const Vec<(String, String)>,
97
+ query: &query as *const String,
98
+ candidates: &candidates as *const Vec<String>,
95
99
  apply_sigmoid,
96
100
  result: None,
97
101
  };
@@ -135,8 +139,7 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
135
139
  unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
136
140
  let args = &mut *(ptr as *mut ScoreArgs);
137
141
  let run_result = catch_unwind(AssertUnwindSafe(|| {
138
- let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
139
- Ok(scores.to_vec())
142
+ (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
140
143
  }));
141
144
  args.result = Some(match run_result {
142
145
  Ok(result) => result,
@@ -171,6 +174,8 @@ impl RbEmbedder {
171
174
  normalize: bool,
172
175
  output_tensor: String,
173
176
  max_length: usize,
177
+ padding: String,
178
+ execution_providers: String,
174
179
  ) -> Result<Self, Error> {
175
180
  let name = if model_name.is_empty() {
176
181
  None
@@ -187,13 +192,28 @@ impl RbEmbedder {
187
192
  } else {
188
193
  Some(max_length)
189
194
  };
195
+ let execution_providers_override = if execution_providers.is_empty() {
196
+ None
197
+ } else {
198
+ Some(execution_providers.as_str())
199
+ };
200
+ let padding_override = if padding.is_empty() {
201
+ None
202
+ } else {
203
+ Some(padding.as_str())
204
+ };
205
+ let overrides = ModelLoadOverrides {
206
+ model_name: name,
207
+ output_tensor: output_override,
208
+ max_length: max_length_override,
209
+ padding: padding_override,
210
+ execution_providers: execution_providers_override,
211
+ };
190
212
  let embedder = Embedder::from_dir(
191
213
  &dir_path,
192
214
  num_threads,
193
215
  optimization_level,
194
- name,
195
- output_override,
196
- max_length_override,
216
+ overrides,
197
217
  )
198
218
  .map_err(magnus::Error::from)?;
199
219
  Ok(RbEmbedder {
@@ -224,6 +244,8 @@ impl RbReranker {
224
244
  sigmoid: bool,
225
245
  output_tensor: String,
226
246
  max_length: usize,
247
+ padding: String,
248
+ execution_providers: String,
227
249
  ) -> Result<Self, Error> {
228
250
  let name = if model_name.is_empty() {
229
251
  None
@@ -240,13 +262,28 @@ impl RbReranker {
240
262
  } else {
241
263
  Some(max_length)
242
264
  };
265
+ let execution_providers_override = if execution_providers.is_empty() {
266
+ None
267
+ } else {
268
+ Some(execution_providers.as_str())
269
+ };
270
+ let padding_override = if padding.is_empty() {
271
+ None
272
+ } else {
273
+ Some(padding.as_str())
274
+ };
275
+ let overrides = ModelLoadOverrides {
276
+ model_name: name,
277
+ output_tensor: output_override,
278
+ max_length: max_length_override,
279
+ padding: padding_override,
280
+ execution_providers: execution_providers_override,
281
+ };
243
282
  let reranker = Reranker::from_dir(
244
283
  &dir_path,
245
284
  num_threads,
246
285
  optimization_level,
247
- name,
248
- output_override,
249
- max_length_override,
286
+ overrides,
250
287
  )
251
288
  .map_err(magnus::Error::from)?;
252
289
  Ok(RbReranker {
@@ -262,11 +299,7 @@ impl RbReranker {
262
299
  candidates: RArray,
263
300
  ) -> Result<RArray, Error> {
264
301
  let candidates: Vec<String> = candidates.to_vec()?;
265
- let pairs: Vec<(String, String)> = candidates
266
- .into_iter()
267
- .map(|candidate| (query.clone(), candidate))
268
- .collect();
269
- let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
302
+ let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
270
303
 
271
304
  let out = ruby.ary_new_capa(scores.len());
272
305
  for score in scores {
@@ -362,12 +395,12 @@ impl RbTensor {
362
395
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
363
396
  let module = ruby.define_module("GTE")?;
364
397
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
365
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
398
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
366
399
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
367
400
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
368
401
 
369
402
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
370
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 7))?;
403
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
371
404
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
372
405
 
373
406
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
@@ -22,7 +22,7 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
22
22
  .with_optimization_level(opt_level)?
23
23
  .with_memory_pattern(true)?;
24
24
 
25
- let providers = preferred_execution_providers();
25
+ let providers = preferred_execution_providers(config.execution_providers.as_deref());
26
26
  if !providers.is_empty() {
27
27
  builder = builder.with_execution_providers(providers)?;
28
28
  }
@@ -34,19 +34,40 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
34
34
  Ok(builder.commit_from_file(model_path)?)
35
35
  }
36
36
 
37
- fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
38
- let order = std::env::var("GTE_EXECUTION_PROVIDERS")
39
- .unwrap_or_else(|_| "xnnpack".to_string())
40
- .to_ascii_lowercase();
37
+ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
38
+ let order = resolve_provider_order(order_override);
41
39
 
42
40
  let mut providers = Vec::new();
43
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
41
+ for provider in parse_provider_registrations(order.as_str()) {
44
42
  match provider {
45
43
  "xnnpack" => {
46
44
  providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
47
45
  }
48
46
  "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
49
- "none" => {}
47
+ _ => {}
48
+ }
49
+ }
50
+ providers
51
+ }
52
+
53
+ fn resolve_provider_order(order_override: Option<&str>) -> String {
54
+ let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
55
+ resolve_provider_order_with_env(order_override, env_order.as_deref())
56
+ }
57
+
58
+ fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
59
+ order_override
60
+ .or(env_order)
61
+ .unwrap_or("cpu")
62
+ .to_ascii_lowercase()
63
+ }
64
+
65
+ fn parse_provider_registrations(order: &str) -> Vec<&str> {
66
+ let mut providers = Vec::new();
67
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
68
+ match provider {
69
+ "xnnpack" | "coreml" => providers.push(provider),
70
+ "none" | "cpu" => {}
50
71
  _ => {}
51
72
  }
52
73
  }
@@ -86,3 +107,42 @@ pub fn run_session(
86
107
  ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
87
108
  }
88
109
  }
110
+
111
+ #[cfg(test)]
112
+ mod tests {
113
+ use super::{parse_provider_registrations, resolve_provider_order_with_env};
114
+
115
+ #[test]
116
+ fn parse_provider_registrations_keeps_supported_order() {
117
+ let parsed = parse_provider_registrations("xnnpack,coreml");
118
+ assert_eq!(parsed, vec!["xnnpack", "coreml"]);
119
+ }
120
+
121
+ #[test]
122
+ fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
123
+ assert!(parse_provider_registrations("cpu").is_empty());
124
+ assert!(parse_provider_registrations("none").is_empty());
125
+ assert!(parse_provider_registrations("none,cpu").is_empty());
126
+ }
127
+
128
+ #[test]
129
+ fn parse_provider_registrations_ignores_unknowns_and_empties() {
130
+ let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
131
+ assert_eq!(parsed, vec!["xnnpack", "coreml"]);
132
+ }
133
+
134
+ #[test]
135
+ fn resolve_provider_order_prefers_override() {
136
+ assert_eq!(
137
+ resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
138
+ "xnnpack"
139
+ );
140
+ assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
141
+ }
142
+
143
+ #[test]
144
+ fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
145
+ assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
146
+ assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
147
+ }
148
+ }
@@ -1,4 +1,5 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::PaddingMode;
2
3
  use std::path::Path;
3
4
  use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
4
5
 
@@ -20,6 +21,8 @@ impl Tokenizer {
20
21
  tokenizer_path: P,
21
22
  max_length: usize,
22
23
  with_type_ids: bool,
24
+ padding_mode: PaddingMode,
25
+ fixed_padding_length: Option<usize>,
23
26
  ) -> Result<Self> {
24
27
  let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
25
28
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
@@ -33,7 +36,7 @@ impl Tokenizer {
33
36
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
34
37
 
35
38
  let padding = PaddingParams {
36
- strategy: PaddingStrategy::BatchLongest,
39
+ strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
37
40
  ..Default::default()
38
41
  };
39
42
  tokenizer.with_padding(Some(padding));
@@ -73,6 +76,56 @@ impl Tokenizer {
73
76
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
74
77
  build_tokenized(&encodings, self.with_type_ids)
75
78
  }
79
+
80
+ pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
81
+ let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
82
+ .iter()
83
+ .map(|candidate| (query, candidate.as_str()).into())
84
+ .collect();
85
+ let encodings = self
86
+ .tokenizer
87
+ .encode_batch_fast(encode_inputs, true)
88
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
89
+ build_tokenized(&encodings, self.with_type_ids)
90
+ }
91
+ }
92
+
93
+ pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
94
+ let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
95
+ return Ok(None);
96
+ };
97
+
98
+ let normalized = raw.to_ascii_lowercase().replace('-', "_");
99
+ let parsed = match normalized.as_str() {
100
+ "auto" => PaddingMode::Auto,
101
+ "batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
102
+ "fixed" => PaddingMode::Fixed,
103
+ _ => {
104
+ return Err(GteError::Inference(format!(
105
+ "invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
106
+ raw
107
+ )))
108
+ }
109
+ };
110
+ Ok(Some(parsed))
111
+ }
112
+
113
+ fn resolve_padding_strategy(
114
+ padding_mode: PaddingMode,
115
+ max_length: usize,
116
+ fixed_padding_length: Option<usize>,
117
+ ) -> PaddingStrategy {
118
+ match padding_mode {
119
+ PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
120
+ PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
121
+ PaddingMode::Auto => {
122
+ if fixed_padding_length.is_some() {
123
+ PaddingStrategy::Fixed(max_length)
124
+ } else {
125
+ PaddingStrategy::BatchLongest
126
+ }
127
+ }
128
+ }
76
129
  }
77
130
 
78
131
  fn build_tokenized_single(
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
121
174
  let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
122
175
 
123
176
  for encoding in encodings {
124
- input_ids.extend(encoding.get_ids().iter().map(|&value| i64::from(value)));
125
- attn_masks.extend(
126
- encoding
127
- .get_attention_mask()
128
- .iter()
129
- .map(|&value| i64::from(value)),
130
- );
177
+ for &value in encoding.get_ids() {
178
+ input_ids.push(i64::from(value));
179
+ }
180
+ for &value in encoding.get_attention_mask() {
181
+ attn_masks.push(i64::from(value));
182
+ }
131
183
 
132
184
  if let Some(type_ids) = type_ids.as_mut() {
133
- type_ids.extend(
134
- encoding
135
- .get_type_ids()
136
- .iter()
137
- .map(|&value| i64::from(value)),
138
- );
185
+ for &value in encoding.get_type_ids() {
186
+ type_ids.push(i64::from(value));
187
+ }
139
188
  }
140
189
  }
141
190
 
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
147
196
  type_ids,
148
197
  })
149
198
  }
199
+
200
+ #[cfg(test)]
201
+ mod tests {
202
+ use super::{parse_padding_mode_override, resolve_padding_strategy};
203
+ use crate::model_config::PaddingMode;
204
+ use tokenizers::PaddingStrategy;
205
+
206
+ #[test]
207
+ fn parse_padding_mode_override_accepts_expected_values() {
208
+ assert_eq!(
209
+ parse_padding_mode_override(Some("auto")).unwrap(),
210
+ Some(PaddingMode::Auto)
211
+ );
212
+ assert_eq!(
213
+ parse_padding_mode_override(Some("batch-longest")).unwrap(),
214
+ Some(PaddingMode::BatchLongest)
215
+ );
216
+ assert_eq!(
217
+ parse_padding_mode_override(Some("fixed")).unwrap(),
218
+ Some(PaddingMode::Fixed)
219
+ );
220
+ }
221
+
222
+ #[test]
223
+ fn parse_padding_mode_override_rejects_invalid_values() {
224
+ assert!(parse_padding_mode_override(Some("unknown")).is_err());
225
+ }
226
+
227
+ #[test]
228
+ fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
229
+ match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
230
+ PaddingStrategy::Fixed(64) => {}
231
+ other => panic!("expected Fixed(64), got {:?}", other),
232
+ }
233
+ }
234
+ }
@@ -1,12 +1,13 @@
1
1
  use gte::embedder::Embedder;
2
+ use gte::model_config::ModelLoadOverrides;
2
3
 
3
4
  #[test]
4
5
  #[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
5
6
  fn test_e5_single_embedding_shape() {
6
7
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
8
 
8
- let embedder =
9
- Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
9
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
10
+ .expect("embedder should initialize");
10
11
  let result = embedder
11
12
  .embed(vec!["query: Hello world".to_string()])
12
13
  .expect("embed should succeed");
@@ -20,8 +21,8 @@ fn test_e5_single_embedding_shape() {
20
21
  fn test_clip_single_embedding_shape() {
21
22
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
22
23
 
23
- let embedder =
24
- Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
24
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
25
+ .expect("embedder should initialize");
25
26
  let result = embedder
26
27
  .embed(vec!["a photo of a cat".to_string()])
27
28
  .expect("embed should succeed");
@@ -35,8 +36,8 @@ fn test_clip_single_embedding_shape() {
35
36
  fn test_e5_batch_embedding_shape() {
36
37
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
37
38
 
38
- let embedder =
39
- Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
39
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
40
+ .expect("embedder should initialize");
40
41
  let texts = vec![
41
42
  "query: first sentence".to_string(),
42
43
  "query: second sentence".to_string(),
@@ -54,8 +55,8 @@ fn test_e5_batch_embedding_shape() {
54
55
  fn test_e5_long_input_truncation_no_error() {
55
56
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
56
57
 
57
- let embedder =
58
- Embedder::from_dir(DIR, 0, 3, None, None, None).expect("embedder should initialize");
58
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
59
+ .expect("embedder should initialize");
59
60
  let very_long_text = "word ".repeat(1000);
60
61
  let result = embedder
61
62
  .embed(vec![very_long_text])
@@ -1,3 +1,4 @@
1
+ use gte::model_config::PaddingMode;
1
2
  use gte::tokenizer::Tokenizer;
2
3
 
3
4
  #[test]
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
8
9
  "/tests/fixtures/e5/tokenizer.json"
9
10
  );
10
11
 
11
- let tokenizer = Tokenizer::new(TOKENIZER, 512, true).expect("tokenizer should load");
12
+ let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
13
+ .expect("tokenizer should load");
12
14
  let texts = vec![
13
15
  "Hello, world!".to_string(),
14
16
  "A second, longer sentence to test padding behavior.".to_string(),
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
33
35
  "/tests/fixtures/e5/tokenizer.json"
34
36
  );
35
37
 
36
- let tokenizer = Tokenizer::new(TOKENIZER, 16, false).expect("tokenizer should load");
38
+ let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
39
+ .expect("tokenizer should load");
37
40
  let long_text = "word ".repeat(200);
38
41
  let tokenized = tokenizer
39
42
  .tokenize(&[long_text])
data/lib/gte/config.rb CHANGED
@@ -4,12 +4,12 @@ module GTE
4
4
  module Config
5
5
  Text = Data.define(
6
6
  :model_dir, :threads, :optimization_level,
7
- :model_name, :normalize, :output_tensor, :max_length
7
+ :model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
8
8
  )
9
9
 
10
10
  Reranker = Data.define(
11
11
  :model_dir, :threads, :optimization_level,
12
- :model_name, :sigmoid, :output_tensor, :max_length
12
+ :model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
13
13
  )
14
14
  end
15
15
  end
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ module GTE
4
+ class Embedder
5
+ class << self
6
+ def config(model_dir)
7
+ cfg = default_config(model_dir)
8
+ cfg = yield(cfg) if block_given?
9
+ from_config(cfg)
10
+ end
11
+
12
+ def from_config(config)
13
+ new(
14
+ config.model_dir,
15
+ config.threads,
16
+ config.optimization_level,
17
+ config.model_name.to_s,
18
+ config.normalize,
19
+ config.output_tensor.to_s,
20
+ config.max_length || 0,
21
+ config.padding.to_s,
22
+ config.execution_providers.to_s
23
+ )
24
+ end
25
+
26
+ private
27
+
28
+ def default_config(model_dir)
29
+ Config::Text.new(
30
+ model_dir: File.expand_path(model_dir),
31
+ threads: 3,
32
+ optimization_level: 3,
33
+ model_name: nil,
34
+ normalize: true,
35
+ output_tensor: nil,
36
+ max_length: nil,
37
+ padding: nil,
38
+ execution_providers: nil
39
+ )
40
+ end
41
+ end
42
+ end
43
+ end
data/lib/gte/model.rb CHANGED
@@ -8,15 +8,7 @@ module GTE
8
8
  raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
9
9
 
10
10
  @config = config
11
- @embedder = GTE::Embedder.new(
12
- config.model_dir,
13
- config.threads,
14
- config.optimization_level,
15
- config.model_name.to_s,
16
- config.normalize,
17
- config.output_tensor.to_s,
18
- config.max_length || 0
19
- )
11
+ @embedder = GTE::Embedder.from_config(config)
20
12
  end
21
13
 
22
14
  def embed(texts)
data/lib/gte/reranker.rb CHANGED
@@ -24,7 +24,9 @@ module GTE
24
24
  model_name: nil,
25
25
  sigmoid: false,
26
26
  output_tensor: nil,
27
- max_length: nil
27
+ max_length: nil,
28
+ padding: nil,
29
+ execution_providers: nil
28
30
  )
29
31
  end
30
32
 
@@ -36,7 +38,9 @@ module GTE
36
38
  cfg.model_name.to_s,
37
39
  cfg.sigmoid,
38
40
  cfg.output_tensor.to_s,
39
- cfg.max_length || 0
41
+ cfg.max_length || 0,
42
+ cfg.padding.to_s,
43
+ cfg.execution_providers.to_s
40
44
  )
41
45
  end
42
46
  end
data/lib/gte.rb CHANGED
@@ -9,6 +9,7 @@ rescue LoadError
9
9
  end
10
10
 
11
11
  require 'gte/config'
12
+ require 'gte/embedder'
12
13
  require 'gte/model'
13
14
  require 'gte/reranker'
14
15
 
@@ -25,7 +26,9 @@ module GTE
25
26
  model_name: nil,
26
27
  normalize: true,
27
28
  output_tensor: nil,
28
- max_length: nil
29
+ max_length: nil,
30
+ padding: nil,
31
+ execution_providers: nil
29
32
  )
30
33
 
31
34
  cfg = yield(cfg) if block_given?
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.5
4
+ version: 0.0.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2026-04-15 00:00:00.000000000 Z
11
+ date: 2026-04-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -114,6 +114,7 @@ files:
114
114
  - ext/gte/tests/tokenizer_unit_test.rs
115
115
  - lib/gte.rb
116
116
  - lib/gte/config.rb
117
+ - lib/gte/embedder.rb
117
118
  - lib/gte/model.rb
118
119
  - lib/gte/reranker.rb
119
120
  - lib/gte/version.rb