gte 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,179 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_config::ExtractorMode;
3
+ use ort::session::Session;
4
+ use std::path::{Path, PathBuf};
5
+
6
+ const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
7
+
8
+ pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
9
+ let tokenizer_path = dir.join("tokenizer.json");
10
+ if !tokenizer_path.exists() {
11
+ return Err(GteError::Tokenizer(format!(
12
+ "tokenizer.json not found in {}",
13
+ dir.display()
14
+ )));
15
+ }
16
+ Ok(tokenizer_path)
17
+ }
18
+
19
+ pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
20
+ let candidates = [dir.join("onnx").join(name), dir.join(name)];
21
+ for path in &candidates {
22
+ if path.exists() {
23
+ return Ok(path.clone());
24
+ }
25
+ }
26
+ Err(GteError::Inference(format!(
27
+ "model '{}' not found in {} (checked onnx/{0} and {0})",
28
+ name,
29
+ dir.display()
30
+ )))
31
+ }
32
+
33
+ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
34
+ let candidates = [
35
+ dir.join("onnx").join("text_model.onnx"),
36
+ dir.join("text_model.onnx"),
37
+ dir.join("onnx").join("model.onnx"),
38
+ dir.join("model.onnx"),
39
+ ];
40
+ for path in &candidates {
41
+ if path.exists() {
42
+ return Ok(path.clone());
43
+ }
44
+ }
45
+ Err(GteError::Inference(format!(
46
+ "no ONNX model found in {} (checked text_model.onnx and model.onnx)",
47
+ dir.display()
48
+ )))
49
+ }
50
+
51
+ pub fn read_max_length(dir: &Path) -> usize {
52
+ (|| -> Option<usize> {
53
+ let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
54
+ let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
55
+ let v = json.get("model_max_length")?;
56
+ let n = v.as_u64().or_else(|| {
57
+ v.as_f64()
58
+ .filter(|&f| f > 0.0 && f < 1e15)
59
+ .map(|f| f as u64)
60
+ })?;
61
+ Some((n as usize).min(8192))
62
+ })()
63
+ .unwrap_or(512)
64
+ }
65
+
66
+ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
67
+ let unsupported: Vec<String> = session
68
+ .inputs
69
+ .iter()
70
+ .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
71
+ .map(|i| i.name.clone())
72
+ .collect();
73
+
74
+ if unsupported.is_empty() {
75
+ return Ok(());
76
+ }
77
+
78
+ let mut message = format!(
79
+ "unsupported model inputs for {} API: {}",
80
+ api_label,
81
+ unsupported.join(", ")
82
+ );
83
+ if unsupported.iter().any(|n| n == "pixel_values") {
84
+ message.push_str(
85
+ ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
86
+ );
87
+ } else {
88
+ message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
89
+ }
90
+ Err(GteError::Inference(message))
91
+ }
92
+
93
+ pub fn has_input(session: &Session, name: &str) -> bool {
94
+ session.inputs.iter().any(|input| input.name == name)
95
+ }
96
+
97
+ fn output_name_matches(name: &str, preferred: &str) -> bool {
98
+ let lower = name.to_ascii_lowercase();
99
+ lower == preferred || lower.ends_with(&format!("/{}", preferred))
100
+ }
101
+
102
+ pub fn select_output_tensor(
103
+ session: &Session,
104
+ requested: Option<&str>,
105
+ preferred_outputs: &[&str],
106
+ ) -> Result<String> {
107
+ if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
108
+ if let Some(output) = session
109
+ .outputs
110
+ .iter()
111
+ .find(|o| output_name_matches(o.name.as_str(), requested_name))
112
+ {
113
+ return Ok(output.name.clone());
114
+ }
115
+ let available = session
116
+ .outputs
117
+ .iter()
118
+ .map(|o| o.name.as_str())
119
+ .collect::<Vec<_>>()
120
+ .join(", ");
121
+ return Err(GteError::Inference(format!(
122
+ "requested output tensor '{}' not found in model outputs: {}",
123
+ requested_name, available
124
+ )));
125
+ }
126
+
127
+ for preferred in preferred_outputs {
128
+ if let Some(output) = session
129
+ .outputs
130
+ .iter()
131
+ .find(|o| output_name_matches(o.name.as_str(), preferred))
132
+ {
133
+ return Ok(output.name.clone());
134
+ }
135
+ }
136
+
137
+ session
138
+ .outputs
139
+ .first()
140
+ .map(|o| o.name.clone())
141
+ .ok_or_else(|| GteError::Inference("model has no outputs".into()))
142
+ }
143
+
144
+ fn output_basename(name: &str) -> &str {
145
+ name.rsplit('/').next().unwrap_or(name)
146
+ }
147
+
148
+ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
149
+ let output = session
150
+ .outputs
151
+ .iter()
152
+ .find(|o| o.name == output_tensor)
153
+ .ok_or_else(|| {
154
+ GteError::Inference(format!(
155
+ "output tensor '{}' not found in model outputs",
156
+ output_tensor
157
+ ))
158
+ })?;
159
+
160
+ let ndims = match &output.output_type {
161
+ ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
162
+ other => {
163
+ return Err(GteError::Inference(format!(
164
+ "output is not a tensor: {:?}",
165
+ other
166
+ )))
167
+ }
168
+ };
169
+
170
+ match (output_basename(output_tensor), ndims) {
171
+ ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
172
+ (_, 2) => Ok(ExtractorMode::Raw),
173
+ (_, 3) => Ok(ExtractorMode::MeanPool),
174
+ (_, n) => Err(GteError::Inference(format!(
175
+ "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
176
+ n, output_tensor
177
+ ))),
178
+ }
179
+ }
@@ -0,0 +1,60 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::tokenizer::Tokenized;
3
+ use ndarray::ArrayView2;
4
+ use ort::session::SessionInputValue;
5
+ use ort::value::Value;
6
+
7
+ pub struct InputTensors<'a> {
8
+ pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
9
+ pub attention_mask: ArrayView2<'a, i64>,
10
+ }
11
+
12
+ impl<'a> InputTensors<'a> {
13
+ pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
14
+ let input_ids_view: ArrayView2<'_, i64> = ArrayView2::from_shape(
15
+ (tokenized.rows, tokenized.cols),
16
+ tokenized.input_ids.as_slice(),
17
+ )?;
18
+ let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
19
+ (tokenized.rows, tokenized.cols),
20
+ tokenized.attn_masks.as_slice(),
21
+ )?;
22
+
23
+ let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
24
+ inputs.push((
25
+ "input_ids",
26
+ SessionInputValue::from(Value::from_array(input_ids_view)?),
27
+ ));
28
+
29
+ if with_attention_mask {
30
+ inputs.push((
31
+ "attention_mask",
32
+ SessionInputValue::from(Value::from_array(attention_mask)?),
33
+ ));
34
+ }
35
+
36
+ if let Some(type_ids) = tokenized.type_ids.as_deref() {
37
+ let type_ids_view: ArrayView2<'_, i64> =
38
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
39
+ inputs.push((
40
+ "token_type_ids",
41
+ SessionInputValue::from(Value::from_array(type_ids_view)?),
42
+ ));
43
+ }
44
+
45
+ Ok(Self {
46
+ inputs,
47
+ attention_mask,
48
+ })
49
+ }
50
+ }
51
+
52
+ pub fn extract_output_tensor<'a>(
53
+ outputs: &'a ort::session::SessionOutputs<'a, 'a>,
54
+ output_name: &str,
55
+ ) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
56
+ let tensor_value = outputs.get(output_name).ok_or_else(|| {
57
+ GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
58
+ })?;
59
+ Ok(tensor_value.try_extract_tensor::<f32>()?.into())
60
+ }
@@ -75,6 +75,12 @@ pub fn normalize_l2(mut embeddings: Array2<f32>) -> Array2<f32> {
75
75
  embeddings
76
76
  }
