gte 0.0.4 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,179 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_config::ExtractorMode;
3
+ use ort::session::Session;
4
+ use std::path::{Path, PathBuf};
5
+
6
+ const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
7
+
8
+ pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
9
+ let tokenizer_path = dir.join("tokenizer.json");
10
+ if !tokenizer_path.exists() {
11
+ return Err(GteError::Tokenizer(format!(
12
+ "tokenizer.json not found in {}",
13
+ dir.display()
14
+ )));
15
+ }
16
+ Ok(tokenizer_path)
17
+ }
18
+
19
+ pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
20
+ let candidates = [dir.join("onnx").join(name), dir.join(name)];
21
+ for path in &candidates {
22
+ if path.exists() {
23
+ return Ok(path.clone());
24
+ }
25
+ }
26
+ Err(GteError::Inference(format!(
27
+ "model '{}' not found in {} (checked onnx/{0} and {0})",
28
+ name,
29
+ dir.display()
30
+ )))
31
+ }
32
+
33
+ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
34
+ let candidates = [
35
+ dir.join("onnx").join("text_model.onnx"),
36
+ dir.join("text_model.onnx"),
37
+ dir.join("onnx").join("model.onnx"),
38
+ dir.join("model.onnx"),
39
+ ];
40
+ for path in &candidates {
41
+ if path.exists() {
42
+ return Ok(path.clone());
43
+ }
44
+ }
45
+ Err(GteError::Inference(format!(
46
+ "no ONNX model found in {} (checked text_model.onnx and model.onnx)",
47
+ dir.display()
48
+ )))
49
+ }
50
+
51
+ pub fn read_max_length(dir: &Path) -> usize {
52
+ (|| -> Option<usize> {
53
+ let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
54
+ let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
55
+ let v = json.get("model_max_length")?;
56
+ let n = v.as_u64().or_else(|| {
57
+ v.as_f64()
58
+ .filter(|&f| f > 0.0 && f < 1e15)
59
+ .map(|f| f as u64)
60
+ })?;
61
+ Some((n as usize).min(8192))
62
+ })()
63
+ .unwrap_or(512)
64
+ }
65
+
66
+ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
67
+ let unsupported: Vec<String> = session
68
+ .inputs
69
+ .iter()
70
+ .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
71
+ .map(|i| i.name.clone())
72
+ .collect();
73
+
74
+ if unsupported.is_empty() {
75
+ return Ok(());
76
+ }
77
+
78
+ let mut message = format!(
79
+ "unsupported model inputs for {} API: {}",
80
+ api_label,
81
+ unsupported.join(", ")
82
+ );
83
+ if unsupported.iter().any(|n| n == "pixel_values") {
84
+ message.push_str(
85
+ ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
86
+ );
87
+ } else {
88
+ message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
89
+ }
90
+ Err(GteError::Inference(message))
91
+ }
92
+
93
+ pub fn has_input(session: &Session, name: &str) -> bool {
94
+ session.inputs.iter().any(|input| input.name == name)
95
+ }
96
+
97
+ fn output_name_matches(name: &str, preferred: &str) -> bool {
98
+ let lower = name.to_ascii_lowercase();
99
+ lower == preferred || lower.ends_with(&format!("/{}", preferred))
100
+ }
101
+
102
+ pub fn select_output_tensor(
103
+ session: &Session,
104
+ requested: Option<&str>,
105
+ preferred_outputs: &[&str],
106
+ ) -> Result<String> {
107
+ if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
108
+ if let Some(output) = session
109
+ .outputs
110
+ .iter()
111
+ .find(|o| output_name_matches(o.name.as_str(), requested_name))
112
+ {
113
+ return Ok(output.name.clone());
114
+ }
115
+ let available = session
116
+ .outputs
117
+ .iter()
118
+ .map(|o| o.name.as_str())
119
+ .collect::<Vec<_>>()
120
+ .join(", ");
121
+ return Err(GteError::Inference(format!(
122
+ "requested output tensor '{}' not found in model outputs: {}",
123
+ requested_name, available
124
+ )));
125
+ }
126
+
127
+ for preferred in preferred_outputs {
128
+ if let Some(output) = session
129
+ .outputs
130
+ .iter()
131
+ .find(|o| output_name_matches(o.name.as_str(), preferred))
132
+ {
133
+ return Ok(output.name.clone());
134
+ }
135
+ }
136
+
137
+ session
138
+ .outputs
139
+ .first()
140
+ .map(|o| o.name.clone())
141
+ .ok_or_else(|| GteError::Inference("model has no outputs".into()))
142
+ }
143
+
144
+ fn output_basename(name: &str) -> &str {
145
+ name.rsplit('/').next().unwrap_or(name)
146
+ }
147
+
148
+ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
149
+ let output = session
150
+ .outputs
151
+ .iter()
152
+ .find(|o| o.name == output_tensor)
153
+ .ok_or_else(|| {
154
+ GteError::Inference(format!(
155
+ "output tensor '{}' not found in model outputs",
156
+ output_tensor
157
+ ))
158
+ })?;
159
+
160
+ let ndims = match &output.output_type {
161
+ ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
162
+ other => {
163
+ return Err(GteError::Inference(format!(
164
+ "output is not a tensor: {:?}",
165
+ other
166
+ )))
167
+ }
168
+ };
169
+
170
+ match (output_basename(output_tensor), ndims) {
171
+ ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
172
+ (_, 2) => Ok(ExtractorMode::Raw),
173
+ (_, 3) => Ok(ExtractorMode::MeanPool),
174
+ (_, n) => Err(GteError::Inference(format!(
175
+ "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
176
+ n, output_tensor
177
+ ))),
178
+ }
179
+ }
@@ -0,0 +1,60 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::tokenizer::Tokenized;
3
+ use ndarray::ArrayView2;
4
+ use ort::session::SessionInputValue;
5
+ use ort::value::Value;
6
+
7
+ pub struct InputTensors<'a> {
8
+ pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
9
+ pub attention_mask: ArrayView2<'a, i64>,
10
+ }
11
+
12
+ impl<'a> InputTensors<'a> {
13
+ pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
14
+ let input_ids_view: ArrayView2<'_, i64> = ArrayView2::from_shape(
15
+ (tokenized.rows, tokenized.cols),
16
+ tokenized.input_ids.as_slice(),
17
+ )?;
18
+ let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
19
+ (tokenized.rows, tokenized.cols),
20
+ tokenized.attn_masks.as_slice(),
21
+ )?;
22
+
23
+ let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
24
+ inputs.push((
25
+ "input_ids",
26
+ SessionInputValue::from(Value::from_array(input_ids_view)?),
27
+ ));
28
+
29
+ if with_attention_mask {
30
+ inputs.push((
31
+ "attention_mask",
32
+ SessionInputValue::from(Value::from_array(attention_mask)?),
33
+ ));
34
+ }
35
+
36
+ if let Some(type_ids) = tokenized.type_ids.as_deref() {
37
+ let type_ids_view: ArrayView2<'_, i64> =
38
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
39
+ inputs.push((
40
+ "token_type_ids",
41
+ SessionInputValue::from(Value::from_array(type_ids_view)?),
42
+ ));
43
+ }
44
+
45
+ Ok(Self {
46
+ inputs,
47
+ attention_mask,
48
+ })
49
+ }
50
+ }
51
+
52
+ pub fn extract_output_tensor<'a>(
53
+ outputs: &'a ort::session::SessionOutputs<'a, 'a>,
54
+ output_name: &str,
55
+ ) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
56
+ let tensor_value = outputs.get(output_name).ok_or_else(|| {
57
+ GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
58
+ })?;
59
+ Ok(tensor_value.try_extract_tensor::<f32>()?.into())
60
+ }
@@ -75,6 +75,12 @@ pub fn normalize_l2(mut embeddings: Array2<f32>) -> Array2<f32> {
75
75
  embeddings
76
76
  }
