gte 0.0.13 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,10 +18,7 @@ pub struct TokenizerProfile {
18
18
  pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
19
19
  let tokenizer_path = dir.join("tokenizer.json");
20
20
  if !tokenizer_path.exists() {
21
- return Err(GteError::Tokenizer(format!(
22
- "tokenizer.json not found in {}",
23
- dir.display()
24
- )));
21
+ return Err(GteError::Tokenizer(format!("tokenizer.json not found in {}", dir.display())));
25
22
  }
26
23
  Ok(tokenizer_path)
27
24
  }
@@ -33,11 +30,7 @@ pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
33
30
  return Ok(path.clone());
34
31
  }
35
32
  }
36
- Err(GteError::Inference(format!(
37
- "model '{}' not found in {} (checked onnx/{0} and {0})",
38
- name,
39
- dir.display()
40
- )))
33
+ Err(GteError::Inference(format!("model '{}' not found in {} (checked onnx/{0} and {0})", name, dir.display())))
41
34
  }
42
35
 
43
36
  pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
@@ -62,9 +55,7 @@ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
62
55
  let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
63
56
  let tokenizer_json = read_json(dir.join("tokenizer.json"));
64
57
 
65
- let fixed_padding_length = tokenizer_json
66
- .as_ref()
67
- .and_then(parse_fixed_padding_length_from_tokenizer_json);
58
+ let fixed_padding_length = tokenizer_json.as_ref().and_then(parse_fixed_padding_length_from_tokenizer_json);
68
59
 
69
60
  let mut candidates = Vec::new();
70
61
  if let Some(config) = tokenizer_config.as_ref() {
@@ -90,19 +81,10 @@ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
90
81
  candidates.push(v.min(MAX_SUPPORTED_LENGTH));
91
82
  }
92
83
 
93
- let default_max_length = candidates
94
- .iter()
95
- .copied()
96
- .min()
97
- .unwrap_or(DEFAULT_MAX_LENGTH)
98
- .max(1);
84
+ let default_max_length = candidates.iter().copied().min().unwrap_or(DEFAULT_MAX_LENGTH).max(1);
99
85
  let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
100
86
 
101
- TokenizerProfile {
102
- default_max_length,
103
- safe_max_length,
104
- fixed_padding_length,
105
- }
87
+ TokenizerProfile { default_max_length, safe_max_length, fixed_padding_length }
106
88
  }
107
89
 
