gte 0.0.6 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,19 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::{ModelLoadOverrides, PaddingMode};
2
3
  use crate::model_profile::{
3
- has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
4
- select_output_tensor, validate_supported_text_inputs,
4
+ has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
5
6
  };
6
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
7
8
  use crate::postprocess::sigmoid_scores;
8
- use crate::session::build_session;
9
- use crate::tokenizer::Tokenizer;
10
- use ndarray::Array1;
11
- use ort::session::Session;
12
- use std::path::Path;
9
+ use crate::session::{build_session, SessionPool};
10
+ use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
+ use std::path::{Path, PathBuf};
13
12
 
14
13
  #[derive(Debug, Clone)]
15
14
  struct RerankerConfig {
16
15
  max_length: usize,
16
+ padding_mode: PaddingMode,
17
17
  output_tensor: String,
18
18
  with_type_ids: bool,
19
19
  with_attention_mask: bool,
@@ -21,7 +21,7 @@ struct RerankerConfig {
21
21
 
22
22
  pub struct Reranker {
23
23
  tokenizer: Tokenizer,
24
- session: Session,
24
+ pool: SessionPool,
25
25
  config: RerankerConfig,
26
26
  }
27
27
 
@@ -30,70 +30,89 @@ impl Reranker {
30
30
  dir: P,
31
31
  num_threads: usize,
32
32
  optimization_level: u8,
33
- model_name: Option<&str>,
34
- output_tensor_override: Option<&str>,
35
- max_length_override: Option<usize>,
36
- execution_providers_override: Option<&str>,
33
+ overrides: ModelLoadOverrides<'_>,
37
34
  ) -> Result<Self> {
38
35
  let dir = dir.as_ref();
39
36
  let tokenizer_path = resolve_tokenizer_path(dir)?;
40
- let model_path = match model_name.filter(|s| !s.is_empty()) {
37
+ let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
41
38
  Some(name) => resolve_named_model(dir, name)?,
42
39
  None => resolve_default_text_model(dir)?,
43
40
  };
44
41
 
45
- let max_length = if let Some(override_value) = max_length_override {
42
+ let tokenizer_profile = read_tokenizer_profile(dir);
43
+ let max_length = if let Some(override_value) = overrides.max_length {
46
44
  if override_value == 0 {
47
45
  return Err(GteError::Inference(
48
46
  "max_length override must be greater than 0".to_string(),
49
47
  ));
50
48
  }
51
- override_value
49
+ override_value.min(tokenizer_profile.safe_max_length)
52
50
  } else {
53
- read_max_length(dir)
51
+ tokenizer_profile.default_max_length
54
52
  };
53
+ let padding_mode =
54
+ parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
55
55
 
56
56
  let probe_config = crate::model_config::ModelConfig {
57
57
  max_length,
58
+ padding_mode,
58
59
  output_tensor: String::new(),
59
60
  mode: crate::model_config::ExtractorMode::Raw,
60
61
  with_type_ids: false,
61
62
  with_attention_mask: true,
62
63
  num_threads,
63
64
  optimization_level,
64
- execution_providers: execution_providers_override.map(str::to_string),
65
+ execution_providers: overrides.execution_providers.map(str::to_string),
65
66
  };
66
67
  let session = build_session(&model_path, &probe_config)?;
67
68
 
68
69
  validate_supported_text_inputs(&session, "text reranking")?;
69
70
  let with_type_ids = has_input(&session, "token_type_ids");
70
71
  let with_attention_mask = has_input(&session, "attention_mask");
71
- let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
72
+ let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
72
73
 
73
74
  let config = RerankerConfig {
74
75
  max_length,
76
+ padding_mode,
75
77
  output_tensor,
76
78
  with_type_ids,
77
79
  with_attention_mask,
78
80
  };
79
81
 
80
- let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
82
+ let tokenizer = Tokenizer::new(
83
+ &tokenizer_path,
84
+ config.max_length,
85
+ config.with_type_ids,
86
+ config.padding_mode,
87
+ tokenizer_profile.fixed_padding_length,
88
+ )?;
81
89
 
82
- Ok(Self {
83
- tokenizer,
84
- session,
85
- config,
86
- })
90
+ let pool = SessionPool::new(session, model_path, probe_config);
91
+ Ok(Self { tokenizer, pool, config })
87
92
  }
88
93
 
89
- pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
94
+ pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
90
95
  let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
91
- let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
92
- let outputs = self.session.run(input_tensors.inputs)?;
96
+ self.score_tokenized(&tokenized, apply_sigmoid)
97
+ }
98
+
99
+ pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
100
+ let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
101
+ self.score_tokenized(&tokenized, apply_sigmoid)
102
+ }
103
+
104
+ fn score_tokenized(
105
+ &self,
106
+ tokenized: &crate::tokenizer::Tokenized,
107
+ apply_sigmoid: bool,
108
+ ) -> Result<Vec<f32>> {
109
+ let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
110
+ let mut session = self.pool.acquire()?;
111
+ let outputs = session.run(input_tensors.inputs).map_err(|e| GteError::Ort(e.to_string()))?;
93
112
  let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
94
113
 
95
114
  let mut scores = match array.ndim() {
96
- 1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
115
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
97
116
  2 => {
98
117
  let shape = array.shape();
99
118
  if shape[1] == 0 {
@@ -102,7 +121,7 @@ impl Reranker {
102
121
  self.config.output_tensor, shape
103
122
  )));
104
123
  }
105
- array.slice(ndarray::s![.., 0]).into_owned()
124
+ array.slice(ndarray::s![.., 0]).to_vec()
106
125
  }
107
126
  n => {
108
127
  return Err(GteError::Inference(format!(
@@ -113,10 +132,9 @@ impl Reranker {
113
132
  };
114
133
 
115
134
  if apply_sigmoid {
116
- sigmoid_scores(scores.view_mut());
135
+ sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
117
136
  }
118
137
 
119
138
  Ok(scores)
120
139
  }
121
-
122
140
  }
@@ -2,6 +2,7 @@
2
2
 
3
3
  use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
+ use crate::model_config::ModelLoadOverrides;
5
6
  use crate::reranker::Reranker;
6
7
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
7
8
  use std::os::raw::c_void;
@@ -27,11 +28,15 @@ pub struct RbTensor {
27
28
  data: Vec<f32>,
28
29
  }
29
30
 
31
+ // ---------------------------------------------------------------------------
32
+ // GVL-release helpers
33
+ // ---------------------------------------------------------------------------
34
+
30
35
  struct InferArgs {
31
36
  embedder: *const Embedder,
32
37
  texts: *const Vec<String>,
33
38
  normalize: bool,
34
- result: Option<Result<ndarray::Array2<f32>, GteError>>,
39
+ result: Option<crate::error::Result<ndarray::Array2<f32>>>,
35
40
  }
36
41
 
37
42
  unsafe impl Send for InferArgs {}
@@ -40,7 +45,7 @@ struct ScoreArgs {
40
45
  reranker: *const Reranker,
41
46
  pairs: *const Vec<(String, String)>,
42
47
  apply_sigmoid: bool,
43
- result: Option<Result<Vec<f32>, GteError>>,
48
+ result: Option<crate::error::Result<Vec<f32>>>,
44
49
  }
45
50
 
46
51
  unsafe impl Send for ScoreArgs {}
@@ -55,6 +60,38 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
55
60
  }
56
61
  }
57
62
 
63
+ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
+ let args = &mut *(ptr as *mut InferArgs);
65
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
66
+ let tokenized = (*args.embedder).tokenize(&*args.texts)?;
67
+ let embeddings = (*args.embedder).run(&tokenized)?;
68
+ if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
69
+ }));
70
+ args.result = Some(match run_result {
71
+ Ok(result) => result,
72
+ Err(payload) => Err(GteError::Inference(format!(
73
+ "panic during inference: {}",
74
+ panic_payload_to_string(payload),
75
+ ))),
76
+ });
77
+ std::ptr::null_mut()
78
+ }
79
+
80
+ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
81
+ let args = &mut *(ptr as *mut ScoreArgs);
82
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
83
+ (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)
84
+ }));
85
+ args.result = Some(match run_result {
86
+ Ok(result) => result,
87
+ Err(payload) => Err(GteError::Inference(format!(
88
+ "panic during reranking: {}",
89
+ panic_payload_to_string(payload),
90
+ ))),
91
+ });
92
+ std::ptr::null_mut()
93
+ }
94
+
58
95
  fn infer_without_gvl(
59
96
  embedder: &Arc<Embedder>,
60
97
  normalize: bool,
@@ -111,42 +148,7 @@ fn score_without_gvl(
111
148
  Ok(scores)
112
149
  }
113
150
 
114
- unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
115
- let args = &mut *(ptr as *mut InferArgs);
116
- let run_result = catch_unwind(AssertUnwindSafe(|| {
117
- let tokenized = (*args.embedder).tokenize(&*args.texts)?;
118
- let embeddings = (*args.embedder).run(&tokenized)?;
119
- if args.normalize {
120
- Ok(normalize_l2(embeddings))
121
- } else {
122
- Ok(embeddings)
123
- }
124
- }));
125
- args.result = Some(match run_result {
126
- Ok(result) => result,
127
- Err(payload) => Err(GteError::Inference(format!(
128
- "panic during inference: {}",
129
- panic_payload_to_string(payload),
130
- ))),
131
- });
132
- std::ptr::null_mut()
133
- }
134
-
135
- unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
136
- let args = &mut *(ptr as *mut ScoreArgs);
137
- let run_result = catch_unwind(AssertUnwindSafe(|| {
138
- let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
139
- Ok(scores.to_vec())
140
- }));
141
- args.result = Some(match run_result {
142
- Ok(result) => result,
143
- Err(payload) => Err(GteError::Inference(format!(
144
- "panic during reranking: {}",
145
- panic_payload_to_string(payload),
146
- ))),
147
- });
148
- std::ptr::null_mut()
149
- }
151
+ // ---------------------------------------------------------------------------
150
152
 
