gte 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c0a8a756612408d0a4c8ae4597336ececd610e468764b5c76a3b0c7ba7b52e85
4
- data.tar.gz: 70e27e75cc17f8ca7acca5048581cf8dccf20d40188bb527211240350114f79d
3
+ metadata.gz: ae83f737b57f798d39cf1fdc895d67948de27d36b46ea02c211a440d3acaa8c9
4
+ data.tar.gz: 9eaf9651b2ccf1fdb93efe4666ed70537628453a8cf92e234b454560560a83e8
5
5
  SHA512:
6
- metadata.gz: 0da54bb7d8b58a189463f6a078f315947d4de06d1c1bb7a55f7ab645f5821aec16139e70b210c5608983b50c0fedf1b548a2d49dd437b9bcce7b534d30dac4b9
7
- data.tar.gz: 453b8e43d37b5c7a6b5ced440a5bd993c1e7b9a46145337dee4e0277da30fe5f341fec77f797b14d587edbb4e40cc75aa1821e7cbe644da5f8b6272dc29c03e7
6
+ metadata.gz: a262194a53bf804e47b0ef9c5910c1e2b814a9824823a92a73867a631c7b26310b3163e61997d9c163dab402a40d49946b76a64cc0421741ae235f623180cb95
7
+ data.tar.gz: 6acf5b58140012df9fa25971ed0f1fdfa707cc3efbe5f7f22104e35ad57877778a08cf9f8b311017be8f40e255289e3249e35c1e3780ae231f9f66e08cbb6ac3
data/README.md CHANGED
@@ -9,14 +9,105 @@ Inspired by https://github.com/fbilhaut/gte-rs
9
9
  ```ruby
10
10
  require "gte"
11
11
 
12
- model = GTE.new(ENV.fetch("GTE_MODEL_DIR"))
13
- vector = model["query: hello world"]
12
+ model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
13
+
14
+ # String input => GTE::Tensor (1 row)
15
+ tensor = model.embed("query: hello world")
16
+ vector = tensor.row(0)
17
+
18
+ # [] with string => Array<Float> (single vector)
19
+ single = model["query: nearest coffee shop"]
20
+
21
+ # [] with array => GTE::Tensor (batch)
22
+ batch = model[["query: hello", "query: world"]]
23
+ ```
24
+
25
+ ## Embedding Config (`GTE.config`)
26
+
27
+ `GTE.config(model_dir)` builds (and caches) a `GTE::Model`.
28
+
29
+ ```ruby
30
+ default_model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
31
+
32
+ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
+ config.with(normalize: false)
34
+ end
35
+
36
+ full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
+ config.with(threads: 0)
38
+ end
39
+
40
+ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
+ config.with(
42
+ output_tensor: "last_hidden_state",
43
+ max_length: 256,
44
+ optimization_level: 3
45
+ )
46
+ end
47
+ ```
48
+
49
+ Config fields and defaults:
50
+
51
+ - `model_dir`: absolute path to model directory
52
+ - `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
53
+ - `optimization_level`: `3`
54
+ - `model_name`: `nil`
55
+ - `normalize`: `true` (L2 normalization at Ruby-facing API)
56
+ - `output_tensor`: `nil` (auto-select output tensor)
57
+ - `max_length`: `nil` (uses tokenizer/model defaults)
58
+
59
+ Notes:
60
+
61
+ - Return a `Config::Text` from the block (for example, `config.with(...)`).
62
+ - Model instances are cached by full config key; different config values create different cached instances.
63
+
64
+ ## Reranker
65
+
66
+ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
67
+
68
+ ```ruby
69
+ reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
70
+ config.with(sigmoid: true, threads: 0)
71
+ end
72
+
73
+ query = "how to train a neural network?"
74
+ candidates = [
75
+ "Backpropagation and gradient descent are core techniques.",
76
+ "This recipe uses flour and eggs."
77
+ ]
78
+
79
+ # Raw scores aligned with input order
80
+ scores = reranker.score(query, candidates)
81
+ # => [0.93, 0.07]
82
+
83
+ # Ranked output sorted by score desc
84
+ ranked = reranker.rerank(query: query, candidates: candidates)
85
+ # => [
86
+ # { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
87
+ # { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
88
+ # ]
14
89
  ```
15
90
 
16
- For Puma or other thread pools, prefer process-local reuse:
91
+ Reranker config fields and defaults:
92
+
93
+ - `model_dir`: absolute path to model directory
94
+ - `threads`: `3`
95
+ - `optimization_level`: `3`
96
+ - `model_name`: `nil`
97
+ - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
98
+ - `output_tensor`: `nil`
99
+ - `max_length`: `nil`
100
+
101
+ ## Runtime + Result Examples
102
+
103
+ Process-local reuse (recommended for Puma/web servers):
17
104
 
18
105
  ```ruby