77
77
 
78
+ pub fn sigmoid_scores(mut scores: ndarray::ArrayViewMut1<'_, f32>) {
79
+ scores.map_inplace(|value| {
80
+ *value = 1.0 / (1.0 + (-*value).exp());
81
+ });
82
+ }
83
+
78
84
  fn mean_pool_contiguous(
79
85
  hidden: &[f32],
80
86
  attention_mask: &[i64],
@@ -87,10 +93,27 @@ fn mean_pool_contiguous(
87
93
  let mask_base = batch_index * seq;
88
94
  let hidden_base = batch_index * seq * dim;
89
95
  let output_row = &mut output[batch_index * dim..(batch_index + 1) * dim];
96
+ let mask_row = &attention_mask[mask_base..mask_base + seq];
97
+
98
+ if mask_row.iter().all(|&weight| weight == 1) {
99
+ for token_index in 0..seq {
100
+ let token_base = hidden_base + token_index * dim;
101
+ for dim_index in 0..dim {
102
+ output_row[dim_index] += hidden[token_base + dim_index];
103
+ }
104
+ }
105
+
106
+ let inverse = (seq as f32).recip();
107
+ for value in output_row {
108
+ *value *= inverse;
109
+ }
110
+ continue;
111
+ }
112
+
90
113
  let mut weight_sum = 0.0f32;
91
114
 
92
- for token_index in 0..seq {
93
- let weight = attention_mask[mask_base + token_index];
115
+ for (token_index, &weight_raw) in mask_row.iter().enumerate() {
116
+ let weight = weight_raw;
94
117
  if weight <= 0 {
95
118
  continue;
96
119
  }
@@ -0,0 +1,120 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_profile::{
3
+ has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
4
+ select_output_tensor, validate_supported_text_inputs,
5
+ };
6
+ use crate::pipeline::{extract_output_tensor, InputTensors};
7
+ use crate::postprocess::sigmoid_scores;
8
+ use crate::session::build_session;
9
+ use crate::tokenizer::Tokenizer;
10
+ use ndarray::Array1;
11
+ use ort::session::Session;
12
+ use std::path::Path;
13
+
14
+ #[derive(Debug, Clone)]
15
+ struct RerankerConfig {
16
+ max_length: usize,
17
+ output_tensor: String,
18
+ with_type_ids: bool,
19
+ with_attention_mask: bool,
20
+ }
21
+
22
+ pub struct Reranker {
23
+ tokenizer: Tokenizer,
24
+ session: Session,
25
+ config: RerankerConfig,
26
+ }
27
+
28
+ impl Reranker {
29
+ pub fn from_dir<P: AsRef<Path>>(
30
+ dir: P,
31
+ num_threads: usize,
32
+ optimization_level: u8,
33
+ model_name: Option<&str>,
34
+ output_tensor_override: Option<&str>,
35
+ max_length_override: Option<usize>,
36
+ ) -> Result<Self> {
37
+ let dir = dir.as_ref();
38
+ let tokenizer_path = resolve_tokenizer_path(dir)?;
39
+ let model_path = match model_name.filter(|s| !s.is_empty()) {
40
+ Some(name) => resolve_named_model(dir, name)?,
41
+ None => resolve_default_text_model(dir)?,
42
+ };
43
+
44
+ let max_length = if let Some(override_value) = max_length_override {
45
+ if override_value == 0 {
46
+ return Err(GteError::Inference(
47
+ "max_length override must be greater than 0".to_string(),
48
+ ));
49
+ }
50
+ override_value
51
+ } else {
52
+ read_max_length(dir)
53
+ };
54
+
55
+ let probe_config = crate::model_config::ModelConfig {
56
+ max_length,
57
+ output_tensor: String::new(),
58
+ mode: crate::model_config::ExtractorMode::Raw,
59
+ with_type_ids: false,
60
+ with_attention_mask: true,
61
+ num_threads,
62
+ optimization_level,
63
+ };
64
+ let session = build_session(&model_path, &probe_config)?;
65
+
66
+ validate_supported_text_inputs(&session, "text reranking")?;
67
+ let with_type_ids = has_input(&session, "token_type_ids");
68
+ let with_attention_mask = has_input(&session, "attention_mask");
69
+ let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
70
+
71
+ let config = RerankerConfig {
72
+ max_length,
73
+ output_tensor,
74
+ with_type_ids,
75
+ with_attention_mask,
76
+ };
77
+
78
+ let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
79
+
80
+ Ok(Self {
81
+ tokenizer,
82
+ session,
83
+ config,
84
+ })
85
+ }
86
+
87
+ pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
88
+ let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
89
+ let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
90
+ let outputs = self.session.run(input_tensors.inputs)?;
91
+ let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
92
+
93
+ let mut scores = match array.ndim() {
94
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
95
+ 2 => {
96
+ let shape = array.shape();
97
+ if shape[1] == 0 {
98
+ return Err(GteError::Inference(format!(
99
+ "reranker output '{}' has invalid shape {:?}",
100
+ self.config.output_tensor, shape
101
+ )));
102
+ }
103
+ array.slice(ndarray::s![.., 0]).into_owned()
104
+ }
105
+ n => {
106
+ return Err(GteError::Inference(format!(
107
+ "reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
108
+ self.config.output_tensor, n
109
+ )))
110
+ }
111
+ };
112
+
113
+ if apply_sigmoid {
114
+ sigmoid_scores(scores.view_mut());
115
+ }
116
+
117
+ Ok(scores)
118
+ }
119
+
120
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
+ use crate::reranker::Reranker;
5
6
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
6
7
  use std::os::raw::c_void;
7
8
  use std::panic::{catch_unwind, AssertUnwindSafe};
@@ -10,6 +11,13 @@ use std::sync::Arc;
10
11
  #[wrap(class = "GTE::Embedder", free_immediately, size)]
11
12
  pub struct RbEmbedder {
12
13
  inner: Arc<Embedder>,
14
+ normalize: bool,
15
+ }
16
+
17
+ #[wrap(class = "GTE::Reranker", free_immediately, size)]
18
+ pub struct RbReranker {
19
+ inner: Arc<Reranker>,
20
+ sigmoid: bool,
13
21
  }
14
22
 
15
23
  #[wrap(class = "GTE::Tensor", free_immediately, size)]
@@ -22,11 +30,21 @@ pub struct RbTensor {
22
30
  struct InferArgs {
23
31
  embedder: *const Embedder,
24
32
  texts: *const Vec<String>,
33
+ normalize: bool,
25
34
  result: Option<Result<ndarray::Array2<f32>, GteError>>,
26
35
  }
27
36
 
28
37
  unsafe impl Send for InferArgs {}
29
38
 
39
+ struct ScoreArgs {
40
+ reranker: *const Reranker,
41
+ pairs: *const Vec<(String, String)>,
42
+ apply_sigmoid: bool,
43
+ result: Option<Result<Vec<f32>, GteError>>,
44
+ }
45
+
46
+ unsafe impl Send for ScoreArgs {}
47
+
30
48
  fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
31
49
  if let Some(msg) = payload.downcast_ref::<&str>() {
32
50
  (*msg).to_string()
@@ -37,11 +55,16 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
37
55
  }
38
56
  }
39
57
 
40
- fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
58
+ fn infer_without_gvl(
59
+ embedder: &Arc<Embedder>,
60
+ normalize: bool,
61
+ texts: Vec<String>,
62
+ ) -> Result<ndarray::Array2<f32>, Error> {
41
63
  let embeddings = unsafe {
42
64
  let mut args = InferArgs {
43
65
  embedder: Arc::as_ptr(embedder),
44
66
  texts: &texts as *const Vec<String>,
67
+ normalize,
45
68
  result: None,
46
69
  };
47
70
  rb_sys::rb_thread_call_without_gvl(
@@ -60,12 +83,44 @@ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<nda
60
83
  Ok(embeddings)
61
84
  }
62
85
 
86
+ fn score_without_gvl(
87
+ reranker: &Arc<Reranker>,
88
+ pairs: Vec<(String, String)>,
89
+ apply_sigmoid: bool,
90
+ ) -> Result<Vec<f32>, Error> {
91
+ let scores = unsafe {
92
+ let mut args = ScoreArgs {
93
+ reranker: Arc::as_ptr(reranker),
94
+ pairs: &pairs as *const Vec<(String, String)>,
95
+ apply_sigmoid,
96
+ result: None,
97
+ };
98
+ rb_sys::rb_thread_call_without_gvl(
99
+ Some(run_score_without_gvl),
100
+ &mut args as *mut ScoreArgs as *mut c_void,
101
+ None,
102
+ std::ptr::null_mut(),
103
+ );
104
+ let result = args.result.take().ok_or_else(|| {
105
+ magnus::Error::from(GteError::Inference(
106
+ "reranking did not return a result".to_string(),
107
+ ))
108
+ })?;
109
+ result.map_err(magnus::Error::from)?
110
+ };
111
+ Ok(scores)
112
+ }
113
+
63
114
  unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
115
  let args = &mut *(ptr as *mut InferArgs);
65
116
  let run_result = catch_unwind(AssertUnwindSafe(|| {
66
117
  let tokenized = (*args.embedder).tokenize(&*args.texts)?;
67
118
  let embeddings = (*args.embedder).run(&tokenized)?;
68
- Ok(normalize_l2(embeddings))
119
+ if args.normalize {
120
+ Ok(normalize_l2(embeddings))
121
+ } else {
122
+ Ok(embeddings)
123
+ }
69
124
  }));
70
125
  args.result = Some(match run_result {
71
126
  Ok(result) => result,
@@ -77,6 +132,22 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
77
132
  std::ptr::null_mut()
78
133
  }
79
134
 
135
+ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
136
+ let args = &mut *(ptr as *mut ScoreArgs);
137
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
138
+ let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
139
+ Ok(scores.to_vec())
140
+ }));
141
+ args.result = Some(match run_result {
142
+ Ok(result) => result,
143
+ Err(payload) => Err(GteError::Inference(format!(
144
+ "panic during reranking: {}",
145
+ panic_payload_to_string(payload),
146
+ ))),
147
+ });
148
+ std::ptr::null_mut()
149
+ }
150
+
80
151
  fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
81
152
  let rows = embeddings.nrows();
82
153
  let cols = embeddings.ncols();
@@ -97,31 +168,114 @@ impl RbEmbedder {
97
168
  num_threads: usize,
98
169
  optimization_level: u8,
99
170
  model_name: String,
171
+ normalize: bool,
172
+ output_tensor: String,
173
+ max_length: usize,
100
174
  ) -> Result<Self, Error> {
101
175
  let name = if model_name.is_empty() {
102
176
  None
103
177
  } else {
104
178
  Some(model_name.as_str())
105
179
  };
106
- let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, name)
107
- .map_err(magnus::Error::from)?;
180
+ let output_override = if output_tensor.is_empty() {
181
+ None
182
+ } else {
183
+ Some(output_tensor.as_str())
184
+ };
185
+ let max_length_override = if max_length == 0 {
186
+ None
187
+ } else {
188
+ Some(max_length)
189
+ };
190
+ let embedder = Embedder::from_dir(
191
+ &dir_path,
192
+ num_threads,
193
+ optimization_level,
194
+ name,
195
+ output_override,
196
+ max_length_override,
197
+ )
198
+ .map_err(magnus::Error::from)?;
108
199
  Ok(RbEmbedder {
109
200
  inner: Arc::new(embedder),
201
+ normalize,
110
202
  })
111
203
  }
112
204
 
113
205
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
114
206
  let texts: Vec<String> = texts.to_vec()?;
115
- let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
207
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
116
208
  tensor_from_array(embeddings)
117
209
  }
118
210
 
119
211
  pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
120
- let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
212
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
121
213
  tensor_from_array(embeddings)
122
214
  }