77
77
 
78
+ pub fn sigmoid_scores(mut scores: ndarray::ArrayViewMut1<'_, f32>) {
79
+ scores.map_inplace(|value| {
80
+ *value = 1.0 / (1.0 + (-*value).exp());
81
+ });
82
+ }
83
+
78
84
  fn mean_pool_contiguous(
79
85
  hidden: &[f32],
80
86
  attention_mask: &[i64],
@@ -0,0 +1,122 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_profile::{
3
+ has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
4
+ select_output_tensor, validate_supported_text_inputs,
5
+ };
6
+ use crate::pipeline::{extract_output_tensor, InputTensors};
7
+ use crate::postprocess::sigmoid_scores;
8
+ use crate::session::build_session;
9
+ use crate::tokenizer::Tokenizer;
10
+ use ndarray::Array1;
11
+ use ort::session::Session;
12
+ use std::path::Path;
13
+
14
+ #[derive(Debug, Clone)]
15
+ struct RerankerConfig {
16
+ max_length: usize,
17
+ output_tensor: String,
18
+ with_type_ids: bool,
19
+ with_attention_mask: bool,
20
+ }
21
+
22
+ pub struct Reranker {
23
+ tokenizer: Tokenizer,
24
+ session: Session,
25
+ config: RerankerConfig,
26
+ }
27
+
28
+ impl Reranker {
29
+ pub fn from_dir<P: AsRef<Path>>(
30
+ dir: P,
31
+ num_threads: usize,
32
+ optimization_level: u8,
33
+ model_name: Option<&str>,
34
+ output_tensor_override: Option<&str>,
35
+ max_length_override: Option<usize>,
36
+ execution_providers_override: Option<&str>,
37
+ ) -> Result<Self> {
38
+ let dir = dir.as_ref();
39
+ let tokenizer_path = resolve_tokenizer_path(dir)?;
40
+ let model_path = match model_name.filter(|s| !s.is_empty()) {
41
+ Some(name) => resolve_named_model(dir, name)?,
42
+ None => resolve_default_text_model(dir)?,
43
+ };
44
+
45
+ let max_length = if let Some(override_value) = max_length_override {
46
+ if override_value == 0 {
47
+ return Err(GteError::Inference(
48
+ "max_length override must be greater than 0".to_string(),
49
+ ));
50
+ }
51
+ override_value
52
+ } else {
53
+ read_max_length(dir)
54
+ };
55
+
56
+ let probe_config = crate::model_config::ModelConfig {
57
+ max_length,
58
+ output_tensor: String::new(),
59
+ mode: crate::model_config::ExtractorMode::Raw,
60
+ with_type_ids: false,
61
+ with_attention_mask: true,
62
+ num_threads,
63
+ optimization_level,
64
+ execution_providers: execution_providers_override.map(str::to_string),
65
+ };
66
+ let session = build_session(&model_path, &probe_config)?;
67
+
68
+ validate_supported_text_inputs(&session, "text reranking")?;
69
+ let with_type_ids = has_input(&session, "token_type_ids");
70
+ let with_attention_mask = has_input(&session, "attention_mask");
71
+ let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
72
+
73
+ let config = RerankerConfig {
74
+ max_length,
75
+ output_tensor,
76
+ with_type_ids,
77
+ with_attention_mask,
78
+ };
79
+
80
+ let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
81
+
82
+ Ok(Self {
83
+ tokenizer,
84
+ session,
85
+ config,
86
+ })
87
+ }
88
+
89
+ pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
90
+ let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
91
+ let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
92
+ let outputs = self.session.run(input_tensors.inputs)?;
93
+ let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
94
+
95
+ let mut scores = match array.ndim() {
96
+ 1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
97
+ 2 => {
98
+ let shape = array.shape();
99
+ if shape[1] == 0 {
100
+ return Err(GteError::Inference(format!(
101
+ "reranker output '{}' has invalid shape {:?}",
102
+ self.config.output_tensor, shape
103
+ )));
104
+ }
105
+ array.slice(ndarray::s![.., 0]).into_owned()
106
+ }
107
+ n => {
108
+ return Err(GteError::Inference(format!(
109
+ "reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
110
+ self.config.output_tensor, n
111
+ )))
112
+ }
113
+ };
114
+
115
+ if apply_sigmoid {
116
+ sigmoid_scores(scores.view_mut());
117
+ }
118
+
119
+ Ok(scores)
120
+ }
121
+
122
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
+ use crate::reranker::Reranker;
5
6
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
6
7
  use std::os::raw::c_void;
7
8
  use std::panic::{catch_unwind, AssertUnwindSafe};
@@ -10,6 +11,13 @@ use std::sync::Arc;
10
11
  #[wrap(class = "GTE::Embedder", free_immediately, size)]
11
12
  pub struct RbEmbedder {
12
13
  inner: Arc<Embedder>,
14
+ normalize: bool,
15
+ }
16
+
17
+ #[wrap(class = "GTE::Reranker", free_immediately, size)]
18
+ pub struct RbReranker {
19
+ inner: Arc<Reranker>,
20
+ sigmoid: bool,
13
21
  }
14
22
 
15
23
  #[wrap(class = "GTE::Tensor", free_immediately, size)]
@@ -22,11 +30,21 @@ pub struct RbTensor {
22
30
  struct InferArgs {
23
31
  embedder: *const Embedder,
24
32
  texts: *const Vec<String>,
33
+ normalize: bool,
25
34
  result: Option<Result<ndarray::Array2<f32>, GteError>>,
26
35
  }
27
36
 
28
37
  unsafe impl Send for InferArgs {}
29
38
 
39
+ struct ScoreArgs {
40
+ reranker: *const Reranker,
41
+ pairs: *const Vec<(String, String)>,
42
+ apply_sigmoid: bool,
43
+ result: Option<Result<Vec<f32>, GteError>>,
44
+ }
45
+
46
+ unsafe impl Send for ScoreArgs {}
47
+
30
48
  fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
31
49
  if let Some(msg) = payload.downcast_ref::<&str>() {
32
50
  (*msg).to_string()
@@ -37,11 +55,16 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
37
55
  }
38
56
  }
39
57
 
40
- fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
58
+ fn infer_without_gvl(
59
+ embedder: &Arc<Embedder>,
60
+ normalize: bool,
61
+ texts: Vec<String>,
62
+ ) -> Result<ndarray::Array2<f32>, Error> {
41
63
  let embeddings = unsafe {
42
64
  let mut args = InferArgs {
43
65
  embedder: Arc::as_ptr(embedder),
44
66
  texts: &texts as *const Vec<String>,
67
+ normalize,
45
68
  result: None,
46
69
  };
47
70
  rb_sys::rb_thread_call_without_gvl(
@@ -60,12 +83,44 @@ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<nda
60
83
  Ok(embeddings)
61
84
  }
62
85
 
86
+ fn score_without_gvl(
87
+ reranker: &Arc<Reranker>,
88
+ pairs: Vec<(String, String)>,
89
+ apply_sigmoid: bool,
90
+ ) -> Result<Vec<f32>, Error> {
91
+ let scores = unsafe {
92
+ let mut args = ScoreArgs {
93
+ reranker: Arc::as_ptr(reranker),
94
+ pairs: &pairs as *const Vec<(String, String)>,
95
+ apply_sigmoid,
96
+ result: None,
97
+ };
98
+ rb_sys::rb_thread_call_without_gvl(
99
+ Some(run_score_without_gvl),
100
+ &mut args as *mut ScoreArgs as *mut c_void,
101
+ None,
102
+ std::ptr::null_mut(),
103
+ );
104
+ let result = args.result.take().ok_or_else(|| {
105
+ magnus::Error::from(GteError::Inference(
106
+ "reranking did not return a result".to_string(),
107
+ ))
108
+ })?;
109
+ result.map_err(magnus::Error::from)?
110
+ };
111
+ Ok(scores)
112
+ }
113
+
63
114
  unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
115
  let args = &mut *(ptr as *mut InferArgs);
65
116
  let run_result = catch_unwind(AssertUnwindSafe(|| {
66
117
  let tokenized = (*args.embedder).tokenize(&*args.texts)?;
67
118
  let embeddings = (*args.embedder).run(&tokenized)?;
68
- Ok(normalize_l2(embeddings))
119
+ if args.normalize {
120
+ Ok(normalize_l2(embeddings))
121
+ } else {
122
+ Ok(embeddings)
123
+ }
69
124
  }));
70
125
  args.result = Some(match run_result {
71
126
  Ok(result) => result,
@@ -77,6 +132,22 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
77
132
  std::ptr::null_mut()
78
133
  }
79
134
 
135
+ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
136
+ let args = &mut *(ptr as *mut ScoreArgs);
137
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
138
+ let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
139
+ Ok(scores.to_vec())
140
+ }));
141
+ args.result = Some(match run_result {
142
+ Ok(result) => result,
143
+ Err(payload) => Err(GteError::Inference(format!(
144
+ "panic during reranking: {}",
145
+ panic_payload_to_string(payload),
146
+ ))),
147
+ });
148
+ std::ptr::null_mut()
149
+ }
150
+
80
151
  fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
