gte 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ pub mod embedder;
2
+ pub mod error;
3
+ pub mod model_config;
4
+ pub mod postprocess;
5
+ pub mod session;
6
+ pub mod tokenizer;
7
+
8
+ #[cfg(feature = "ruby-ffi")]
9
+ mod ruby_embedder;
10
+
11
+ #[cfg(feature = "ruby-ffi")]
12
+ use magnus::{prelude::*, Error, Ruby};
13
+
14
+ #[cfg(feature = "ruby-ffi")]
15
+ #[magnus::init]
16
+ fn init(ruby: &Ruby) -> Result<(), Error> {
17
+ let module = ruby.define_module("GTE")?;
18
+ module.define_error("Error", ruby.exception_standard_error())?;
19
+ crate::ruby_embedder::register(ruby)?;
20
+ std::panic::set_hook(Box::new(|info| {
21
+ let msg = info
22
+ .payload()
23
+ .downcast_ref::<&str>()
24
+ .copied()
25
+ .or_else(|| info.payload().downcast_ref::<String>().map(|s| s.as_str()))
26
+ .unwrap_or("unknown panic");
27
+ eprintln!("GTE Rust panic: {msg}");
28
+ }));
29
+
30
+ Ok(())
31
+ }
@@ -0,0 +1,17 @@
1
+ #[derive(Debug, Clone, Copy)]
2
+ pub enum ExtractorMode {
3
+ Token(usize),
4
+ MeanPool,
5
+ Raw,
6
+ }
7
+
8
+ #[derive(Debug, Clone)]
9
+ pub struct ModelConfig {
10
+ pub max_length: usize,
11
+ pub output_tensor: String,
12
+ pub mode: ExtractorMode,
13
+ pub with_type_ids: bool,
14
+ pub with_attention_mask: bool,
15
+ pub num_threads: usize,
16
+ pub optimization_level: u8,
17
+ }
@@ -0,0 +1,113 @@
1
+ use crate::error::{GteError, Result};
2
+ use ndarray::{Array2, ArrayView2, ArrayView3};
3
+
4
+ pub fn mean_pool(
5
+ hidden_states: ArrayView3<'_, f32>,
6
+ attention_mask: ArrayView2<'_, i64>,
7
+ ) -> Result<Array2<f32>> {
8
+ let (batch, seq, dim) = hidden_states.dim();
9
+ if attention_mask.dim() != (batch, seq) {
10
+ return Err(GteError::Inference(format!(
11
+ "attention mask shape {:?} does not match hidden state shape ({batch}, {seq}, {dim})",
12
+ attention_mask.dim()
13
+ )));
14
+ }
15
+
16
+ let mut pooled = Array2::<f32>::zeros((batch, dim));
17
+
18
+ if let (Some(hidden), Some(mask), Some(output)) = (
19
+ hidden_states.as_slice_memory_order(),
20
+ attention_mask.as_slice_memory_order(),
21
+ pooled.as_slice_memory_order_mut(),
22
+ ) {
23
+ mean_pool_contiguous(hidden, mask, output, batch, seq, dim);
24
+ return Ok(pooled);
25
+ }
26
+
27
+ for batch_index in 0..batch {
28
+ let mut weight_sum = 0.0f32;
29
+ for token_index in 0..seq {
30
+ let weight = attention_mask[[batch_index, token_index]];
31
+ if weight <= 0 {
32
+ continue;
33
+ }
34
+
35
+ let weight = weight as f32;
36
+ for dim_index in 0..dim {
37
+ pooled[[batch_index, dim_index]] +=
38
+ hidden_states[[batch_index, token_index, dim_index]] * weight;
39
+ }
40
+ weight_sum += weight;
41
+ }
42
+
43
+ if weight_sum > 0.0 {
44
+ let inverse = weight_sum.recip();
45
+ pooled
46
+ .row_mut(batch_index)
47
+ .map_inplace(|value| *value *= inverse);
48
+ }
49
+ }
50
+
51
+ Ok(pooled)
52
+ }
53
+
54
+ pub fn normalize_l2(mut embeddings: Array2<f32>) -> Array2<f32> {
55
+ let cols = embeddings.ncols();
56
+ if let Some(data) = embeddings.as_slice_mut() {
57
+ for row in data.chunks_mut(cols) {
58
+ let norm = row.iter().map(|v| v * v).sum::<f32>().sqrt();
59
+ if norm > 0.0 {
60
+ let inv = norm.recip();
61
+ for v in row.iter_mut() {
62
+ *v *= inv;
63
+ }
64
+ }
65
+ }
66
+ return embeddings;
67
+ }
68
+ // non-contiguous fallback
69
+ for mut row in embeddings.rows_mut() {
70
+ let norm = row.iter().map(|value| value * value).sum::<f32>().sqrt();
71
+ if norm > 0.0 {
72
+ row.map_inplace(|value| *value *= norm.recip());
73
+ }
74
+ }
75
+ embeddings
76
+ }
77
+
78
+ fn mean_pool_contiguous(
79
+ hidden: &[f32],
80
+ attention_mask: &[i64],
81
+ output: &mut [f32],
82
+ batch: usize,
83
+ seq: usize,
84
+ dim: usize,
85
+ ) {
86
+ for batch_index in 0..batch {
87
+ let mask_base = batch_index * seq;
88
+ let hidden_base = batch_index * seq * dim;
89
+ let output_row = &mut output[batch_index * dim..(batch_index + 1) * dim];
90
+ let mut weight_sum = 0.0f32;
91
+
92
+ for token_index in 0..seq {
93
+ let weight = attention_mask[mask_base + token_index];
94
+ if weight <= 0 {
95
+ continue;
96
+ }
97
+
98
+ let weight = weight as f32;
99
+ let token_base = hidden_base + token_index * dim;
100
+ for dim_index in 0..dim {
101
+ output_row[dim_index] += hidden[token_base + dim_index] * weight;
102
+ }
103
+ weight_sum += weight;
104
+ }
105
+
106
+ if weight_sum > 0.0 {
107
+ let inverse = weight_sum.recip();
108
+ for value in output_row {
109
+ *value *= inverse;
110
+ }
111
+ }
112
+ }
113
+ }
@@ -0,0 +1,222 @@
1
+ #![cfg(feature = "ruby-ffi")]
2
+
3
+ use crate::embedder::{normalize_l2, Embedder};
4
+ use crate::error::GteError;
5
+ use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
6
+ use std::os::raw::c_void;
7
+ use std::panic::{catch_unwind, AssertUnwindSafe};
8
+ use std::sync::Arc;
9
+
10
+ #[wrap(class = "GTE::Embedder", free_immediately, size)]
11
+ pub struct RbEmbedder {
12
+ inner: Arc<Embedder>,
13
+ }
14
+
15
+ #[wrap(class = "GTE::Tensor", free_immediately, size)]
16
+ pub struct RbTensor {
17
+ rows: usize,
18
+ cols: usize,
19
+ data: Vec<f32>,
20
+ }
21
+
22
+ struct InferArgs {
23
+ embedder: *const Embedder,
24
+ texts: *const Vec<String>,
25
+ result: Option<Result<ndarray::Array2<f32>, GteError>>,
26
+ }
27
+
28
+ unsafe impl Send for InferArgs {}
29
+
30
+ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
31
+ if let Some(msg) = payload.downcast_ref::<&str>() {
32
+ (*msg).to_string()
33
+ } else if let Some(msg) = payload.downcast_ref::<String>() {
34
+ msg.clone()
35
+ } else {
36
+ "unknown panic payload".to_string()
37
+ }
38
+ }
39
+
40
+ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
41
+ let embeddings = unsafe {
42
+ let mut args = InferArgs {
43
+ embedder: Arc::as_ptr(embedder),
44
+ texts: &texts as *const Vec<String>,
45
+ result: None,
46
+ };
47
+ rb_sys::rb_thread_call_without_gvl(
48
+ Some(run_without_gvl),
49
+ &mut args as *mut InferArgs as *mut c_void,
50
+ None,
51
+ std::ptr::null_mut(),
52
+ );
53
+ let result = args.result.take().ok_or_else(|| {
54
+ magnus::Error::from(GteError::Inference(
55
+ "inference did not return a result".to_string(),
56
+ ))
57
+ })?;
58
+ result.map_err(magnus::Error::from)?
59
+ };
60
+ Ok(embeddings)
61
+ }
62
+
63
+ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
+ let args = &mut *(ptr as *mut InferArgs);
65
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
66
+ let tokenized = (*args.embedder).tokenize(&*args.texts)?;
67
+ let embeddings = (*args.embedder).run(&tokenized)?;
68
+ Ok(normalize_l2(embeddings))
69
+ }));
70
+ args.result = Some(match run_result {
71
+ Ok(result) => result,
72
+ Err(payload) => Err(GteError::Inference(format!(
73
+ "panic during inference: {}",
74
+ panic_payload_to_string(payload),
75
+ ))),
76
+ });
77
+ std::ptr::null_mut()
78
+ }
79
+
80
+ fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
81
+ let rows = embeddings.nrows();
82
+ let cols = embeddings.ncols();
83
+ let (data, offset) = embeddings.into_raw_vec_and_offset();
84
+ if let Some(off) = offset.filter(|&o| o != 0) {
85
+ return Err(magnus::Error::from(GteError::Inference(format!(
86
+ "unexpected non-zero tensor offset: {}",
87
+ off
88
+ ))));
89
+ }
90
+ Ok(RbTensor { rows, cols, data })
91
+ }
92
+
93
+ impl RbEmbedder {
94
+ pub fn rb_new(
95
+ _ruby: &Ruby,
96
+ dir_path: String,
97
+ num_threads: usize,
98
+ optimization_level: u8,
99
+ ) -> Result<Self, Error> {
100
+ let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level)
101
+ .map_err(magnus::Error::from)?;
102
+ Ok(RbEmbedder {
103
+ inner: Arc::new(embedder),
104
+ })
105
+ }
106
+
107
+ pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
108
+ let texts: Vec<String> = texts.to_vec()?;
109
+ let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
110
+ tensor_from_array(embeddings)
111
+ }
112
+
113
+ pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
114
+ let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
115
+ tensor_from_array(embeddings)
116
+ }
117
+ }
118
+
119
+ impl RbTensor {
120
+ pub fn len(&self) -> usize {
121
+ self.rows
122
+ }
123
+
124
+ pub fn rows(&self) -> usize {
125
+ self.rows
126
+ }
127
+
128
+ pub fn dim(&self) -> usize {
129
+ self.cols
130
+ }
131
+
132
+ pub fn shape(ruby: &Ruby, rb_self: &Self) -> Result<RArray, Error> {
133
+ let out = ruby.ary_new_capa(2);
134
+ out.push(rb_self.rows)?;
135
+ out.push(rb_self.cols)?;
136
+ Ok(out)
137
+ }
138
+
139
+ pub fn row(ruby: &Ruby, rb_self: &Self, index: usize) -> Result<RArray, Error> {
140
+ if index >= rb_self.rows {
141
+ return Err(magnus::Error::from(GteError::Inference(format!(
142
+ "row index {} out of bounds for {} rows",
143
+ index, rb_self.rows
144
+ ))));
145
+ }
146
+
147
+ let start = index * rb_self.cols;
148
+ let end = start + rb_self.cols;
149
+ let out = ruby.ary_new_capa(rb_self.cols);
150
+ for &value in &rb_self.data[start..end] {
151
+ out.push(value)?;
152
+ }
153
+ Ok(out)
154
+ }
155
+
156
+ pub fn first(ruby: &Ruby, rb_self: &Self) -> Result<RArray, Error> {
157
+ Self::row(ruby, rb_self, 0)
158
+ }
159
+
160
+ pub fn row_binary_f32(
161
+ ruby: &Ruby,
162
+ rb_self: &Self,
163
+ index: usize,
164
+ ) -> Result<magnus::RString, Error> {
165
+ if index >= rb_self.rows {
166
+ return Err(magnus::Error::from(GteError::Inference(format!(
167
+ "row index {} out of bounds for {} rows",
168
+ index, rb_self.rows
169
+ ))));
170
+ }
171
+
172
+ let start = index * rb_self.cols;
173
+ let end = start + rb_self.cols;
174
+ let bytes = unsafe {
175
+ std::slice::from_raw_parts(
176
+ rb_self.data[start..end].as_ptr() as *const u8,
177
+ rb_self.cols * std::mem::size_of::<f32>(),
178
+ )
179
+ };
180
+ Ok(ruby.str_from_slice(bytes))
181
+ }
182
+
183
+ pub fn to_a(ruby: &Ruby, rb_self: &Self) -> Result<RArray, Error> {
184
+ let outer = ruby.ary_new_capa(rb_self.rows);
185
+ for row_idx in 0..rb_self.rows {
186
+ outer.push(Self::row(ruby, rb_self, row_idx)?)?;
187
+ }
188
+ Ok(outer)
189
+ }
190
+
191
+ pub fn to_binary_f32(ruby: &Ruby, rb_self: &Self) -> Result<magnus::RString, Error> {
192
+ let bytes = unsafe {
193
+ std::slice::from_raw_parts(
194
+ rb_self.data.as_ptr() as *const u8,
195
+ rb_self.data.len() * std::mem::size_of::<f32>(),
196
+ )
197
+ };
198
+ Ok(ruby.str_from_slice(bytes))
199
+ }
200
+ }
201
+
202
+ pub fn register(ruby: &Ruby) -> Result<(), Error> {
203
+ let module = ruby.define_module("GTE")?;
204
+ let embedder_class = module.define_class("Embedder", ruby.class_object())?;
205
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 3))?;
206
+ embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
207
+ embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
208
+
209
+ let tensor_class = module.define_class("Tensor", ruby.class_object())?;
210
+ tensor_class.define_method("rows", method!(RbTensor::rows, 0))?;
211
+ tensor_class.define_method("size", method!(RbTensor::len, 0))?;
212
+ tensor_class.define_method("length", method!(RbTensor::len, 0))?;
213
+ tensor_class.define_method("dim", method!(RbTensor::dim, 0))?;
214
+ tensor_class.define_method("shape", method!(RbTensor::shape, 0))?;
215
+ tensor_class.define_method("[]", method!(RbTensor::row, 1))?;
216
+ tensor_class.define_method("row", method!(RbTensor::row, 1))?;
217
+ tensor_class.define_method("first", method!(RbTensor::first, 0))?;
218
+ tensor_class.define_method("row_binary_f32", method!(RbTensor::row_binary_f32, 1))?;
219
+ tensor_class.define_method("to_a", method!(RbTensor::to_a, 0))?;
220
+ tensor_class.define_method("to_binary_f32", method!(RbTensor::to_binary_f32, 0))?;
221
+ Ok(())
222
+ }
@@ -0,0 +1,123 @@
1
+ use crate::error::{GteError, Result};
2
+ use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::postprocess::mean_pool;
4
+ use crate::tokenizer::Tokenized;
5
+ use ndarray::{Array2, ArrayView2, Ix2};
6
+ use ort::execution_providers::{
7
+ CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
8
+ };
9
+ use ort::session::Session;
10
+ use ort::session::SessionInputValue;
11
+ use ort::value::Value;
12
+ use std::path::Path;
13
+
14
+ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
15
+ let opt_level = match config.optimization_level {
16
+ 0 => ort::session::builder::GraphOptimizationLevel::Disable,
17
+ 1 => ort::session::builder::GraphOptimizationLevel::Level1,
18
+ 2 => ort::session::builder::GraphOptimizationLevel::Level2,
19
+ _ => ort::session::builder::GraphOptimizationLevel::Level3,
20
+ };
21
+
22
+ let mut builder = Session::builder()?
23
+ .with_optimization_level(opt_level)?
24
+ .with_memory_pattern(true)?;
25
+
26
+ let providers = preferred_execution_providers();
27
+ if !providers.is_empty() {
28
+ builder = builder.with_execution_providers(providers)?;
29
+ }
30
+
31
+ if config.num_threads > 0 {
32
+ builder = builder.with_intra_threads(config.num_threads)?;
33
+ }
34
+
35
+ Ok(builder.commit_from_file(model_path)?)
36
+ }
37
+
38
+ fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
39
+ let default_providers = if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
40
+ "xnnpack,coreml"
41
+ } else {
42
+ "xnnpack"
43
+ };
44
+ let order = std::env::var("GTE_EXECUTION_PROVIDERS")
45
+ .unwrap_or_else(|_| default_providers.to_string())
46
+ .to_ascii_lowercase();
47
+
48
+ let mut providers = Vec::new();
49
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
50
+ match provider {
51
+ "xnnpack" => providers.push(XNNPACKExecutionProvider::default().build().fail_silently()),
52
+ "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
53
+ "none" => {}
54
+ _ => {}
55
+ }
56
+ }
57
+ providers
58
+ }
59
+
60
+ pub fn run_session(
61
+ session: &Session,
62
+ tokenized: &Tokenized,
63
+ config: &ModelConfig,
64
+ ) -> Result<Array2<f32>> {
65
+ let input_ids_view: ArrayView2<'_, i64> =
66
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
67
+ let attn_masks_view: ArrayView2<'_, i64> =
68
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
69
+
70
+ let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
71
+ inputs.push((
72
+ "input_ids",
73
+ SessionInputValue::from(Value::from_array(input_ids_view)?),
74
+ ));
75
+ if config.with_attention_mask {
76
+ inputs.push((
77
+ "attention_mask",
78
+ SessionInputValue::from(Value::from_array(attn_masks_view)?),
79
+ ));
80
+ }
81
+ if let Some(type_ids) = tokenized.type_ids.as_deref() {
82
+ let type_ids_view: ArrayView2<'_, i64> =
83
+ ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
84
+ inputs.push((
85
+ "token_type_ids",
86
+ SessionInputValue::from(Value::from_array(type_ids_view)?),
87
+ ));
88
+ }
89
+
90
+ let outputs = session.run(inputs)?;
91
+ let tensor_value = outputs.get(config.output_tensor.as_str()).ok_or_else(|| {
92
+ GteError::Inference(format!(
93
+ "output tensor '{}' not found in model outputs",
94
+ &config.output_tensor
95
+ ))
96
+ })?;
97
+
98
+ let array = tensor_value.try_extract_tensor::<f32>()?;
99
+
100
+ match config.mode {
101
+ ExtractorMode::Token(idx) => {
102
+ let shape = array.shape();
103
+ if shape.len() != 3 || idx >= shape[1] {
104
+ return Err(GteError::Inference(format!(
105
+ "token extraction index {} out of bounds for output shape {:?}",
106
+ idx, shape
107
+ )));
108
+ }
109
+ Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
110
+ }
111
+ ExtractorMode::MeanPool => {
112
+ let ndim = array.ndim();
113
+ let hidden_states = array.into_dimensionality::<ndarray::Ix3>().map_err(|_| {
114
+ GteError::Inference(format!(
115
+ "mean pooling requires rank-3 output, got rank {}",
116
+ ndim
117
+ ))
118
+ })?;
119
+ mean_pool(hidden_states, attn_masks_view)
120
+ }
121
+ ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
122
+ }
123
+ }
@@ -0,0 +1,130 @@
1
+ use crate::error::{GteError, Result};
2
+ use std::path::Path;
3
+ use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
4
+
5
+ pub struct Tokenized {
6
+ pub rows: usize,
7
+ pub cols: usize,
8
+ pub input_ids: Vec<i64>,
9
+ pub attn_masks: Vec<i64>,
10
+ pub type_ids: Option<Vec<i64>>,
11
+ }
12
+
13
+ pub struct Tokenizer {
14
+ tokenizer: tokenizers::Tokenizer,
15
+ with_type_ids: bool,
16
+ }
17
+
18
+ impl Tokenizer {
19
+ pub fn new<P: AsRef<Path>>(
20
+ tokenizer_path: P,
21
+ max_length: usize,
22
+ with_type_ids: bool,
23
+ ) -> Result<Self> {
24
+ let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
25
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
26
+
27
+ let truncation = TruncationParams {
28
+ max_length,
29
+ ..Default::default()
30
+ };
31
+ tokenizer
32
+ .with_truncation(Some(truncation))
33
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
34
+
35
+ let padding = PaddingParams {
36
+ strategy: PaddingStrategy::BatchLongest,
37
+ ..Default::default()
38
+ };
39
+ tokenizer.with_padding(Some(padding));
40
+
41
+ Ok(Self {
42
+ tokenizer,
43
+ with_type_ids,
44
+ })
45
+ }
46
+
47
+ pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
48
+ if texts.len() == 1 {
49
+ let encoding = self
50
+ .tokenizer
51
+ .encode_fast(texts[0].as_str(), true)
52
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
53
+ return build_tokenized_single(&encoding, self.with_type_ids);
54
+ }
55
+
56
+ let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
57
+ let encodings = self
58
+ .tokenizer
59
+ .encode_batch_fast(encode_inputs, true)
60
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
61
+
62
+ build_tokenized(&encodings, self.with_type_ids)
63
+ }
64
+ }
65
+
66
+ fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Result<Tokenized> {
67
+ let cols = encoding.len();
68
+
69
+ let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&value| i64::from(value)).collect();
70
+ let attn_masks: Vec<i64> = encoding
71
+ .get_attention_mask()
72
+ .iter()
73
+ .map(|&value| i64::from(value))
74
+ .collect();
75
+ let type_ids: Option<Vec<i64>> = with_type_ids.then(|| {
76
+ encoding
77
+ .get_type_ids()
78
+ .iter()
79
+ .map(|&value| i64::from(value))
80
+ .collect()
81
+ });
82
+
83
+ Ok(Tokenized {
84
+ rows: 1,
85
+ cols,
86
+ input_ids,
87
+ attn_masks,
88
+ type_ids,
89
+ })
90
+ }
91
+
92
+ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Result<Tokenized> {
93
+ let rows = encodings.len();
94
+ let cols = encodings
95
+ .first()
96
+ .map(|encoding| encoding.len())
97
+ .unwrap_or(0);
98
+ let len = rows * cols;
99
+
100
+ let mut input_ids = Vec::with_capacity(len);
101
+ let mut attn_masks = Vec::with_capacity(len);
102
+ let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
103
+
104
+ for encoding in encodings {
105
+ input_ids.extend(encoding.get_ids().iter().map(|&value| i64::from(value)));
106
+ attn_masks.extend(
107
+ encoding
108
+ .get_attention_mask()
109
+ .iter()
110
+ .map(|&value| i64::from(value)),
111
+ );
112
+
113
+ if let Some(type_ids) = type_ids.as_mut() {
114
+ type_ids.extend(
115
+ encoding
116
+ .get_type_ids()
117
+ .iter()
118
+ .map(|&value| i64::from(value)),
119
+ );
120
+ }
121
+ }
122
+
123
+ Ok(Tokenized {
124
+ rows,
125
+ cols,
126
+ input_ids,
127
+ attn_masks,
128
+ type_ids,
129
+ })
130
+ }
@@ -0,0 +1,39 @@
1
+ use gte::embedder::normalize_l2;
2
+ use ndarray::array;
3
+
4
+ #[test]
5
+ fn test_normalize_l2_basic() {
6
+ let input = array![[3.0f32, 4.0], [1.0, 0.0]];
7
+ let result = normalize_l2(input);
8
+
9
+ let row0 = result.row(0);
10
+ assert!((row0[0] - 0.6).abs() < 1e-6);
11
+ assert!((row0[1] - 0.8).abs() < 1e-6);
12
+ }
13
+
14
+ #[test]
15
+ fn test_normalize_l2_zero_vector_unchanged() {
16
+ let input = array![[0.0f32, 0.0, 0.0]];
17
+ let result = normalize_l2(input);
18
+ let row = result.row(0);
19
+ assert!(row.iter().all(|&x| x == 0.0));
20
+ }
21
+
22
+ #[test]
23
+ fn test_normalize_l2_unit_norm() {
24
+ let input = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
25
+ let result = normalize_l2(input);
26
+
27
+ for row in result.rows() {
28
+ let norm: f32 = row.mapv(|x: f32| x * x).sum().sqrt();
29
+ assert!((norm - 1.0).abs() < 1e-6);
30
+ }
31
+ }
32
+
33
+ #[test]
34
+ fn test_normalize_l2_already_unit_unchanged() {
35
+ let input = array![[1.0f32, 0.0, 0.0]];
36
+ let result = normalize_l2(input.clone());
37
+ let row = result.row(0);
38
+ assert!((row[0] - 1.0).abs() < 1e-6 && row[1] == 0.0 && row[2] == 0.0);
39
+ }