19
- MODEL = GTE.new(ENV.fetch("GTE_MODEL_DIR"))
106
+ EMBEDDER = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
107
+
108
+ def embed_query(text)
109
+ EMBEDDER[text] # Array<Float>
110
+ end
20
111
  ```
21
112
 
22
113
  ## Model Directory
@@ -28,14 +119,28 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
28
119
  3. `onnx/model.onnx`
29
120
  4. `model.onnx`
30
121
 
122
+ Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
123
+
124
+ ## Execution Providers
125
+
126
+ Default execution provider is `xnnpack` on all platforms (including macOS arm64).
127
+
128
+ To opt in to CoreML explicitly:
129
+
130
+ ```bash
131
+ export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
132
+ ```
133
+
31
134
  ## Development
32
135
 
33
- Run commands inside `nix develop`.
136
+ Run commands inside `nix develop` via Make targets:
34
137
 
35
138
  ```bash
36
- bundle exec rake compile
37
- cargo test --manifest-path ext/gte/Cargo.toml --no-default-features
38
- bundle exec rspec
139
+ make setup
140
+ make compile
141
+ make test
142
+ make lint
143
+ make ci
39
144
  ```
40
145
 
41
146
  ## Benchmark
@@ -43,14 +148,14 @@ bundle exec rspec
43
148
  The repo includes two benchmark paths:
44
149
 
45
150
  ```bash
46
- bundle exec rake bench:pure_compare
47
- bundle exec rake bench:puma_compare
48
- bundle exec rake bench:matrix_sweep
49
- bundle exec ruby bench/memory_probe.rb --compare-pure
151
+ make bench
152
+ nix develop -c bundle exec rake bench:pure_compare
153
+ nix develop -c bundle exec rake bench:matrix_sweep
154
+ nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
50
155
  ```
51
156
 
52
157
  For release tracking and regression detection, record a run entry in `RUNS.md`:
53
158
 
54
159
  ```bash
