gte 0.0.6 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc149108c647dc5b14154bfbdc4975b53670b9ed3cf7d80760cc2b415c935a48
4
- data.tar.gz: 32a682a95d56c8fab8d0d64a7ada0c0347ae796b6aefe6191f9aca8fc96426c2
3
+ metadata.gz: 2c754b4675ee105e9a280cd9deafa00a81b9e02ee629131f3e908400006b6ae4
4
+ data.tar.gz: 40a0d3e04c3d2943ae50910164d644ecb763eac99a02044dc962cc141a0e13c5
5
5
  SHA512:
6
- metadata.gz: f5c69d954f51a51521b143b576942a9c0505ad60574c1727f963dd79e0b6c22cacc4e6d9af75394ae06f451521dbc788af51f1e79397a5cc66a41b4ce1b31933
7
- data.tar.gz: 9e75fdbc9b5c8cfdd9d0e377a7e4a944057ec604e38ab23d960c4ed75ec6a72ce1dd27c2dd1bb2802721387babdabe0996e0c42be34d17d98253e0582b375de1
6
+ metadata.gz: 16614e01e7a33a53339ba9fe7cf32fe7606041518a24177258d7a6e5550516e8cff741d0f0df02b7e5863fc763c02ae81b943dc4b18295701a4cafdec6627cb0
7
+ data.tar.gz: 348e1fd1d9f4c44214b5101ba339109b5ececfbef18b48b7c11324a64481f476d8da831cc5148d17a85c41b525ee753c296d4421a4fb2adda269a3f5fe38cda6
data/README.md CHANGED
@@ -33,14 +33,15 @@ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
33
  config.with(normalize: false)
34
34
  end
35
35
 
36
- full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
- config.with(threads: 0)
36
+ single_thread = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
+ config.with(threads: 1)
38
38
  end
39
39
 
40
40
  custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
41
  config.with(
42
42
  output_tensor: "last_hidden_state",
43
43
  max_length: 256,
44
+ padding: "batch_longest",
44
45
  optimization_level: 3
45
46
  )
46
47
  end
@@ -49,12 +50,13 @@ end
49
50
  Config fields and defaults:
50
51
 
51
52
  - `model_dir`: absolute path to model directory
52
- - `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
53
+ - `threads`: `1` (default tuned for p95 latency; use `0` for ONNX Runtime auto-thread mode)
53
54
  - `optimization_level`: `3`
54
55
  - `model_name`: `nil`
55
56
  - `normalize`: `true` (L2 normalization at Ruby-facing API)
56
57
  - `output_tensor`: `nil` (auto-select output tensor)
57
58
  - `max_length`: `nil` (uses tokenizer/model defaults)
59
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
58
60
  - `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
59
61
 
60
62
  Notes:
@@ -66,7 +68,7 @@ Low-level embedder setup (without model cache):
66
68
 
67
69
  ```ruby
68
70
  embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
69
- config.with(threads: 0, execution_providers: "cpu")
71
+ config.with(threads: 1, execution_providers: "cpu")
70
72
  end
71
73
  ```
72
74
 
@@ -76,7 +78,7 @@ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
76
78
 
77
79
  ```ruby
78
80
  reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
79
- config.with(sigmoid: true, threads: 0)
81
+ config.with(sigmoid: true, threads: 1)
80
82
  end
81
83
 
82
84
  query = "how to train a neural network?"
@@ -100,12 +102,13 @@ ranked = reranker.rerank(query: query, candidates: candidates)
100
102
  Reranker config fields and defaults:
101
103
 
102
104
  - `model_dir`: absolute path to model directory
103
- - `threads`: `3`
105
+ - `threads`: `1`
104
106
  - `optimization_level`: `3`
105
107
  - `model_name`: `nil`
106
108
  - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
107
109
  - `output_tensor`: `nil`
108
110
  - `max_length`: `nil`
111
+ - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
109
112
  - `execution_providers`: `nil`
110
113
 
111
114
  ## Runtime + Result Examples
@@ -171,7 +174,7 @@ make ci
171
174
 
172
175
  ## Benchmark
173
176
 
174
- The repo includes two benchmark paths:
177
+ The repo includes a shared multi-runtime benchmark harness:
175
178
 
176
179
  ```bash
177
180
  make bench
