gte 0.0.13 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  #![cfg(feature = "ruby-ffi")]
2
+ #![allow(unsafe_code)]
3
+ #![allow(unused_results)]
4
+ #![allow(unused_qualifications)]
2
5
 
3
- use crate::embedder::{normalize_l2, Embedder};
6
+ use crate::embedder::{normalize_l2, output_name_suggests_normalized, Embedder};
4
7
  use crate::error::GteError;
5
8
  use crate::model_config::ModelLoadOverrides;
6
9
  use crate::reranker::Reranker;
@@ -66,29 +69,30 @@ unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
66
69
  let run_result = catch_unwind(AssertUnwindSafe(|| {
67
70
  // Full embedding path (tokenization + inference) runs without the GVL.
68
71
  let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
69
- if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
72
+ if args.normalize {
73
+ Ok(normalize_l2(embeddings))
74
+ } else {
75
+ Ok(embeddings)
76
+ }
70
77
  }));
71
78
  args.result = Some(match run_result {
72
79
  Ok(result) => result,
73
- Err(payload) => Err(GteError::Inference(format!(
74
- "panic during inference: {}",
75
- panic_payload_to_string(payload),
76
- ))),
80
+ Err(payload) => {
81
+ Err(GteError::Inference(format!("panic during inference: {}", panic_payload_to_string(payload),)))
82
+ }
77
83
  });
78
84
  std::ptr::null_mut()
79
85
  }
80
86
 
81
87
  unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
82
88
  let args = &mut *(ptr as *mut ScoreArgs);
83
- let run_result = catch_unwind(AssertUnwindSafe(|| {
84
- (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
85
- }));
89
+ let run_result =
90
+ catch_unwind(AssertUnwindSafe(|| (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)));
86
91
  args.result = Some(match run_result {
87
92
  Ok(result) => result,
88
- Err(payload) => Err(GteError::Inference(format!(
89
- "panic during reranking: {}",
90
- panic_payload_to_string(payload),
91
- ))),
93
+ Err(payload) => {
94
+ Err(GteError::Inference(format!("panic during reranking: {}", panic_payload_to_string(payload),)))
95
+ }
92
96
  });
93
97
  std::ptr::null_mut()
94
98
  }
@@ -99,23 +103,18 @@ fn infer_without_gvl(
99
103
  texts: Vec<String>,
100
104
  ) -> Result<ndarray::Array2<f32>, Error> {
101
105
  let embeddings = unsafe {
102
- let mut args = InferArgs {
103
- embedder: Arc::as_ptr(embedder),
104
- texts: &texts as *const Vec<String>,
105
- normalize,
106
- result: None,
107
- };
106
+ let mut args =
107
+ InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
108
108
  rb_sys::rb_thread_call_without_gvl(
109
109
  Some(run_embed_without_gvl),
110
110
  &mut args as *mut InferArgs as *mut c_void,
111
111
  None,
112
112
  std::ptr::null_mut(),
113
113
  );
114
- let result = args.result.take().ok_or_else(|| {
115
- magnus::Error::from(GteError::Inference(
116
- "inference did not return a result".to_string(),
117
- ))
118
- })?;
114
+ let result = args
115
+ .result
116
+ .take()
117
+ .ok_or_else(|| magnus::Error::from(GteError::Inference("inference did not return a result".to_string())))?;
119
118
  result.map_err(magnus::Error::from)?
120
119
  };
121
120
  Ok(embeddings)
@@ -141,11 +140,10 @@ fn score_without_gvl(
141
140
  None,
142
141
  std::ptr::null_mut(),
143
142
  );
144
- let result = args.result.take().ok_or_else(|| {
145
- magnus::Error::from(GteError::Inference(
146
- "reranking did not return a result".to_string(),
147
- ))
148
- })?;
143
+ let result = args
144
+ .result
145
+ .take()
146
+ .ok_or_else(|| magnus::Error::from(GteError::Inference("reranking did not return a result".to_string())))?;
149
147
  result.map_err(magnus::Error::from)?
150
148
  };
151
149
  Ok(scores)
@@ -158,10 +156,7 @@ fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error
158
156
  let cols = embeddings.ncols();
159
157
  let (data, offset) = embeddings.into_raw_vec_and_offset();