123
215
  }
124
216
 
217
+ impl RbReranker {
218
+ pub fn rb_new(
219
+ _ruby: &Ruby,
220
+ dir_path: String,
221
+ num_threads: usize,
222
+ optimization_level: u8,
223
+ model_name: String,
224
+ sigmoid: bool,
225
+ output_tensor: String,
226
+ max_length: usize,
227
+ ) -> Result<Self, Error> {
228
+ let name = if model_name.is_empty() {
229
+ None
230
+ } else {
231
+ Some(model_name.as_str())
232
+ };
233
+ let output_override = if output_tensor.is_empty() {
234
+ None
235
+ } else {
236
+ Some(output_tensor.as_str())
237
+ };
238
+ let max_length_override = if max_length == 0 {
239
+ None
240
+ } else {
241
+ Some(max_length)
242
+ };
243
+ let reranker = Reranker::from_dir(
244
+ &dir_path,
245
+ num_threads,
246
+ optimization_level,
247
+ name,
248
+ output_override,
249
+ max_length_override,
250
+ )
251
+ .map_err(magnus::Error::from)?;
252
+ Ok(RbReranker {
253
+ inner: Arc::new(reranker),
254
+ sigmoid,
255
+ })
256
+ }
257
+
258
+ pub fn rb_score(
259
+ ruby: &Ruby,
260
+ rb_self: &Self,
261
+ query: String,
262
+ candidates: RArray,
263
+ ) -> Result<RArray, Error> {
264
+ let candidates: Vec<String> = candidates.to_vec()?;
265
+ let pairs: Vec<(String, String)> = candidates
266
+ .into_iter()
267
+ .map(|candidate| (query.clone(), candidate))
268
+ .collect();
269
+ let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
270
+
271
+ let out = ruby.ary_new_capa(scores.len());
272
+ for score in scores {
273
+ out.push(score)?;
274
+ }
275
+ Ok(out)
276
+ }
277
+ }
278
+
125
279
  impl RbTensor {
126
280
  pub fn len(&self) -> usize {
127
281
  self.rows
@@ -208,10 +362,14 @@ impl RbTensor {
208
362
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
209
363
  let module = ruby.define_module("GTE")?;
210
364
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
211
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 4))?;
365
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
212
366
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
213
367
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
214
368
 
369
+ let reranker_class = module.define_class("Reranker", ruby.class_object())?;
370
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 7))?;
371
+ reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
372
+
215
373
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
216
374
  tensor_class.define_method("rows", method!(RbTensor::rows, 0))?;
217
375
  tensor_class.define_method("size", method!(RbTensor::len, 0))?;