@@ -180,6 +183,11 @@ nix develop -c bundle exec rake bench:matrix_sweep
180
183
  nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
181
184
  ```
182
185
 
186
+ - `make bench`: Puma-like single-request comparison at concurrency `16`
187
+ - `rake bench:pure_compare`: batch amortization comparison
188
+ - `rake bench:matrix_sweep`: GTE provider/thread sweep using the shared result schema
189
+ - Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
190
+
183
191
  To run benchmark + append a `RUNS.md` entry + enforce goal checks:
184
192
 
185
193
  ```bash
@@ -188,5 +196,5 @@ make bench-record
188
196
 
189
197
  `bench/runs_ledger.rb check` is goal-focused by default:
190
198
 
191
- - Enforces goal metric (`response_time_p95` ratio threshold).
199
+ - Enforces the goal metric (`response_time_p95`) across every enabled competitor.
192
200
  - Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
data/Rakefile CHANGED
@@ -10,17 +10,52 @@ rescue LoadError
10
10
  end
11
11
 
12
12
  spec = Gem::Specification.load('gte.gemspec')
13
+ cross_target = ENV.fetch('RUBY_TARGET', nil)
13
14
 
14
- Rake::ExtensionTask.new('gte', spec) do |ext|
15
+ if cross_target == 'arm64-darwin'
16
+ # rb-sys-dock's darwin image can expose an unusable default LIBRARY_PATH.
17
+ # Force the compiler-rt darwin runtime directory so -lclang_rt.osx resolves.
18
+ ENV['LIBRARY_PATH'] = '/usr/lib/llvm-10/lib/clang/10.0.0/lib/darwin'
19
+ end
20
+
21
+ extension_task = Rake::ExtensionTask.new('gte', spec) do |ext|
15
22
  ext.lib_dir = 'lib/gte'
16
23
  ext.cross_compile = true
17
- ext.cross_platform = %w[x86_64-linux arm64-darwin]
24
+ # rb-sys-dock invokes `rake native:$RUBY_TARGET gem` without the `cross` task,
25
+ # so scope platforms during dock builds to avoid host-Ruby fallback copy tasks.
26
+ cross_platforms = if cross_target && !cross_target.empty?
27
+ [cross_target]
28
+ else
29
+ %w[x86_64-linux aarch64-linux arm64-darwin]
30
+ end
31
+ ext.cross_platform = cross_platforms
32
+ end
33
+
34
+ if cross_target && !cross_target.empty? && ENV['RUBY_CC_VERSION']
35
+ ruby_version = ENV['RUBY_CC_VERSION'].split(':').first
36
+ lib_binary_path = File.join(extension_task.lib_dir, File.basename(extension_task.binary(cross_target)))
37
+ copy_task = "copy:gte:#{cross_target}:#{ruby_version}"
38
+
39
+ if Rake::Task.task_defined?(lib_binary_path) && Rake::Task.task_defined?(copy_task)
40
+ Rake::Task[lib_binary_path].prerequisites.clear
41
+ Rake::Task[lib_binary_path].enhance([copy_task])
42
+ end
18
43
  end
19
44
 
20
45
  task default: %i[compile spec]
21
46
 
47
+ def bundler_env
48
+ root = File.expand_path(__dir__)
49
+ {
50
+ 'BUNDLE_DISABLE_SHARED_GEMS' => '1',
51
+ 'GEM_HOME' => File.join(root, '.bundle-gems'),
52
+ 'GEM_PATH' => File.join(root, '.bundle-gems'),
53
+ 'BUNDLE_PATH' => File.join(root, 'vendor/bundle')
54
+ }
55
+ end
56
+
22
57
  def run_in_nix(*command)
23
- sh('nix', 'develop', '-c', *command)
58
+ sh(bundler_env, 'nix', 'develop', '-c', *command)
24
59
  end
25
60
 
26
61
  namespace :bench do
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.6
1
+ 0.0.8
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.6"
3
+ version = "0.0.8"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -21,10 +21,10 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
21
21
  [dependencies]
22
22
  rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
23
23
  magnus = { version = "0.8", optional = true }
24
- ort = { version = "=2.0.0-rc.9", features = ["ndarray"] }
25
- ort-sys = "=2.0.0-rc.9"
24
+ ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
25
+ ort-sys = "=2.0.0-rc.12"
26
26
  tokenizers = "0.21.0"