151
153
  fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
152
154
  let rows = embeddings.nrows();
@@ -171,42 +173,24 @@ impl RbEmbedder {
171
173
  normalize: bool,
172
174
  output_tensor: String,
173
175
  max_length: usize,
176
+ padding: String,
174
177
  execution_providers: String,
175
178
  ) -> Result<Self, Error> {
176
- let name = if model_name.is_empty() {
177
- None
178
- } else {
179
- Some(model_name.as_str())
180
- };
181
- let output_override = if output_tensor.is_empty() {
182
- None
183
- } else {
184
- Some(output_tensor.as_str())
185
- };
186
- let max_length_override = if max_length == 0 {
187
- None
188
- } else {
189
- Some(max_length)
179
+ let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
180
+ let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
181
+ let max_length_override = if max_length == 0 { None } else { Some(max_length) };
182
+ let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
183
+ let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
184
+ let overrides = ModelLoadOverrides {
185
+ model_name: name,
186
+ output_tensor: output_override,
187
+ max_length: max_length_override,
188
+ padding: padding_override,
189
+ execution_providers: execution_providers_override,
190
190
  };
191
- let execution_providers_override = if execution_providers.is_empty() {
192
- None
193
- } else {
194
- Some(execution_providers.as_str())
195
- };
196
- let embedder = Embedder::from_dir(
197
- &dir_path,
198
- num_threads,
199
- optimization_level,
200
- name,
201
- output_override,
202
- max_length_override,
203
- execution_providers_override,
204
- )
205
- .map_err(magnus::Error::from)?;
206
- Ok(RbEmbedder {
207
- inner: Arc::new(embedder),
208
- normalize,
209
- })
191
+ let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, overrides)
192
+ .map_err(magnus::Error::from)?;
193
+ Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
210
194
  }