81
152
  let rows = embeddings.nrows();
82
153
  let cols = embeddings.ncols();
@@ -97,31 +168,128 @@ impl RbEmbedder {
97
168
  num_threads: usize,
98
169
  optimization_level: u8,
99
170
  model_name: String,
171
+ normalize: bool,
172
+ output_tensor: String,
173
+ max_length: usize,
174
+ execution_providers: String,
100
175
  ) -> Result<Self, Error> {
101
176
  let name = if model_name.is_empty() {
102
177
  None
103
178
  } else {
104
179
  Some(model_name.as_str())
105
180
  };
106
- let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, name)
107
- .map_err(magnus::Error::from)?;
181
+ let output_override = if output_tensor.is_empty() {
182
+ None
183
+ } else {
184
+ Some(output_tensor.as_str())
185
+ };
186
+ let max_length_override = if max_length == 0 {
187
+ None
188
+ } else {
189
+ Some(max_length)
190
+ };
191
+ let execution_providers_override = if execution_providers.is_empty() {
192
+ None
193
+ } else {
194
+ Some(execution_providers.as_str())
195
+ };
196
+ let embedder = Embedder::from_dir(
197
+ &dir_path,
198
+ num_threads,
199
+ optimization_level,
200
+ name,
201
+ output_override,
202
+ max_length_override,
203
+ execution_providers_override,
204
+ )
205
+ .map_err(magnus::Error::from)?;
108
206
  Ok(RbEmbedder {
109
207
  inner: Arc::new(embedder),
208
+ normalize,
110
209
  })