160
158
  if let Some(off) = offset.filter(|&o| o != 0) {
161
- return Err(magnus::Error::from(GteError::Inference(format!(
162
- "unexpected non-zero tensor offset: {}",
163
- off
164
- ))));
159
+ return Err(magnus::Error::from(GteError::Inference(format!("unexpected non-zero tensor offset: {}", off))));
165
160
  }
166
161
  Ok(RbTensor { rows, cols, data })
167
162
  }
@@ -177,22 +172,28 @@ impl RbEmbedder {
177
172
  max_length: usize,
178
173
  padding: String,
179
174
  execution_providers: String,
175
+ lowercase_input: bool,
176
+ max_input_chars: usize,
180
177
  ) -> Result<Self, Error> {
181
178
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
182
179
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
183
180
  let max_length_override = if max_length == 0 { None } else { Some(max_length) };
184
- let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
181
+ let execution_providers_override =
182
+ if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
185
183
  let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
184
+ let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
186
185
  let overrides = ModelLoadOverrides {
187
186
  model_name: name,
188
187
  output_tensor: output_override,
189
188
  max_length: max_length_override,
190
189
  padding: padding_override,
191
190
  execution_providers: execution_providers_override,
191
+ lowercase_input: Some(lowercase_input),
192
+ max_input_chars: max_input_chars_override,
192
193
  };
193
- let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides)
194
- .map_err(magnus::Error::from)?;
195
- Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
194
+ let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
195
+ let skip_normalize = normalize && output_name_suggests_normalized(&embedder.config.output_tensor);
196
+ Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
196
197
  }
197
198
 
198
199
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
@@ -218,11 +219,14 @@ impl RbReranker {
218
219
  max_length: usize,
219
220
  padding: String,
220
221
  execution_providers: String,
222
+ _lowercase_input: bool,
223
+ _max_input_chars: usize,
221
224
  ) -> Result<Self, Error> {
222
225
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
223
226
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
224
227
  let max_length_override = if max_length == 0 { None } else { Some(max_length) };
225
- let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
228
+ let execution_providers_override =
229
+ if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
226
230
  let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
227
231
  let overrides = ModelLoadOverrides {
228
232
  model_name: name,
@@ -230,18 +234,13 @@ impl RbReranker {
230
234
  max_length: max_length_override,
231
235
  padding: padding_override,
232
236
  execution_providers: execution_providers_override,
237
+ ..ModelLoadOverrides::default()
233
238
  };
234
- let reranker = Reranker::from_dir(&dir_path, optimization_level, overrides)
235
- .map_err(magnus::Error::from)?;
239
+ let reranker = Reranker::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
236
240
  Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
237
241
  }
238
242
 
239
- pub fn rb_score(
240
- ruby: &Ruby,
241
- rb_self: &Self,
242
- query: String,
243
- candidates: RArray,
244
- ) -> Result<RArray, Error> {
243
+ pub fn rb_score(ruby: &Ruby, rb_self: &Self, query: String, candidates: RArray) -> Result<RArray, Error> {
245
244
  let candidates: Vec<String> = candidates.to_vec()?;
246
245
  let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
247
246
  let out = ruby.ary_new_capa(scores.len());
@@ -296,11 +295,7 @@ impl RbTensor {
296
295
  Self::row_binary_f32(ruby, rb_self, 0)
297
296
  }
298
297
 
299
- pub fn row_binary_f32(
300
- ruby: &Ruby,
301
- rb_self: &Self,
302
- index: usize,
303
- ) -> Result<magnus::RString, Error> {
298
+ pub fn row_binary_f32(ruby: &Ruby, rb_self: &Self, index: usize) -> Result<magnus::RString, Error> {
304
299
  if index >= rb_self.rows {
305
300
  return Err(magnus::Error::from(GteError::Inference(format!(
306
301
  "row index {} out of bounds for {} rows",
@@ -340,12 +335,12 @@ impl RbTensor {
340
335
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
341
336
  let module = ruby.define_module("GTE")?;
342
337
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
343
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
338
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 10))?;
344
339
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
345
340
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
346
341
 
347
342
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
348
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
343
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 10))?;
349
344
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
350
345
 
351
346
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;