gte 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +136 -248
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
|
@@ -18,10 +18,7 @@ pub struct TokenizerProfile {
|
|
|
18
18
|
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
19
19
|
let tokenizer_path = dir.join("tokenizer.json");
|
|
20
20
|
if !tokenizer_path.exists() {
|
|
21
|
-
return Err(GteError::Tokenizer(format!(
|
|
22
|
-
"tokenizer.json not found in {}",
|
|
23
|
-
dir.display()
|
|
24
|
-
)));
|
|
21
|
+
return Err(GteError::Tokenizer(format!("tokenizer.json not found in {}", dir.display())));
|
|
25
22
|
}
|
|
26
23
|
Ok(tokenizer_path)
|
|
27
24
|
}
|
|
@@ -33,11 +30,7 @@ pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
|
33
30
|
return Ok(path.clone());
|
|
34
31
|
}
|
|
35
32
|
}
|
|
36
|
-
Err(GteError::Inference(format!(
|
|
37
|
-
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
38
|
-
name,
|
|
39
|
-
dir.display()
|
|
40
|
-
)))
|
|
33
|
+
Err(GteError::Inference(format!("model '{}' not found in {} (checked onnx/{0} and {0})", name, dir.display())))
|
|
41
34
|
}
|
|
42
35
|
|
|
43
36
|
pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
@@ -62,9 +55,7 @@ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
|
|
|
62
55
|
let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
|
|
63
56
|
let tokenizer_json = read_json(dir.join("tokenizer.json"));
|
|
64
57
|
|
|
65
|
-
let fixed_padding_length = tokenizer_json
|
|
66
|
-
.as_ref()
|
|
67
|
-
.and_then(parse_fixed_padding_length_from_tokenizer_json);
|
|
58
|
+
let fixed_padding_length = tokenizer_json.as_ref().and_then(parse_fixed_padding_length_from_tokenizer_json);
|
|
68
59
|
|
|
69
60
|
let mut candidates = Vec::new();
|
|
70
61
|
if let Some(config) = tokenizer_config.as_ref() {
|
|
@@ -90,19 +81,10 @@ pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
|
|
|
90
81
|
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
91
82
|
}
|
|
92
83
|
|
|
93
|
-
let default_max_length = candidates
|
|
94
|
-
.iter()
|
|
95
|
-
.copied()
|
|
96
|
-
.min()
|
|
97
|
-
.unwrap_or(DEFAULT_MAX_LENGTH)
|
|
98
|
-
.max(1);
|
|
84
|
+
let default_max_length = candidates.iter().copied().min().unwrap_or(DEFAULT_MAX_LENGTH).max(1);
|
|
99
85
|
let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
|
|
100
86
|
|
|
101
|
-
TokenizerProfile {
|
|
102
|
-
default_max_length,
|
|
103
|
-
safe_max_length,
|
|
104
|
-
fixed_padding_length,
|
|
105
|
-
}
|
|
87
|
+
TokenizerProfile { default_max_length, safe_max_length, fixed_padding_length }
|
|
106
88
|
}
|
|
107
89
|
|
|
108
90
|
fn read_json(path: PathBuf) -> Option<Value> {
|
|
@@ -113,12 +95,7 @@ fn read_json(path: PathBuf) -> Option<Value> {
|
|
|
113
95
|
fn parse_positive_usize(value: &Value) -> Option<usize> {
|
|
114
96
|
let raw = value
|
|
115
97
|
.as_u64()
|
|
116
|
-
.or_else(||
|
|
117
|
-
value
|
|
118
|
-
.as_f64()
|
|
119
|
-
.filter(|&v| v.is_finite() && v > 0.0)
|
|
120
|
-
.map(|v| v as u64)
|
|
121
|
-
})
|
|
98
|
+
.or_else(|| value.as_f64().filter(|&v| v.is_finite() && v > 0.0).map(|v| v as u64))
|
|
122
99
|
.or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
|
|
123
100
|
let parsed = usize::try_from(raw).ok()?;
|
|
124
101
|
(parsed > 0).then_some(parsed)
|
|
@@ -133,7 +110,9 @@ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Opt
|
|
|
133
110
|
}
|
|
134
111
|
|
|
135
112
|
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
136
|
-
let unsupported: Vec<String> = session
|
|
113
|
+
let unsupported: Vec<String> = session
|
|
114
|
+
.inputs()
|
|
115
|
+
.iter()
|
|
137
116
|
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
|
|
138
117
|
.map(|i| i.name().to_owned())
|
|
139
118
|
.collect();
|
|
@@ -142,11 +121,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
|
|
|
142
121
|
return Ok(());
|
|
143
122
|
}
|
|
144
123
|
|
|
145
|
-
let mut message = format!(
|
|
146
|
-
"unsupported model inputs for {} API: {}",
|
|
147
|
-
api_label,
|
|
148
|
-
unsupported.join(", ")
|
|
149
|
-
);
|
|
124
|
+
let mut message = format!("unsupported model inputs for {} API: {}", api_label, unsupported.join(", "));
|
|
150
125
|
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
151
126
|
message.push_str(
|
|
152
127
|
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
@@ -163,40 +138,23 @@ pub fn has_input(session: &Session, name: &str) -> bool {
|
|
|
163
138
|
|
|
164
139
|
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
165
140
|
let lower = name.to_ascii_lowercase();
|
|
166
|
-
lower == preferred || lower.ends_with(&format!("/{}"
|
|
141
|
+
lower == preferred || lower.ends_with(&format!("/{preferred}"))
|
|
167
142
|
}
|
|
168
143
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
requested: Option<&str>,
|
|
172
|
-
preferred_outputs: &[&str],
|
|
173
|
-
) -> Result<String> {
|
|
144
|
+
#[allow(clippy::redundant_closure_for_method_calls)]
|
|
145
|
+
pub fn select_output_tensor(session: &Session, requested: Option<&str>, preferred_outputs: &[&str]) -> Result<String> {
|
|
174
146
|
if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
|
|
175
|
-
if let Some(output) = session
|
|
176
|
-
.outputs()
|
|
177
|
-
.iter()
|
|
178
|
-
.find(|o| output_name_matches(o.name(), requested_name))
|
|
179
|
-
{
|
|
147
|
+
if let Some(output) = session.outputs().iter().find(|o| output_name_matches(o.name(), requested_name)) {
|
|
180
148
|
return Ok(output.name().to_owned());
|
|
181
149
|
}
|
|
182
|
-
let available = session
|
|
183
|
-
.outputs()
|
|
184
|
-
.iter()
|
|
185
|
-
.map(|o| o.name())
|
|
186
|
-
.collect::<Vec<_>>()
|
|
187
|
-
.join(", ");
|
|
150
|
+
let available = session.outputs().iter().map(|o| o.name()).collect::<Vec<_>>().join(", ");
|
|
188
151
|
return Err(GteError::Inference(format!(
|
|
189
|
-
"requested output tensor '{}' not found in model outputs: {}"
|
|
190
|
-
requested_name, available
|
|
152
|
+
"requested output tensor '{requested_name}' not found in model outputs: {available}"
|
|
191
153
|
)));
|
|
192
154
|
}
|
|
193
155
|
|
|
194
156
|
for preferred in preferred_outputs {
|
|
195
|
-
if let Some(output) = session
|
|
196
|
-
.outputs()
|
|
197
|
-
.iter()
|
|
198
|
-
.find(|o| output_name_matches(o.name(), preferred))
|
|
199
|
-
{
|
|
157
|
+
if let Some(output) = session.outputs().iter().find(|o| output_name_matches(o.name(), preferred)) {
|
|
200
158
|
return Ok(output.name().to_owned());
|
|
201
159
|
}
|
|
202
160
|
}
|
|
@@ -204,12 +162,9 @@ pub fn select_output_tensor(
|
|
|
204
162
|
let outputs = session.outputs();
|
|
205
163
|
let best = outputs
|
|
206
164
|
.iter()
|
|
207
|
-
.find(|o| {
|
|
208
|
-
matches!(o.dtype(), ort::value::ValueType::Tensor { shape, .. } if shape.len() == 2)
|
|
209
|
-
})
|
|
165
|
+
.find(|o| matches!(o.dtype(), ort::value::ValueType::Tensor { shape, .. } if shape.len() == 2))
|
|
210
166
|
.or_else(|| outputs.first());
|
|
211
|
-
best.map(|o| o.name().to_owned())
|
|
212
|
-
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
167
|
+
best.map(|o| o.name().to_owned()).ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
213
168
|
}
|
|
214
169
|
|
|
215
170
|
fn output_basename(name: &str) -> &str {
|
|
@@ -217,34 +172,21 @@ fn output_basename(name: &str) -> &str {
|
|
|
217
172
|
}
|
|
218
173
|
|
|
219
174
|
pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
220
|
-
let output =
|
|
221
|
-
.outputs()
|
|
222
|
-
|
|
223
|
-
.find(|o| o.name() == output_tensor)
|
|
224
|
-
.ok_or_else(|| {
|
|
225
|
-
GteError::Inference(format!(
|
|
226
|
-
"output tensor '{}' not found in model outputs",
|
|
227
|
-
output_tensor
|
|
228
|
-
))
|
|
175
|
+
let output =
|
|
176
|
+
session.outputs().iter().find(|o| o.name() == output_tensor).ok_or_else(|| {
|
|
177
|
+
GteError::Inference(format!("output tensor '{output_tensor}' not found in model outputs"))
|
|
229
178
|
})?;
|
|
230
179
|
|
|
231
180
|
let ndims = match output.dtype() {
|
|
232
181
|
ort::value::ValueType::Tensor { shape, .. } => shape.len(),
|
|
233
|
-
other => {
|
|
234
|
-
return Err(GteError::Inference(format!(
|
|
235
|
-
"output is not a tensor: {:?}",
|
|
236
|
-
other
|
|
237
|
-
)))
|
|
238
|
-
}
|
|
182
|
+
other => return Err(GteError::Inference(format!("output is not a tensor: {other:?}"))),
|
|
239
183
|
};
|
|
240
184
|
|
|
241
185
|
match (output_basename(output_tensor), ndims) {
|
|
242
|
-
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
186
|
+
("last_hidden_state" | _, 3) => Ok(ExtractorMode::MeanPool),
|
|
243
187
|
(_, 2) => Ok(ExtractorMode::Raw),
|
|
244
|
-
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
245
188
|
(_, n) => Err(GteError::Inference(format!(
|
|
246
|
-
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)"
|
|
247
|
-
n, output_tensor
|
|
189
|
+
"unexpected output tensor rank {n} for '{output_tensor}': expected 2 (Raw) or 3 (MeanPool)"
|
|
248
190
|
))),
|
|
249
191
|
}
|
|
250
192
|
}
|
|
@@ -271,9 +213,6 @@ mod tests {
|
|
|
271
213
|
}
|
|
272
214
|
}
|
|
273
215
|
});
|
|
274
|
-
assert_eq!(
|
|
275
|
-
parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
|
|
276
|
-
Some(64)
|
|
277
|
-
);
|
|
216
|
+
assert_eq!(parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json), Some(64));
|
|
278
217
|
}
|
|
279
218
|
}
|
data/ext/gte/src/pipeline.rs
CHANGED
|
@@ -11,41 +11,25 @@ pub struct InputTensors<'a> {
|
|
|
11
11
|
|
|
12
12
|
impl<'a> InputTensors<'a> {
|
|
13
13
|
pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
|
|
14
|
-
let input_ids_view: ArrayView2<'_, i64> =
|
|
15
|
-
(tokenized.rows, tokenized.cols),
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
|
|
19
|
-
(tokenized.rows, tokenized.cols),
|
|
20
|
-
tokenized.attn_masks.as_slice(),
|
|
21
|
-
)?;
|
|
14
|
+
let input_ids_view: ArrayView2<'_, i64> =
|
|
15
|
+
ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
|
|
16
|
+
let attention_mask: ArrayView2<'_, i64> =
|
|
17
|
+
ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
|
|
22
18
|
|
|
23
19
|
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
24
|
-
inputs.push((
|
|
25
|
-
"input_ids",
|
|
26
|
-
SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
|
|
27
|
-
));
|
|
20
|
+
inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
|
|
28
21
|
|
|
29
22
|
if with_attention_mask {
|
|
30
|
-
inputs.push((
|
|
31
|
-
"attention_mask",
|
|
32
|
-
SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
|
|
33
|
-
));
|
|
23
|
+
inputs.push(("attention_mask", SessionInputValue::from(TensorRef::from_array_view(attention_mask)?)));
|
|
34
24
|
}
|
|
35
25
|
|
|
36
26
|
if let Some(type_ids) = tokenized.type_ids.as_deref() {
|
|
37
27
|
let type_ids_view: ArrayView2<'_, i64> =
|
|
38
28
|
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
39
|
-
inputs.push((
|
|
40
|
-
"token_type_ids",
|
|
41
|
-
SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
|
|
42
|
-
));
|
|
29
|
+
inputs.push(("token_type_ids", SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?)));
|
|
43
30
|
}
|
|
44
31
|
|
|
45
|
-
Ok(Self {
|
|
46
|
-
inputs,
|
|
47
|
-
attention_mask,
|
|
48
|
-
})
|
|
32
|
+
Ok(Self { inputs, attention_mask })
|
|
49
33
|
}
|
|
50
34
|
}
|
|
51
35
|
|
|
@@ -53,11 +37,8 @@ pub fn extract_output_tensor<'a>(
|
|
|
53
37
|
outputs: &'a ort::session::SessionOutputs<'_>,
|
|
54
38
|
output_name: &str,
|
|
55
39
|
) -> Result<ArrayViewD<'a, f32>> {
|
|
56
|
-
let tensor_value = outputs
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
output_name
|
|
60
|
-
))
|
|
61
|
-
})?;
|
|
40
|
+
let tensor_value = outputs
|
|
41
|
+
.get(output_name)
|
|
42
|
+
.ok_or_else(|| GteError::Inference(format!("output tensor '{output_name}' not found in model outputs")))?;
|
|
62
43
|
Ok(tensor_value.try_extract_array::<f32>()?)
|
|
63
44
|
}
|
data/ext/gte/src/postprocess.rs
CHANGED
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use ndarray::{Array2, ArrayView2, ArrayView3};
|
|
3
3
|
|
|
4
|
-
pub fn mean_pool(
|
|
5
|
-
hidden_states: ArrayView3<'_, f32>,
|
|
6
|
-
attention_mask: ArrayView2<'_, i64>,
|
|
7
|
-
) -> Result<Array2<f32>> {
|
|
4
|
+
pub fn mean_pool(hidden_states: ArrayView3<'_, f32>, attention_mask: ArrayView2<'_, i64>) -> Result<Array2<f32>> {
|
|
8
5
|
let (batch, seq, dim) = hidden_states.dim();
|
|
9
6
|
if attention_mask.dim() != (batch, seq) {
|
|
10
7
|
return Err(GteError::Inference(format!(
|
|
@@ -34,17 +31,14 @@ pub fn mean_pool(
|
|
|
34
31
|
|
|
35
32
|
let weight = weight as f32;
|
|
36
33
|
for dim_index in 0..dim {
|
|
37
|
-
pooled[[batch_index, dim_index]] +=
|
|
38
|
-
hidden_states[[batch_index, token_index, dim_index]] * weight;
|
|
34
|
+
pooled[[batch_index, dim_index]] += hidden_states[[batch_index, token_index, dim_index]] * weight;
|
|
39
35
|
}
|
|
40
36
|
weight_sum += weight;
|
|
41
37
|
}
|
|
42
38
|
|
|
43
39
|
if weight_sum > 0.0 {
|
|
44
40
|
let inverse = weight_sum.recip();
|
|
45
|
-
pooled
|
|
46
|
-
.row_mut(batch_index)
|
|
47
|
-
.map_inplace(|value| *value *= inverse);
|
|
41
|
+
pooled.row_mut(batch_index).map_inplace(|value| *value *= inverse);
|
|
48
42
|
}
|
|
49
43
|
}
|
|
50
44
|
|
|
@@ -89,6 +83,8 @@ fn mean_pool_contiguous(
|
|
|
89
83
|
seq: usize,
|
|
90
84
|
dim: usize,
|
|
91
85
|
) {
|
|
86
|
+
let seq_inverse = (seq as f32).recip();
|
|
87
|
+
|
|
92
88
|
for batch_index in 0..batch {
|
|
93
89
|
let mask_base = batch_index * seq;
|
|
94
90
|
let hidden_base = batch_index * seq * dim;
|
|
@@ -103,9 +99,8 @@ fn mean_pool_contiguous(
|
|
|
103
99
|
}
|
|
104
100
|
}
|
|
105
101
|
|
|
106
|
-
let inverse = (seq as f32).recip();
|
|
107
102
|
for value in output_row {
|
|
108
|
-
*value *=
|
|
103
|
+
*value *= seq_inverse;
|
|
109
104
|
}
|
|
110
105
|
continue;
|
|
111
106
|
}
|
|
@@ -113,12 +108,11 @@ fn mean_pool_contiguous(
|
|
|
113
108
|
let mut weight_sum = 0.0f32;
|
|
114
109
|
|
|
115
110
|
for (token_index, &weight_raw) in mask_row.iter().enumerate() {
|
|
116
|
-
|
|
117
|
-
if weight <= 0 {
|
|
111
|
+
if weight_raw <= 0 {
|
|
118
112
|
continue;
|
|
119
113
|
}
|
|
120
114
|
|
|
121
|
-
let weight =
|
|
115
|
+
let weight = weight_raw as f32;
|
|
122
116
|
let token_base = hidden_base + token_index * dim;
|
|
123
117
|
for dim_index in 0..dim {
|
|
124
118
|
output_row[dim_index] += hidden[token_base + dim_index] * weight;
|
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::{ModelLoadOverrides, PaddingMode};
|
|
3
3
|
use crate::model_profile::{
|
|
4
|
-
has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
|
|
5
|
-
|
|
4
|
+
has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
|
|
5
|
+
select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
8
8
|
use crate::postprocess::sigmoid_scores;
|
|
@@ -26,11 +26,7 @@ pub struct Reranker {
|
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
impl Reranker {
|
|
29
|
-
pub fn from_dir<P: AsRef<Path>>(
|
|
30
|
-
dir: P,
|
|
31
|
-
optimization_level: u8,
|
|
32
|
-
overrides: ModelLoadOverrides<'_>,
|
|
33
|
-
) -> Result<Self> {
|
|
29
|
+
pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
|
|
34
30
|
let dir = dir.as_ref();
|
|
35
31
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
36
32
|
let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
@@ -41,16 +37,13 @@ impl Reranker {
|
|
|
41
37
|
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
42
38
|
let max_length = if let Some(override_value) = overrides.max_length {
|
|
43
39
|
if override_value == 0 {
|
|
44
|
-
return Err(GteError::Inference(
|
|
45
|
-
"max_length override must be greater than 0".to_string(),
|
|
46
|
-
));
|
|
40
|
+
return Err(GteError::Inference("max_length override must be greater than 0".to_string()));
|
|
47
41
|
}
|
|
48
42
|
override_value.min(tokenizer_profile.safe_max_length)
|
|
49
43
|
} else {
|
|
50
44
|
tokenizer_profile.default_max_length
|
|
51
45
|
};
|
|
52
|
-
let padding_mode =
|
|
53
|
-
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
46
|
+
let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
54
47
|
|
|
55
48
|
let probe_config = crate::model_config::ModelConfig {
|
|
56
49
|
max_length,
|
|
@@ -61,6 +54,8 @@ impl Reranker {
|
|
|
61
54
|
with_attention_mask: true,
|
|
62
55
|
optimization_level,
|
|
63
56
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
57
|
+
lowercase_input: false,
|
|
58
|
+
max_input_chars: None,
|
|
64
59
|
};
|
|
65
60
|
let session = build_session(&model_path, &probe_config)?;
|
|
66
61
|
|
|
@@ -69,13 +64,7 @@ impl Reranker {
|
|
|
69
64
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
70
65
|
let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
|
|
71
66
|
|
|
72
|
-
let config = RerankerConfig {
|
|
73
|
-
max_length,
|
|
74
|
-
padding_mode,
|
|
75
|
-
output_tensor,
|
|
76
|
-
with_type_ids,
|
|
77
|
-
with_attention_mask,
|
|
78
|
-
};
|
|
67
|
+
let config = RerankerConfig { max_length, padding_mode, output_tensor, with_type_ids, with_attention_mask };
|
|
79
68
|
|
|
80
69
|
let tokenizer = Tokenizer::new(
|
|
81
70
|
&tokenizer_path,
|
|
@@ -85,7 +74,19 @@ impl Reranker {
|
|
|
85
74
|
tokenizer_profile.fixed_padding_length,
|
|
86
75
|
)?;
|
|
87
76
|
|
|
88
|
-
let
|
|
77
|
+
let model_config = crate::model_config::ModelConfig {
|
|
78
|
+
max_length,
|
|
79
|
+
padding_mode,
|
|
80
|
+
output_tensor: config.output_tensor.clone(),
|
|
81
|
+
mode: crate::model_config::ExtractorMode::Raw,
|
|
82
|
+
with_type_ids: config.with_type_ids,
|
|
83
|
+
with_attention_mask: config.with_attention_mask,
|
|
84
|
+
optimization_level,
|
|
85
|
+
execution_providers: None,
|
|
86
|
+
lowercase_input: false,
|
|
87
|
+
max_input_chars: None,
|
|
88
|
+
};
|
|
89
|
+
let pool = SessionPool::new(session, &model_path, &model_config)?;
|
|
89
90
|
Ok(Self { tokenizer, pool, config })
|
|
90
91
|
}
|
|
91
92
|
|
|
@@ -99,40 +100,39 @@ impl Reranker {
|
|
|
99
100
|
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
100
101
|
}
|
|
101
102
|
|
|
102
|
-
fn score_tokenized(
|
|
103
|
-
&self,
|
|
104
|
-
tokenized: &crate::tokenizer::Tokenized,
|
|
105
|
-
apply_sigmoid: bool,
|
|
106
|
-
) -> Result<Vec<f32>> {
|
|
103
|
+
fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
107
104
|
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
108
|
-
let
|
|
109
|
-
let
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
105
|
+
let output_name = self.config.output_tensor.clone();
|
|
106
|
+
let inputs = input_tensors.inputs;
|
|
107
|
+
|
|
108
|
+
self.pool.with_session(|session| {
|
|
109
|
+
let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
110
|
+
|
|
111
|
+
let array = extract_output_tensor(&outputs, output_name.as_str())?;
|
|
112
|
+
|
|
113
|
+
let mut scores = match array.ndim() {
|
|
114
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
115
|
+
2 => {
|
|
116
|
+
let shape = array.shape();
|
|
117
|
+
if shape[1] == 0 {
|
|
118
|
+
return Err(GteError::Inference(format!(
|
|
119
|
+
"reranker output '{output_name}' has invalid shape {shape:?}"
|
|
120
|
+
)));
|
|
121
|
+
}
|
|
122
|
+
array.slice(ndarray::s![.., 0]).to_vec()
|
|
123
|
+
}
|
|
124
|
+
n => {
|
|
117
125
|
return Err(GteError::Inference(format!(
|
|
118
|
-
"reranker output '{}'
|
|
119
|
-
|
|
120
|
-
)));
|
|
126
|
+
"reranker output '{output_name}' rank {n} is unsupported; expected rank 1 or 2"
|
|
127
|
+
)))
|
|
121
128
|
}
|
|
122
|
-
|
|
123
|
-
}
|
|
124
|
-
n => {
|
|
125
|
-
return Err(GteError::Inference(format!(
|
|
126
|
-
"reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
|
|
127
|
-
self.config.output_tensor, n
|
|
128
|
-
)))
|
|
129
|
-
}
|
|
130
|
-
};
|
|
129
|
+
};
|
|
131
130
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
131
|
+
if apply_sigmoid {
|
|
132
|
+
sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
|
|
133
|
+
}
|
|
135
134
|
|
|
136
|
-
|
|
135
|
+
Ok(scores)
|
|
136
|
+
})
|
|
137
137
|
}
|
|
138
138
|
}
|