111
210
  }
112
211
 
113
212
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
114
213
  let texts: Vec<String> = texts.to_vec()?;
115
- let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
214
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
116
215
  tensor_from_array(embeddings)
117
216
  }
118
217
 
119
218
  pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
120
- let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
219
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
121
220
  tensor_from_array(embeddings)
122
221
  }
123
222
  }
124
223
 
224
+ impl RbReranker {
225
+ pub fn rb_new(
226
+ _ruby: &Ruby,
227
+ dir_path: String,
228
+ num_threads: usize,
229
+ optimization_level: u8,
230
+ model_name: String,
231
+ sigmoid: bool,
232
+ output_tensor: String,
233
+ max_length: usize,
234
+ execution_providers: String,
235
+ ) -> Result<Self, Error> {
236
+ let name = if model_name.is_empty() {
237
+ None
238
+ } else {
239
+ Some(model_name.as_str())
240
+ };
241
+ let output_override = if output_tensor.is_empty() {
242
+ None
243
+ } else {
244
+ Some(output_tensor.as_str())
245
+ };
246
+ let max_length_override = if max_length == 0 {
247
+ None
248
+ } else {
249
+ Some(max_length)
250
+ };
251
+ let execution_providers_override = if execution_providers.is_empty() {
252
+ None
253
+ } else {
254
+ Some(execution_providers.as_str())
255
+ };
256
+ let reranker = Reranker::from_dir(
257
+ &dir_path,
258
+ num_threads,
259
+ optimization_level,
260
+ name,
261
+ output_override,
262
+ max_length_override,
263
+ execution_providers_override,
264
+ )
265
+ .map_err(magnus::Error::from)?;
266
+ Ok(RbReranker {
267
+ inner: Arc::new(reranker),
268
+ sigmoid,
269
+ })
270
+ }
271
+
272
+ pub fn rb_score(
273
+ ruby: &Ruby,
274
+ rb_self: &Self,
275
+ query: String,
276
+ candidates: RArray,
277
+ ) -> Result<RArray, Error> {
278
+ let candidates: Vec<String> = candidates.to_vec()?;
279
+ let pairs: Vec<(String, String)> = candidates
280
+ .into_iter()
281
+ .map(|candidate| (query.clone(), candidate))
282
+ .collect();
283
+ let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
284
+
285
+ let out = ruby.ary_new_capa(scores.len());
286
+ for score in scores {
287
+ out.push(score)?;
288
+ }
289
+ Ok(out)
290
+ }
291
+ }
292
+
125
293
  impl RbTensor {
126
294
  pub fn len(&self) -> usize {
127
295
  self.rows
@@ -208,10 +376,14 @@ impl RbTensor {
208
376
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
209
377
  let module = ruby.define_module("GTE")?;
210
378
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
211
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 4))?;
379
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
212
380
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
213
381
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
214
382
 
383
+ let reranker_class = module.define_class("Reranker", ruby.class_object())?;
384
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
385
+ reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
386
+
215
387
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
216
388
  tensor_class.define_method("rows", method!(RbTensor::rows, 0))?;
217
389
  tensor_class.define_method("size", method!(RbTensor::len, 0))?;