108
90
  fn read_json(path: PathBuf) -> Option<Value> {
@@ -113,12 +95,7 @@ fn read_json(path: PathBuf) -> Option<Value> {
113
95
  fn parse_positive_usize(value: &Value) -> Option<usize> {
114
96
  let raw = value
115
97
  .as_u64()
116
- .or_else(|| {
117
- value
118
- .as_f64()
119
- .filter(|&v| v.is_finite() && v > 0.0)
120
- .map(|v| v as u64)
121
- })
98
+ .or_else(|| value.as_f64().filter(|&v| v.is_finite() && v > 0.0).map(|v| v as u64))
122
99
  .or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
123
100
  let parsed = usize::try_from(raw).ok()?;
124
101
  (parsed > 0).then_some(parsed)
@@ -133,7 +110,9 @@ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Opt
133
110
  }
134
111
 
135
112
  pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
136
- let unsupported: Vec<String> = session.inputs().iter()
113
+ let unsupported: Vec<String> = session
114
+ .inputs()
115
+ .iter()
137
116
  .filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
138
117
  .map(|i| i.name().to_owned())
139
118
  .collect();
@@ -142,11 +121,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
142
121
  return Ok(());
143
122
  }
144
123
 
145
- let mut message = format!(
146
- "unsupported model inputs for {} API: {}",
147
- api_label,
148
- unsupported.join(", ")
149
- );
124
+ let mut message = format!("unsupported model inputs for {} API: {}", api_label, unsupported.join(", "));
150
125
  if unsupported.iter().any(|n| n == "pixel_values") {
151
126
  message.push_str(
152
127
  ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
@@ -163,40 +138,23 @@ pub fn has_input(session: &Session, name: &str) -> bool {
163
138
 
164
139
  fn output_name_matches(name: &str, preferred: &str) -> bool {
165
140
  let lower = name.to_ascii_lowercase();
166
- lower == preferred || lower.ends_with(&format!("/{}", preferred))
141
+ lower == preferred || lower.ends_with(&format!("/{preferred}"))
167
142
  }
168
143
 
169
- pub fn select_output_tensor(
170
- session: &Session,
171
- requested: Option<&str>,
172
- preferred_outputs: &[&str],
173
- ) -> Result<String> {
144
+ #[allow(clippy::redundant_closure_for_method_calls)]
145
+ pub fn select_output_tensor(session: &Session, requested: Option<&str>, preferred_outputs: &[&str]) -> Result<String> {
174
146
  if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
175
- if let Some(output) = session
176
- .outputs()
177
- .iter()
178
- .find(|o| output_name_matches(o.name(), requested_name))
179
- {
147
+ if let Some(output) = session.outputs().iter().find(|o| output_name_matches(o.name(), requested_name)) {
180
148
  return Ok(output.name().to_owned());
181
149
  }
182
- let available = session
183
- .outputs()
184
- .iter()
185
- .map(|o| o.name())
186
- .collect::<Vec<_>>()
187
- .join(", ");
150
+ let available = session.outputs().iter().map(|o| o.name()).collect::<Vec<_>>().join(", ");
188
151
  return Err(GteError::Inference(format!(
189
- "requested output tensor '{}' not found in model outputs: {}",
190
- requested_name, available
152
+ "requested output tensor '{requested_name}' not found in model outputs: {available}"
191
153
  )));
192
154
  }
193
155
 
194
156
  for preferred in preferred_outputs {
195
- if let Some(output) = session
196
- .outputs()
197
- .iter()
198
- .find(|o| output_name_matches(o.name(), preferred))
199
- {
157
+ if let Some(output) = session.outputs().iter().find(|o| output_name_matches(o.name(), preferred)) {
200
158
  return Ok(output.name().to_owned());
201
159
  }
202
160
  }
@@ -204,12 +162,9 @@ pub fn select_output_tensor(
204
162
  let outputs = session.outputs();
205
163
  let best = outputs
206
164
  .iter()
207
- .find(|o| {
208
- matches!(o.dtype(), ort::value::ValueType::Tensor { shape, .. } if shape.len() == 2)
209
- })
165
+ .find(|o| matches!(o.dtype(), ort::value::ValueType::Tensor { shape, .. } if shape.len() == 2))
210
166
  .or_else(|| outputs.first());
211
- best.map(|o| o.name().to_owned())
212
- .ok_or_else(|| GteError::Inference("model has no outputs".into()))
167
+ best.map(|o| o.name().to_owned()).ok_or_else(|| GteError::Inference("model has no outputs".into()))
213
168
  }
214
169
 
215
170
  fn output_basename(name: &str) -> &str {
@@ -217,34 +172,21 @@ fn output_basename(name: &str) -> &str {
217
172
  }
218
173
 
219
174
  pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
220
- let output = session
221
- .outputs()
222
- .iter()
223
- .find(|o| o.name() == output_tensor)
224
- .ok_or_else(|| {
225
- GteError::Inference(format!(
226
- "output tensor '{}' not found in model outputs",
227
- output_tensor
228
- ))
175
+ let output =
176
+ session.outputs().iter().find(|o| o.name() == output_tensor).ok_or_else(|| {
177
+ GteError::Inference(format!("output tensor '{output_tensor}' not found in model outputs"))
229
178
  })?;
230
179
 
231
180
  let ndims = match output.dtype() {
232
181
  ort::value::ValueType::Tensor { shape, .. } => shape.len(),
233
- other => {
234
- return Err(GteError::Inference(format!(
235
- "output is not a tensor: {:?}",
236
- other
237
- )))
238
- }
182
+ other => return Err(GteError::Inference(format!("output is not a tensor: {other:?}"))),
239
183
  };
240
184
 
241
185
  match (output_basename(output_tensor), ndims) {
242
- ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
186
+ ("last_hidden_state" | _, 3) => Ok(ExtractorMode::MeanPool),
243
187
  (_, 2) => Ok(ExtractorMode::Raw),
244
- (_, 3) => Ok(ExtractorMode::MeanPool),
245
188
  (_, n) => Err(GteError::Inference(format!(
246
- "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
247
- n, output_tensor
189
+ "unexpected output tensor rank {n} for '{output_tensor}': expected 2 (Raw) or 3 (MeanPool)"
248
190
  ))),
249
191
  }
250
192
  }
@@ -271,9 +213,6 @@ mod tests {
271
213
  }
272
214
  }
273
215
  });
274
- assert_eq!(
275
- parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
276
- Some(64)
277
- );
216
+ assert_eq!(parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json), Some(64));
278
217
  }
279
218
  }
@@ -11,41 +11,25 @@ pub struct InputTensors<'a> {
11
11
 
12
12
  impl<'a> InputTensors<'a> {
13
13
  pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
14
- let input_ids_view: ArrayView2<'_, i64> = ArrayView2::from_shape(
15
- (tokenized.rows, tokenized.cols),
16
- tokenized.input_ids.as_slice(),
17
- )?;
18
- let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
19
- (tokenized.rows, tokenized.cols),
20
- tokenized.attn_masks.as_slice(),
21
- )?;
14
+ let input_ids_view: ArrayView2<'_, i64> =
15
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
16
+ let attention_mask: ArrayView2<'_, i64> =
17
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
22
18
 
23
19
  let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
24
- inputs.push((
25
- "input_ids",
26
- SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
27
- ));
20
+ inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
28
21
 
29
22
  if with_attention_mask {
30
- inputs.push((
31
- "attention_mask",
32
- SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
33
- ));
23
+ inputs.push(("attention_mask", SessionInputValue::from(TensorRef::from_array_view(attention_mask)?)));
34
24
  }
35
25
 
36
26
  if let Some(type_ids) = tokenized.type_ids.as_deref() {
37
27
  let type_ids_view: ArrayView2<'_, i64> =
38
28
  ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
39
- inputs.push((
40
- "token_type_ids",
41
- SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
42
- ));
29
+ inputs.push(("token_type_ids", SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?)));
43
30
  }
44
31
 
45
- Ok(Self {
46
- inputs,
47
- attention_mask,
48
- })
32
+ Ok(Self { inputs, attention_mask })
49
33
  }