55
- bundle exec rake bench:record_run
160
+ make bench-record
56
161
  ```
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.4
1
+ 0.0.5
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -1,19 +1,15 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::model_profile::{
4
+ has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
+ };
3
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
4
8
  use crate::session::{build_session, run_session};
5
9
  use crate::tokenizer::{Tokenized, Tokenizer};
6
10
  use ndarray::Array2;
7
11
  use ort::session::Session;
8
- use std::path::{Path, PathBuf};
9
-
10
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
11
- pub enum ModelFamily {
12
- E5Like,
13
- SiglipLike,
14
- ClipLike,
15
- Other,
16
- }
12
+ use std::path::Path;
17
13
 
18
14
  pub struct Embedder {
19
15
  tokenizer: Tokenizer,
@@ -41,39 +37,50 @@ impl Embedder {
41
37
  num_threads: usize,
42
38
  optimization_level: u8,
43
39
  model_name: Option<&str>,
40
+ output_tensor_override: Option<&str>,
41
+ max_length_override: Option<usize>,
44
42
  ) -> Result<Self> {
43
+ const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
44
+ "pooler_output",
45
+ "text_embeds",
46
+ "sentence_embedding",
47
+ "last_hidden_state",
48
+ ];
49
+
45
50
  let dir = dir.as_ref();
46
- let tokenizer_path = dir.join("tokenizer.json");
51
+ let tokenizer_path = resolve_tokenizer_path(dir)?;
47
52
  let model_path = match model_name.filter(|s| !s.is_empty()) {
48
53
  Some(name) => resolve_named_model(dir, name)?,
49
- None => resolve_model_path(dir)?,
54
+ None => resolve_default_text_model(dir)?,
50
55
  };
51
56
 
52
- if !tokenizer_path.exists() {
53
- return Err(GteError::Tokenizer(format!(
54
- "tokenizer.json not found in {}",
55
- dir.display()
56
- )));
57
- }
57
+ let max_length = if let Some(override_value) = max_length_override {
58
+ if override_value == 0 {
59
+ return Err(GteError::Inference(
60
+ "max_length override must be greater than 0".to_string(),
61
+ ));
62
+ }
63
+ override_value
64
+ } else {
65
+ read_max_length(dir)
66
+ };
58
67
 
59
- let max_length = read_max_length(dir);
60
- let probe_num_threads = if num_threads == 0 { 1 } else { num_threads };
61
- let temp_config = ModelConfig {
68
+ let session_config = ModelConfig {
62
69
  max_length,
63
70
  output_tensor: String::new(),
64
71
  mode: ExtractorMode::Raw,
65
72
  with_type_ids: false,
66
73
  with_attention_mask: true,
67
- num_threads: probe_num_threads,
74
+ num_threads,
68
75
  optimization_level,
69
76
  };
70
- let mut session = build_session(&model_path, &temp_config)?;
77
+ let session = build_session(&model_path, &session_config)?;
71
78
 
72
- validate_supported_inputs(&session)?;
73
- let with_type_ids = session.inputs.iter().any(|i| i.name == "token_type_ids");
74
- let with_attention_mask = session.inputs.iter().any(|i| i.name == "attention_mask");
75
- let output_tensor = select_output_tensor(&session)?;
76
- let output_base = output_basename(output_tensor.as_str()).to_string();
79
+ validate_supported_text_inputs(&session, "text embedding")?;
80
+ let with_type_ids = has_input(&session, "token_type_ids");
81
+ let with_attention_mask = has_input(&session, "attention_mask");
82
+ let output_tensor =
83
+ select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
77
84
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
78
85
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
79
86
  return Err(GteError::Inference(
@@ -81,29 +88,16 @@ impl Embedder {
81
88
  ));
82
89
  }
83
90
 
84
- let tuned_num_threads = tune_num_threads(
85
- num_threads,
86
- with_attention_mask,
87
- with_type_ids,
88
- output_base.as_str(),
89
- );
90
-
91
91
  let config = ModelConfig {
92
92
  max_length,
93
93
  output_tensor,
94
94
  mode,
95
95
  with_type_ids,
96
96
  with_attention_mask,
97
- num_threads: tuned_num_threads,
97
+ num_threads,
98
98
  optimization_level,
99
99
  };
100
100
 
101
- if tuned_num_threads != probe_num_threads {
102
- // Release probe session before rebuilding to minimize transient peak RSS.
103
- drop(session);
104
- session = build_session(&model_path, &config)?;
105
- }
106
-
107
101
  let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
108
102
 
109
103
  Ok(Self {
@@ -125,218 +119,6 @@ impl Embedder {
125
119
  pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
126
120
  run_session(&self.session, tokenized, &self.config)
127
121
  }
128
-
129
- }
130
-
131
- fn tune_num_threads(
132
- requested: usize,
133
- with_attention_mask: bool,
134
- with_type_ids: bool,
135
- output_name: &str,
136
- ) -> usize {
137
- if requested > 0 {
138
- return requested;
139
- }
140
-
141
- let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
142
-
143
- match family {
144
- // Puma-like workloads typically run many concurrent single-item requests where
145
- // one intra-op thread per request gives the best tail behavior.
146
- ModelFamily::E5Like | ModelFamily::ClipLike => 1,
147
- // Siglip2 text path benefits from a small intra-op pool under concurrency.
148
- ModelFamily::SiglipLike => 3,
149
- ModelFamily::Other => 0,
150
- }
151
- }
152
-
153
- fn infer_model_family(
154
- with_attention_mask: bool,
155
- with_type_ids: bool,
156
- output_name: &str,
157
- ) -> ModelFamily {
158
- if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
159
- return ModelFamily::E5Like;
160
- }
161
- if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
162
- return ModelFamily::SiglipLike;
163
- }
164
- if output_name == "text_embeds" && !with_attention_mask {
165
- return ModelFamily::ClipLike;
166
- }
167
- ModelFamily::Other
168
- }
169
-
170
- fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
171
- let candidates = [dir.join("onnx").join(name), dir.join(name)];
172
- for path in &candidates {
173
- if path.exists() {
174
- return Ok(path.clone());
175
- }
176
- }
177
- Err(GteError::Inference(format!(
178
- "model '{}' not found in {} (checked onnx/{0} and {0})",
179
- name,
180
- dir.display()
181
- )))
182
- }
183
-
184
- fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
185
- let candidates = [
186
- dir.join("onnx").join("text_model.onnx"),
187
- dir.join("text_model.onnx"),
188
- dir.join("onnx").join("model.onnx"),
189
- dir.join("model.onnx"),
190
- ];
191
- for path in &candidates {
192
- if path.exists() {
193
- return Ok(path.clone());
194
- }
195
- }
196
- Err(GteError::Inference(format!(
197
- "no ONNX model found in {} (checked text_model.onnx and model.onnx)",
198
- dir.display()
199
- )))
200
- }
201
-
202
- const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
203
-
204
- fn validate_supported_inputs(session: &Session) -> Result<()> {
205
- let unsupported: Vec<String> = session
206
- .inputs
207
- .iter()
208
- .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
209
- .map(|i| i.name.clone())
210
- .collect();
211
-
212
- if unsupported.is_empty() {
213
- return Ok(());
214
- }
215
-
216
- let mut message = format!(
217
- "unsupported model inputs for text embedding API: {}",
218
- unsupported.join(", ")
219
- );
220
- if unsupported.iter().any(|n| n == "pixel_values") {
221
- message.push_str(
222
- ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
223
- );
224
- } else {
225
- message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
226
- }
227
- Err(GteError::Inference(message))
228
- }
229
-
230
- fn output_name_matches(name: &str, preferred: &str) -> bool {
231
- let lower = name.to_ascii_lowercase();
232
- lower == preferred || lower.ends_with(&format!("/{}", preferred))
233
- }
234
-
235
- fn select_output_tensor(session: &Session) -> Result<String> {
236
- const PREFERRED: [&str; 4] = [
237
- "text_embeds",
238
- "pooler_output",
239
- "sentence_embedding",
240
- "last_hidden_state",
241
- ];
242
-
243
- for preferred in PREFERRED {
244
- if let Some(output) = session
245
- .outputs
246
- .iter()
247
- .find(|o| output_name_matches(o.name.as_str(), preferred))
248
- {
249
- return Ok(output.name.clone());
250
- }
251
- }
252
-
253
- session
254
- .outputs
255
- .first()
256
- .map(|o| o.name.clone())
257
- .ok_or_else(|| GteError::Inference("model has no outputs".into()))
258
- }
259
-
260
- fn read_max_length(dir: &Path) -> usize {
261
- (|| -> Option<usize> {
262
- let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
263
- let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
264
- let v = json.get("model_max_length")?;
265
- let n = v
266
- .as_u64()
267
- .or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
268
- Some((n as usize).min(8192))
269
- })()
270
- .unwrap_or(512)
271
- }
272
-
273
- #[cfg(test)]
274
- mod tests {
275
- use super::{infer_model_family, tune_num_threads, ModelFamily};
276
-
277
- #[test]
278
- fn infer_model_family_recognizes_known_signatures() {
279
- assert_eq!(
280
- infer_model_family(true, true, "last_hidden_state"),
281
- ModelFamily::E5Like
282
- );
283
- assert_eq!(
284
- infer_model_family(true, false, "last_hidden_state"),
285
- ModelFamily::SiglipLike
286
- );
287
- assert_eq!(
288
- infer_model_family(false, false, "text_embeds"),
289
- ModelFamily::ClipLike
290
- );
291
- assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
292
- }
293
-
294
- #[test]
295
- fn tune_num_threads_respects_requested_value() {
296
- assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
297
- }
298
-
299
- #[test]
300
- fn tune_num_threads_returns_ort_default_for_other_family() {
301
- assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
302
- }
303
- }
304
-
305
- fn output_basename(name: &str) -> &str {
306
- name.rsplit('/').next().unwrap_or(name)
307
- }
308
-
309
- fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
310
- let output = session
311
- .outputs
312
- .iter()
313
- .find(|o| o.name == output_tensor)
314
- .ok_or_else(|| {
315
- GteError::Inference(format!(
316
- "output tensor '{}' not found in model outputs",
317
- output_tensor
318
- ))
319
- })?;
320
-
321
- let ndims = match &output.output_type {
322
- ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
323
- other => {
324
- return Err(GteError::Inference(format!(
325
- "output is not a tensor: {:?}",
326
- other
327
- )))
328
- }
329
- };
330
-
331
- match (output_basename(output_tensor), ndims) {
332
- ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
333
- (_, 2) => Ok(ExtractorMode::Raw),
334
- (_, 3) => Ok(ExtractorMode::MeanPool),
335
- (_, n) => Err(GteError::Inference(format!(
336
- "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
337
- n, output_tensor
338
- ))),
339
- }
340
122
  }
341
123
 
342
124
  pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
data/ext/gte/src/lib.rs CHANGED
@@ -1,7 +1,10 @@
1
1
  pub mod embedder;
2
2
  pub mod error;
3
3
  pub mod model_config;
4
+ pub mod model_profile;
5
+ pub mod pipeline;
4
6
  pub mod postprocess;
7
+ pub mod reranker;
5
8
  pub mod session;
6
9
  pub mod tokenizer;
7
10