27
- ndarray = "0.16.0"
27
+ ndarray = "0.17"
28
28
  half = "2"
29
29
  serde = { version = "1", features = ["derive"] }
30
30
  serde_json = "1"
@@ -1,19 +1,18 @@
1
1
  use crate::error::{GteError, Result};
2
- use crate::model_config::{ExtractorMode, ModelConfig};
2
+ use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
3
3
  use crate::model_profile::{
4
- has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
5
- resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
4
+ has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
5
+ resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
- use crate::session::{build_session, run_session};
9
- use crate::tokenizer::{Tokenized, Tokenizer};
8
+ use crate::session::{build_session, run_session, SessionPool};
9
+ use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
10
10
  use ndarray::Array2;
11
- use ort::session::Session;
12
- use std::path::Path;
11
+ use std::path::{Path, PathBuf};
13
12
 
14
13
  pub struct Embedder {
15
14
  tokenizer: Tokenizer,
16
- session: Session,
15
+ pool: SessionPool,
17
16
  config: ModelConfig,
18
17
  }
19
18
 
@@ -23,23 +22,24 @@ impl Embedder {
23
22
  P1: AsRef<Path>,
24
23
  P2: AsRef<Path>,
25
24
  {
26
- let tokenizer = Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids)?;
27
- let session = build_session(model_path, &config)?;
28
- Ok(Self {
29
- tokenizer,
30
- session,
31
- config,
32
- })
25
+ let tokenizer = Tokenizer::new(
26
+ tokenizer_path,
27
+ config.max_length,
28
+ config.with_type_ids,
29
+ config.padding_mode,
30
+ None,
31
+ )?;
32
+ let model_path = model_path.as_ref().to_path_buf();
33
+ let session = build_session(&model_path, &config)?;
34
+ let pool = SessionPool::new(session, model_path, config.clone());
35
+ Ok(Self { tokenizer, pool, config })
33
36
  }
34
37
 
35
38
  pub fn from_dir<P: AsRef<Path>>(
36
39
  dir: P,
37
40
  num_threads: usize,
38
41
  optimization_level: u8,
39
- model_name: Option<&str>,
40
- output_tensor_override: Option<&str>,
41
- max_length_override: Option<usize>,
42
- execution_providers_override: Option<&str>,
42
+ overrides: ModelLoadOverrides<'_>,
43
43
  ) -> Result<Self> {
44
44
  const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
45
45
  "pooler_output",
@@ -50,31 +50,35 @@ impl Embedder {
50
50
 
51
51
  let dir = dir.as_ref();
52
52
  let tokenizer_path = resolve_tokenizer_path(dir)?;
53
- let model_path = match model_name.filter(|s| !s.is_empty()) {
53
+ let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
54
54
  Some(name) => resolve_named_model(dir, name)?,
55
55
  None => resolve_default_text_model(dir)?,
56
56
  };
57
57
 
58
- let max_length = if let Some(override_value) = max_length_override {
58
+ let tokenizer_profile = read_tokenizer_profile(dir);
59
+ let max_length = if let Some(override_value) = overrides.max_length {
59
60
  if override_value == 0 {
60
61
  return Err(GteError::Inference(
61
62
  "max_length override must be greater than 0".to_string(),
62
63
  ));
63
64
  }
64
- override_value
65
+ override_value.min(tokenizer_profile.safe_max_length)
65
66
  } else {
66
- read_max_length(dir)
67
+ tokenizer_profile.default_max_length
67
68
  };
69
+ let padding_mode =
70
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
68
71
 
69
72
  let session_config = ModelConfig {
70
73
  max_length,
74
+ padding_mode,
71
75
  output_tensor: String::new(),
72
76
  mode: ExtractorMode::Raw,
73
77
  with_type_ids: false,
74
78
  with_attention_mask: true,
75
79
  num_threads,
76
80
  optimization_level,
77
- execution_providers: execution_providers_override.map(str::to_string),
81
+ execution_providers: overrides.execution_providers.map(str::to_string),
78
82
  };
79
83
  let session = build_session(&model_path, &session_config)?;
80
84
 
@@ -82,7 +86,7 @@ impl Embedder {
82
86
  let with_type_ids = has_input(&session, "token_type_ids");
83
87
  let with_attention_mask = has_input(&session, "attention_mask");
84
88
  let output_tensor =
85
- select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
89
+ select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
86
90
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
87
91
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
88
92
  return Err(GteError::Inference(
@@ -92,22 +96,26 @@ impl Embedder {
92
96
 
93
97
  let config = ModelConfig {
94
98
  max_length,
99
+ padding_mode,
95
100
  output_tensor,
96
101
  mode,
97
102
  with_type_ids,
98
103
  with_attention_mask,
99
104
  num_threads,
100
105
  optimization_level,
101
- execution_providers: execution_providers_override.map(str::to_string),
106
+ execution_providers: overrides.execution_providers.map(str::to_string),
102
107
  };
103
108
 
104
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
109
+ let tokenizer = Tokenizer::new(
110
+ &tokenizer_path,
111
+ config.max_length,
112
+ config.with_type_ids,
113
+ config.padding_mode,
114
+ tokenizer_profile.fixed_padding_length,
115
+ )?;
105
116
 
106
- Ok(Self {
107
- tokenizer,
108
- session,
109
- config,
110
- })
117
+ let pool = SessionPool::new(session, model_path, session_config);
118
+ Ok(Self { tokenizer, pool, config })
111
119
  }
112
120
 
113
121
  pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
@@ -120,7 +128,8 @@ impl Embedder {
120
128
  }
121
129
 
122
130
  pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
123
- run_session(&self.session, tokenized, &self.config)
131
+ let mut session = self.pool.acquire()?;
132
+ run_session(&mut session, tokenized, &self.config)
124
133
  }
125
134
  }
126
135
 
@@ -5,9 +5,18 @@ pub enum ExtractorMode {
5
5
  Raw,
6
6
  }
7
7
 
8
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9
+ pub enum PaddingMode {
10
+ #[default]
11
+ Auto,
12
+ BatchLongest,
13
+ Fixed,
14
+ }
15
+
8
16
  #[derive(Debug, Clone)]
9
17
  pub struct ModelConfig {
10
18
  pub max_length: usize,
19
+ pub padding_mode: PaddingMode,
11
20
  pub output_tensor: String,
12
21
  pub mode: ExtractorMode,
13
22
  pub with_type_ids: bool,
@@ -16,3 +25,12 @@ pub struct ModelConfig {
16
25
  pub optimization_level: u8,
17
26
  pub execution_providers: Option<String>,
18
27
  }
28
+
29
+ #[derive(Debug, Clone, Copy, Default)]
30
+ pub struct ModelLoadOverrides<'a> {
31
+ pub model_name: Option<&'a str>,
32
+ pub output_tensor: Option<&'a str>,
33
+ pub max_length: Option<usize>,
34
+ pub padding: Option<&'a str>,
35
+ pub execution_providers: Option<&'a str>,
36
+ }
@@ -1,9 +1,19 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::ExtractorMode;
3
3
  use ort::session::Session;
4
+ use serde_json::Value;
4
5
  use std::path::{Path, PathBuf};
5
6
 
6
7
  const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
8
+ const DEFAULT_MAX_LENGTH: usize = 512;
9
+ const MAX_SUPPORTED_LENGTH: usize = 8192;
10
+
11
+ #[derive(Debug, Clone, Copy)]
12
+ pub struct TokenizerProfile {
13
+ pub default_max_length: usize,
14
+ pub safe_max_length: usize,
15
+ pub fixed_padding_length: Option<usize>,
16
+ }
7
17
 
8
18
  pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
9
19
  let tokenizer_path = dir.join("tokenizer.json");
@@ -48,27 +58,84 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
48
58
  )))
49
59
  }
50
60
 
51
- pub fn read_max_length(dir: &Path) -> usize {
52
- (|| -> Option<usize> {
53
- let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
54
- let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
55
- let v = json.get("model_max_length")?;
56
- let n = v.as_u64().or_else(|| {
57
- v.as_f64()
58
- .filter(|&f| f > 0.0 && f < 1e15)
59
- .map(|f| f as u64)
60
- })?;
61
- Some((n as usize).min(8192))
62
- })()
63
- .unwrap_or(512)
61
+ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
62
+ let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
63
+ let tokenizer_json = read_json(dir.join("tokenizer.json"));
64
+
65
+ let fixed_padding_length = tokenizer_json
66
+ .as_ref()
67
+ .and_then(parse_fixed_padding_length_from_tokenizer_json);
68
+
69
+ let mut candidates = Vec::new();
70
+ if let Some(config) = tokenizer_config.as_ref() {
71
+ if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
72
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
73
+ }
74
+ if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
75
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
76
+ }
77
+ }
78
+
79
+ if let Some(tokenizer) = tokenizer_json.as_ref() {
80
+ if let Some(v) = tokenizer
81
+ .get("truncation")
82
+ .and_then(|truncation| truncation.get("max_length"))
83
+ .and_then(parse_positive_usize)
84
+ {
85
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
86
+ }
87
+ }
88
+
89
+ if let Some(v) = fixed_padding_length {
90
+ candidates.push(v.min(MAX_SUPPORTED_LENGTH));
91
+ }
92
+
93
+ let default_max_length = candidates
94
+ .iter()
95
+ .copied()
96
+ .min()
97
+ .unwrap_or(DEFAULT_MAX_LENGTH)
98
+ .max(1);
99
+ let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
100
+
101
+ TokenizerProfile {
102
+ default_max_length,
103
+ safe_max_length,
104
+ fixed_padding_length,
105
+ }
106
+ }
107
+
108
+ fn read_json(path: PathBuf) -> Option<Value> {
109
+ let contents = std::fs::read_to_string(path).ok()?;
110
+ serde_json::from_str(&contents).ok()
111
+ }
112
+
113
+ fn parse_positive_usize(value: &Value) -> Option<usize> {
114
+ let raw = value
115
+ .as_u64()
116
+ .or_else(|| {
117
+ value
118
+ .as_f64()
119
+ .filter(|&v| v.is_finite() && v > 0.0)
120
+ .map(|v| v as u64)
121
+ })
122
+ .or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
123
+ let parsed = usize::try_from(raw).ok()?;
124
+ (parsed > 0).then_some(parsed)
125
+ }
126
+
127
+ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
128
+ tokenizer_json
129
+ .get("padding")
130
+ .and_then(|padding| padding.get("strategy"))
131
+ .and_then(|strategy| strategy.get("Fixed"))
132
+ .and_then(parse_positive_usize)
64
133
  }