50
34
  }
51
35
 
@@ -53,11 +37,8 @@ pub fn extract_output_tensor<'a>(
53
37
  outputs: &'a ort::session::SessionOutputs<'_>,
54
38
  output_name: &str,
55
39
  ) -> Result<ArrayViewD<'a, f32>> {
56
- let tensor_value = outputs.get(output_name).ok_or_else(|| {
57
- GteError::Inference(format!(
58
- "output tensor '{}' not found in model outputs",
59
- output_name
60
- ))
61
- })?;
40
+ let tensor_value = outputs
41
+ .get(output_name)
42
+ .ok_or_else(|| GteError::Inference(format!("output tensor '{output_name}' not found in model outputs")))?;
62
43
  Ok(tensor_value.try_extract_array::<f32>()?)
63
44
  }
@@ -1,10 +1,7 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use ndarray::{Array2, ArrayView2, ArrayView3};
3
3
 
4
- pub fn mean_pool(
5
- hidden_states: ArrayView3<'_, f32>,
6
- attention_mask: ArrayView2<'_, i64>,
7
- ) -> Result<Array2<f32>> {
4
+ pub fn mean_pool(hidden_states: ArrayView3<'_, f32>, attention_mask: ArrayView2<'_, i64>) -> Result<Array2<f32>> {
8
5
  let (batch, seq, dim) = hidden_states.dim();
9
6
  if attention_mask.dim() != (batch, seq) {
10
7
  return Err(GteError::Inference(format!(
@@ -34,17 +31,14 @@ pub fn mean_pool(
34
31
 
35
32
  let weight = weight as f32;
36
33
  for dim_index in 0..dim {
37
- pooled[[batch_index, dim_index]] +=
38
- hidden_states[[batch_index, token_index, dim_index]] * weight;
34
+ pooled[[batch_index, dim_index]] += hidden_states[[batch_index, token_index, dim_index]] * weight;
39
35
  }
40
36
  weight_sum += weight;
41
37
  }
42
38
 
43
39
  if weight_sum > 0.0 {
44
40
  let inverse = weight_sum.recip();
45
- pooled
46
- .row_mut(batch_index)
47
- .map_inplace(|value| *value *= inverse);
41
+ pooled.row_mut(batch_index).map_inplace(|value| *value *= inverse);
48
42
  }
49
43
  }
50
44
 
@@ -89,6 +83,8 @@ fn mean_pool_contiguous(
89
83
  seq: usize,
90
84
  dim: usize,
91
85
  ) {
86
+ let seq_inverse = (seq as f32).recip();
87
+
92
88
  for batch_index in 0..batch {
93
89
  let mask_base = batch_index * seq;
94
90
  let hidden_base = batch_index * seq * dim;
@@ -103,9 +99,8 @@ fn mean_pool_contiguous(
103
99
  }
104
100
  }
105
101
 
106
- let inverse = (seq as f32).recip();
107
102
  for value in output_row {
108
- *value *= inverse;
103
+ *value *= seq_inverse;
109
104
  }
110
105
  continue;
111
106
  }
@@ -113,12 +108,11 @@ fn mean_pool_contiguous(
113
108
  let mut weight_sum = 0.0f32;
114
109
 
115
110
  for (token_index, &weight_raw) in mask_row.iter().enumerate() {
116
- let weight = weight_raw;
117
- if weight <= 0 {
111
+ if weight_raw <= 0 {
118
112
  continue;
119
113
  }
120
114
 
121
- let weight = weight as f32;
115
+ let weight = weight_raw as f32;
122
116
  let token_base = hidden_base + token_index * dim;
123
117
  for dim_index in 0..dim {
124
118
  output_row[dim_index] += hidden[token_base + dim_index] * weight;
@@ -1,8 +1,8 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ModelLoadOverrides, PaddingMode};
3
3
  use crate::model_profile::{
4
- has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
5
- resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
4
+ has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
5
+ select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
8
8
  use crate::postprocess::sigmoid_scores;
@@ -26,11 +26,7 @@ pub struct Reranker {
26
26
  }
27
27
 
28
28
  impl Reranker {
29
- pub fn from_dir<P: AsRef<Path>>(
30
- dir: P,
31
- optimization_level: u8,
32
- overrides: ModelLoadOverrides<'_>,
33
- ) -> Result<Self> {
29
+ pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
34
30
  let dir = dir.as_ref();
35
31
  let tokenizer_path = resolve_tokenizer_path(dir)?;
36
32
  let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
@@ -41,16 +37,13 @@ impl Reranker {
41
37
  let tokenizer_profile = read_tokenizer_profile(dir);
42
38
  let max_length = if let Some(override_value) = overrides.max_length {
43
39
  if override_value == 0 {
44
- return Err(GteError::Inference(
45
- "max_length override must be greater than 0".to_string(),
46
- ));
40
+ return Err(GteError::Inference("max_length override must be greater than 0".to_string()));
47
41
  }
48
42
  override_value.min(tokenizer_profile.safe_max_length)
49
43
  } else {
50
44
  tokenizer_profile.default_max_length
51
45
  };
52
- let padding_mode =
53
- parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
46
+ let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
54
47
 
55
48
  let probe_config = crate::model_config::ModelConfig {
56
49
  max_length,
@@ -61,6 +54,8 @@ impl Reranker {
61
54
  with_attention_mask: true,
62
55
  optimization_level,
63
56
  execution_providers: overrides.execution_providers.map(str::to_string),
57
+ lowercase_input: false,
58
+ max_input_chars: None,
64
59
  };
65
60
  let session = build_session(&model_path, &probe_config)?;
66
61
 
@@ -69,13 +64,7 @@ impl Reranker {
69
64
  let with_attention_mask = has_input(&session, "attention_mask");
70
65
  let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
71
66
 
72
- let config = RerankerConfig {
73
- max_length,
74
- padding_mode,
75
- output_tensor,
76
- with_type_ids,
77
- with_attention_mask,
78
- };
67
+ let config = RerankerConfig { max_length, padding_mode, output_tensor, with_type_ids, with_attention_mask };
79
68
 
80
69
  let tokenizer = Tokenizer::new(
81
70
  &tokenizer_path,
@@ -85,7 +74,19 @@ impl Reranker {
85
74
  tokenizer_profile.fixed_padding_length,
86
75
  )?;
87
76
 
88
- let pool = SessionPool::new(session, model_path, probe_config);
77
+ let model_config = crate::model_config::ModelConfig {
78
+ max_length,
79
+ padding_mode,
80
+ output_tensor: config.output_tensor.clone(),
81
+ mode: crate::model_config::ExtractorMode::Raw,
82
+ with_type_ids: config.with_type_ids,
83
+ with_attention_mask: config.with_attention_mask,
84
+ optimization_level,
85
+ execution_providers: None,
86
+ lowercase_input: false,
87
+ max_input_chars: None,
88
+ };
89
+ let pool = SessionPool::new(session, &model_path, &model_config)?;
89
90
  Ok(Self { tokenizer, pool, config })
90
91
  }
91
92
 
@@ -99,40 +100,39 @@ impl Reranker {
99
100
  self.score_tokenized(&tokenized, apply_sigmoid)
100
101
  }
101
102
 
102
- fn score_tokenized(
103
- &self,
104
- tokenized: &crate::tokenizer::Tokenized,
105
- apply_sigmoid: bool,
106
- ) -> Result<Vec<f32>> {
103
+ fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
107
104
  let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
108
- let mut session = self.pool.acquire()?;
109
- let outputs = session.run(input_tensors.inputs).map_err(|e| GteError::Ort(e.to_string()))?;
110
- let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
111
-
112
- let mut scores = match array.ndim() {
113
- 1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
114
- 2 => {
115
- let shape = array.shape();
116
- if shape[1] == 0 {
105
+ let output_name = self.config.output_tensor.clone();
106
+ let inputs = input_tensors.inputs;
107
+
108
+ self.pool.with_session(|session| {
109
+ let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
110
+
111
+ let array = extract_output_tensor(&outputs, output_name.as_str())?;
112
+
113
+ let mut scores = match array.ndim() {
114
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
115
+ 2 => {
116
+ let shape = array.shape();
117
+ if shape[1] == 0 {
118
+ return Err(GteError::Inference(format!(
119
+ "reranker output '{output_name}' has invalid shape {shape:?}"
120
+ )));
121
+ }
122
+ array.slice(ndarray::s![.., 0]).to_vec()
123
+ }
124
+ n => {
117
125
  return Err(GteError::Inference(format!(
118
- "reranker output '{}' has invalid shape {:?}",
119
- self.config.output_tensor, shape
120
- )));
126
+ "reranker output '{output_name}' rank {n} is unsupported; expected rank 1 or 2"
127
+ )))
121
128
  }
122
- array.slice(ndarray::s![.., 0]).to_vec()
123
- }
124
- n => {
125
- return Err(GteError::Inference(format!(
126
- "reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
127
- self.config.output_tensor, n
128
- )))
129
- }
130
- };
129
+ };
131
130
 
132
- if apply_sigmoid {
133
- sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
134
- }
131
+ if apply_sigmoid {
132
+ sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
133
+ }
135
134
 
136
- Ok(scores)
135
+ Ok(scores)
136
+ })
137
137
  }
138
138
  }