211
195
 
212
196
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
@@ -231,42 +215,24 @@ impl RbReranker {
231
215
  sigmoid: bool,
232
216
  output_tensor: String,
233
217
  max_length: usize,
218
+ padding: String,
234
219
  execution_providers: String,
235
220
  ) -> Result<Self, Error> {
236
- let name = if model_name.is_empty() {
237
- None
238
- } else {
239
- Some(model_name.as_str())
240
- };
241
- let output_override = if output_tensor.is_empty() {
242
- None
243
- } else {
244
- Some(output_tensor.as_str())
245
- };
246
- let max_length_override = if max_length == 0 {
247
- None
248
- } else {
249
- Some(max_length)
221
+ let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
222
+ let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
223
+ let max_length_override = if max_length == 0 { None } else { Some(max_length) };
224
+ let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
225
+ let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
226
+ let overrides = ModelLoadOverrides {
227
+ model_name: name,
228
+ output_tensor: output_override,
229
+ max_length: max_length_override,
230
+ padding: padding_override,
231
+ execution_providers: execution_providers_override,
250
232
  };
251
- let execution_providers_override = if execution_providers.is_empty() {
252
- None
253
- } else {
254
- Some(execution_providers.as_str())
255
- };
256
- let reranker = Reranker::from_dir(
257
- &dir_path,
258
- num_threads,
259
- optimization_level,
260
- name,
261
- output_override,
262
- max_length_override,
263
- execution_providers_override,
264
- )
265
- .map_err(magnus::Error::from)?;
266
- Ok(RbReranker {
267
- inner: Arc::new(reranker),
268
- sigmoid,
269
- })
233
+ let reranker = Reranker::from_dir(&dir_path, num_threads, optimization_level, overrides)
234
+ .map_err(magnus::Error::from)?;
235
+ Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
270
236
  }
271
237
 
272
238
  pub fn rb_score(
@@ -276,12 +242,8 @@ impl RbReranker {
276
242
  candidates: RArray,
277
243
  ) -> Result<RArray, Error> {
278
244
  let candidates: Vec<String> = candidates.to_vec()?;
279
- let pairs: Vec<(String, String)> = candidates
280
- .into_iter()
281
- .map(|candidate| (query.clone(), candidate))
282
- .collect();
245
+ let pairs: Vec<(String, String)> = candidates.into_iter().map(|c| (query.clone(), c)).collect();
283
246
  let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
284
-
285
247
  let out = ruby.ary_new_capa(scores.len());
286
248
  for score in scores {
287
249
  out.push(score)?;
@@ -317,7 +279,6 @@ impl RbTensor {
317
279
  index, rb_self.rows
318
280
  ))));
319
281
  }
320
-
321
282
  let start = index * rb_self.cols;
322
283
  let end = start + rb_self.cols;
323
284
  let out = ruby.ary_new_capa(rb_self.cols);
@@ -342,7 +303,6 @@ impl RbTensor {
342
303
  index, rb_self.rows
343
304
  ))));
344
305
  }
345
-
346
306
  let start = index * rb_self.cols;
347
307
  let end = start + rb_self.cols;
348
308
  let bytes = unsafe {
@@ -376,12 +336,12 @@ impl RbTensor {
376
336
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
377
337
  let module = ruby.define_module("GTE")?;
378
338
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
379
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
339
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
380
340
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
381
341
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
382
342
 
383
343
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
384
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
344
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
385
345
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
386
346
 
387
347
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;