65
134
 
66
135
  pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
67
- let unsupported: Vec<String> = session
68
- .inputs
69
- .iter()
70
- .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
71
- .map(|i| i.name.clone())
136
+ let unsupported: Vec<String> = session.inputs().iter()
137
+ .filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
138
+ .map(|i| i.name().to_owned())
72
139
  .collect();
73
140
 
74
141
  if unsupported.is_empty() {
@@ -91,7 +158,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
91
158
  }
92
159
 
93
160
  pub fn has_input(session: &Session, name: &str) -> bool {
94
- session.inputs.iter().any(|input| input.name == name)
161
+ session.inputs().iter().any(|input| input.name() == name)
95
162
  }
96
163
 
97
164
  fn output_name_matches(name: &str, preferred: &str) -> bool {
@@ -106,16 +173,16 @@ pub fn select_output_tensor(
106
173
  ) -> Result<String> {
107
174
  if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
108
175
  if let Some(output) = session
109
- .outputs
176
+ .outputs()
110
177
  .iter()
111
- .find(|o| output_name_matches(o.name.as_str(), requested_name))
178
+ .find(|o| output_name_matches(o.name(), requested_name))
112
179
  {
113
- return Ok(output.name.clone());
180
+ return Ok(output.name().to_owned());
114
181
  }
115
182
  let available = session
116
- .outputs
183
+ .outputs()
117
184
  .iter()
118
- .map(|o| o.name.as_str())
185
+ .map(|o| o.name())
119
186
  .collect::<Vec<_>>()
120
187
  .join(", ");
121
188
  return Err(GteError::Inference(format!(
@@ -126,18 +193,18 @@ pub fn select_output_tensor(
126
193
 
127
194
  for preferred in preferred_outputs {
128
195
  if let Some(output) = session
129
- .outputs
196
+ .outputs()
130
197
  .iter()
131
- .find(|o| output_name_matches(o.name.as_str(), preferred))
198
+ .find(|o| output_name_matches(o.name(), preferred))
132
199
  {
133
- return Ok(output.name.clone());
200
+ return Ok(output.name().to_owned());
134
201
  }
135
202
  }
136
203
 
137
204
  session
138
- .outputs
205
+ .outputs()
139
206
  .first()
140
- .map(|o| o.name.clone())
207
+ .map(|o| o.name().to_owned())
141
208
  .ok_or_else(|| GteError::Inference("model has no outputs".into()))
142
209
  }
143
210
 
@@ -147,9 +214,9 @@ fn output_basename(name: &str) -> &str {
147
214
 
148
215
  pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
149
216
  let output = session
150
- .outputs
217
+ .outputs()
151
218
  .iter()
152
- .find(|o| o.name == output_tensor)
219
+ .find(|o| o.name() == output_tensor)
153
220
  .ok_or_else(|| {
154
221
  GteError::Inference(format!(
155
222
  "output tensor '{}' not found in model outputs",
@@ -157,8 +224,8 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
157
224
  ))
158
225
  })?;
159
226
 
160
- let ndims = match &output.output_type {
161
- ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
227
+ let ndims = match output.dtype() {
228
+ ort::value::ValueType::Tensor { shape, .. } => shape.len(),
162
229
  other => {
163
230
  return Err(GteError::Inference(format!(
164
231
  "output is not a tensor: {:?}",
@@ -177,3 +244,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
177
244
  ))),
178
245
  }
179
246
  }
247
+
248
+ #[cfg(test)]
249
+ mod tests {
250
+ use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
251
+ use serde_json::json;
252
+
253
+ #[test]
254
+ fn parse_positive_usize_handles_integer_float_and_string() {
255
+ assert_eq!(parse_positive_usize(&json!(64)), Some(64));
256
+ assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
257
+ assert_eq!(parse_positive_usize(&json!("64")), Some(64));
258
+ assert_eq!(parse_positive_usize(&json!(0)), None);
259
+ }
260
+
261
+ #[test]
262
+ fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
263
+ let tokenizer_json = json!({
264
+ "padding": {
265
+ "strategy": {
266
+ "Fixed": 64
267
+ }
268
+ }
269
+ });
270
+ assert_eq!(
271
+ parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
272
+ Some(64)
273
+ );
274
+ }
275
+ }
@@ -1,8 +1,8 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::tokenizer::Tokenized;
3
- use ndarray::ArrayView2;
3
+ use ndarray::{ArrayView2, ArrayViewD};
4
4
  use ort::session::SessionInputValue;
5
- use ort::value::Value;
5
+ use ort::value::TensorRef;
6
6
 
7
7
  pub struct InputTensors<'a> {
8
8
  pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
@@ -23,13 +23,13 @@ impl<'a> InputTensors<'a> {
23
23
  let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
24
24
  inputs.push((
25
25
  "input_ids",
26
- SessionInputValue::from(Value::from_array(input_ids_view)?),
26
+ SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
27
27
  ));
28
28
 
29
29
  if with_attention_mask {
30
30
  inputs.push((
31
31
  "attention_mask",
32
- SessionInputValue::from(Value::from_array(attention_mask)?),
32
+ SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
33
33
  ));
34
34
  }
35
35
 
@@ -38,7 +38,7 @@ impl<'a> InputTensors<'a> {
38
38
  ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
39
39
  inputs.push((
40
40
  "token_type_ids",
41
- SessionInputValue::from(Value::from_array(type_ids_view)?),
41
+ SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
42
42
  ));
43
43
  }
44
44
 
@@ -50,11 +50,14 @@ impl<'a> InputTensors<'a> {
50
50
  }
51
51
 
52
52
  pub fn extract_output_tensor<'a>(
53
- outputs: &'a ort::session::SessionOutputs<'a, 'a>,
53
+ outputs: &'a ort::session::SessionOutputs<'_>,
54
54
  output_name: &str,
55
- ) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
55
+ ) -> Result<ArrayViewD<'a, f32>> {
56
56
  let tensor_value = outputs.get(output_name).ok_or_else(|| {
57
- GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
57
+ GteError::Inference(format!(
58
+ "output tensor '{}' not found in model outputs",
59
+ output_name
60
+ ))
58
61
  })?;
59
- Ok(tensor_value.try_extract_tensor::<f32>()?.into())
62
+ Ok(tensor_value.try_extract_array::<f32>()?)
60
63
  }