gte 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b7ce34f894403d3d2767d9c7f694aa712b42af251b0babf741e2dcd9dd6c7a27
4
- data.tar.gz: c91aa21b10b2a20358c5d56c511623927c6e4cd4e0667cc7f40cdca405a4d10f
3
+ metadata.gz: ae83f737b57f798d39cf1fdc895d67948de27d36b46ea02c211a440d3acaa8c9
4
+ data.tar.gz: 9eaf9651b2ccf1fdb93efe4666ed70537628453a8cf92e234b454560560a83e8
5
5
  SHA512:
6
- metadata.gz: 87e824d3fa79dc67a9584b902d17329aa85eb4f8fc4a358a6350c7f19e3d4e3c170a59b852abd16332caada49106bbba3356b6a5486bbb52c97b8bef22b1b9a0
7
- data.tar.gz: 0dfeb1f6b4223f7ee88609411b94548740b588d89a92b55ba7e093564417086f24a12ebbf98bfee6a9fbd4c74d0f55dc0d66c2a6095d0d7ad7d9b1adca1b2eb7
6
+ metadata.gz: a262194a53bf804e47b0ef9c5910c1e2b814a9824823a92a73867a631c7b26310b3163e61997d9c163dab402a40d49946b76a64cc0421741ae235f623180cb95
7
+ data.tar.gz: 6acf5b58140012df9fa25971ed0f1fdfa707cc3efbe5f7f22104e35ad57877778a08cf9f8b311017be8f40e255289e3249e35c1e3780ae231f9f66e08cbb6ac3
data/README.md CHANGED
@@ -9,8 +9,105 @@ Inspired by https://github.com/fbilhaut/gte-rs
9
9
  ```ruby
10
10
  require "gte"
11
11
 
12
- model = GTE.new(ENV.fetch("GTE_MODEL_DIR"))
13
- vector = model["query: hello world"]
12
+ model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
13
+
14
+ # String input => GTE::Tensor (1 row)
15
+ tensor = model.embed("query: hello world")
16
+ vector = tensor.row(0)
17
+
18
+ # [] with string => Array<Float> (single vector)
19
+ single = model["query: nearest coffee shop"]
20
+
21
+ # [] with array => GTE::Tensor (batch)
22
+ batch = model[["query: hello", "query: world"]]
23
+ ```
24
+
25
+ ## Embedding Config (`GTE.config`)
26
+
27
+ `GTE.config(model_dir)` builds (and caches) a `GTE::Model`.
28
+
29
+ ```ruby
30
+ default_model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
31
+
32
+ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
+ config.with(normalize: false)
34
+ end
35
+
36
+ full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
+ config.with(threads: 0)
38
+ end
39
+
40
+ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
+ config.with(
42
+ output_tensor: "last_hidden_state",
43
+ max_length: 256,
44
+ optimization_level: 3
45
+ )
46
+ end
47
+ ```
48
+
49
+ Config fields and defaults:
50
+
51
+ - `model_dir`: absolute path to model directory
52
+ - `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
53
+ - `optimization_level`: `3`
54
+ - `model_name`: `nil`
55
+ - `normalize`: `true` (L2 normalization at Ruby-facing API)
56
+ - `output_tensor`: `nil` (auto-select output tensor)
57
+ - `max_length`: `nil` (uses tokenizer/model defaults)
58
+
59
+ Notes:
60
+
61
+ - Return a `Config::Text` from the block (for example, `config.with(...)`).
62
+ - Model instances are cached by full config key; different config values create different cached instances.
63
+
64
+ ## Reranker
65
+
66
+ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
67
+
68
+ ```ruby
69
+ reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
70
+ config.with(sigmoid: true, threads: 0)
71
+ end
72
+
73
+ query = "how to train a neural network?"
74
+ candidates = [
75
+ "Backpropagation and gradient descent are core techniques.",
76
+ "This recipe uses flour and eggs."
77
+ ]
78
+
79
+ # Raw scores aligned with input order
80
+ scores = reranker.score(query, candidates)
81
+ # => [0.93, 0.07]
82
+
83
+ # Ranked output sorted by score desc
84
+ ranked = reranker.rerank(query: query, candidates: candidates)
85
+ # => [
86
+ # { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
87
+ # { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
88
+ # ]
89
+ ```
90
+
91
+ Reranker config fields and defaults:
92
+
93
+ - `model_dir`: absolute path to model directory
94
+ - `threads`: `3`
95
+ - `optimization_level`: `3`
96
+ - `model_name`: `nil`
97
+ - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
98
+ - `output_tensor`: `nil`
99
+ - `max_length`: `nil`
100
+
101
+ ## Runtime + Result Examples
102
+
103
+ Process-local reuse (recommended for Puma/web servers):
104
+
105
+ ```ruby
106
+ EMBEDDER = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
107
+
108
+ def embed_query(text)
109
+ EMBEDDER[text] # Array<Float>
110
+ end
14
111
  ```
15
112
 
16
113
  ## Model Directory
@@ -22,14 +119,28 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
22
119
  3. `onnx/model.onnx`
23
120
  4. `model.onnx`
24
121
 
122
+ Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
123
+
124
+ ## Execution Providers
125
+
126
+ Default execution provider is `xnnpack` on all platforms (including macOS arm64).
127
+
128
+ To opt in to CoreML explicitly:
129
+
130
+ ```bash
131
+ export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
132
+ ```
133
+
25
134
  ## Development
26
135
 
27
- Run commands inside `nix develop`.
136
+ Run commands inside `nix develop` via Make targets:
28
137
 
29
138
  ```bash
30
- bundle exec rake compile
31
- cargo test --manifest-path ext/gte/Cargo.toml --no-default-features
32
- bundle exec rspec
139
+ make setup
140
+ make compile
141
+ make test
142
+ make lint
143
+ make ci
33
144
  ```
34
145
 
35
146
  ## Benchmark
@@ -37,13 +148,14 @@ bundle exec rspec
37
148
  The repo includes two benchmark paths:
38
149
 
39
150
  ```bash
40
- bundle exec rake bench:pure_compare
41
- bundle exec rake bench:puma_compare
42
- bundle exec rake bench:matrix_sweep
151
+ make bench
152
+ nix develop -c bundle exec rake bench:pure_compare
153
+ nix develop -c bundle exec rake bench:matrix_sweep
154
+ nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
43
155
  ```
44
156
 
45
157
  For release tracking and regression detection, record a run entry in `RUNS.md`:
46
158
 
47
159
  ```bash
48
- bundle exec rake bench:record_run
160
+ make bench-record
49
161
  ```
data/Rakefile CHANGED
@@ -48,6 +48,14 @@ namespace :bench do
48
48
  )
49
49
  end
50
50
 
51
+ desc 'Run memory probe for single-instance vs duplicate-instance behavior'
52
+ task :memory_probe do
53
+ run_in_nix(
54
+ 'bundle', 'exec', 'ruby', 'bench/memory_probe.rb',
55
+ '--compare-pure'
56
+ )
57
+ end
58
+
51
59
  desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal/regression checks'
52
60
  task :record_run do
53
61
  run_in_nix(
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.3
1
+ 0.0.5
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.3"
3
+ version = "0.0.5"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -1,19 +1,15 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ExtractorMode, ModelConfig};
3
+ use crate::model_profile::{
4
+ has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
+ };
3
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
4
8
  use crate::session::{build_session, run_session};
5
9
  use crate::tokenizer::{Tokenized, Tokenizer};
6
10
  use ndarray::Array2;
7
11
  use ort::session::Session;
8
- use std::path::{Path, PathBuf};
9
-
10
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
11
- pub enum ModelFamily {
12
- E5Like,
13
- SiglipLike,
14
- ClipLike,
15
- Other,
16
- }
12
+ use std::path::Path;
17
13
 
18
14
  pub struct Embedder {
19
15
  tokenizer: Tokenizer,
@@ -41,23 +37,35 @@ impl Embedder {
41
37
  num_threads: usize,
42
38
  optimization_level: u8,
43
39
  model_name: Option<&str>,
40
+ output_tensor_override: Option<&str>,
41
+ max_length_override: Option<usize>,
44
42
  ) -> Result<Self> {
43
+ const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
44
+ "pooler_output",
45
+ "text_embeds",
46
+ "sentence_embedding",
47
+ "last_hidden_state",
48
+ ];
49
+
45
50
  let dir = dir.as_ref();
46
- let tokenizer_path = dir.join("tokenizer.json");
51
+ let tokenizer_path = resolve_tokenizer_path(dir)?;
47
52
  let model_path = match model_name.filter(|s| !s.is_empty()) {
48
53
  Some(name) => resolve_named_model(dir, name)?,
49
- None => resolve_model_path(dir)?,
54
+ None => resolve_default_text_model(dir)?,
50
55
  };
51
56
 
52
- if !tokenizer_path.exists() {
53
- return Err(GteError::Tokenizer(format!(
54
- "tokenizer.json not found in {}",
55
- dir.display()
56
- )));
57
- }
57
+ let max_length = if let Some(override_value) = max_length_override {
58
+ if override_value == 0 {
59
+ return Err(GteError::Inference(
60
+ "max_length override must be greater than 0".to_string(),
61
+ ));
62
+ }
63
+ override_value
64
+ } else {
65
+ read_max_length(dir)
66
+ };
58
67
 
59
- let max_length = read_max_length(dir);
60
- let temp_config = ModelConfig {
68
+ let session_config = ModelConfig {
61
69
  max_length,
62
70
  output_tensor: String::new(),
63
71
  mode: ExtractorMode::Raw,
@@ -66,13 +74,13 @@ impl Embedder {
66
74
  num_threads,
67
75
  optimization_level,
68
76
  };
69
- let session = build_session(&model_path, &temp_config)?;
77
+ let session = build_session(&model_path, &session_config)?;
70
78
 
71
- validate_supported_inputs(&session)?;
72
- let with_type_ids = session.inputs.iter().any(|i| i.name == "token_type_ids");
73
- let with_attention_mask = session.inputs.iter().any(|i| i.name == "attention_mask");
74
- let output_tensor = select_output_tensor(&session)?;
75
- let output_base = output_basename(output_tensor.as_str()).to_string();
79
+ validate_supported_text_inputs(&session, "text embedding")?;
80
+ let with_type_ids = has_input(&session, "token_type_ids");
81
+ let with_attention_mask = has_input(&session, "attention_mask");
82
+ let output_tensor =
83
+ select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
76
84
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
77
85
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
78
86
  return Err(GteError::Inference(
@@ -80,29 +88,16 @@ impl Embedder {
80
88
  ));
81
89
  }
82
90
 
83
- let tuned_num_threads = tune_num_threads(
84
- num_threads,
85
- with_attention_mask,
86
- with_type_ids,
87
- output_base.as_str(),
88
- );
89
-
90
91
  let config = ModelConfig {
91
92
  max_length,
92
93
  output_tensor,
93
94
  mode,
94
95
  with_type_ids,
95
96
  with_attention_mask,
96
- num_threads: tuned_num_threads,
97
+ num_threads,
97
98
  optimization_level,
98
99
  };
99
100
 
100
- let session = if tuned_num_threads != num_threads {
101
- build_session(&model_path, &config)?
102
- } else {
103
- session
104
- };
105
-
106
101
  let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
107
102
 
108
103
  Ok(Self {
@@ -124,235 +119,6 @@ impl Embedder {
124
119
  pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
125
120
  run_session(&self.session, tokenized, &self.config)
126
121
  }
127
-
128
- }
129
-
130
- fn tune_num_threads(
131
- requested: usize,
132
- with_attention_mask: bool,
133
- with_type_ids: bool,
134
- output_name: &str,
135
- ) -> usize {
136
- if requested > 0 {
137
- return requested;
138
- }
139
-
140
- let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
141
- let target_concurrency = puma_target_concurrency();
142
- let host_cores = host_parallelism();
143
- let budgeted_threads = (host_cores / target_concurrency).max(1);
144
-
145
- match family {
146
- // Puma-like workloads typically run many concurrent single-item requests where
147
- // one intra-op thread per request gives the best tail behavior.
148
- ModelFamily::E5Like | ModelFamily::ClipLike | ModelFamily::SiglipLike => {
149
- budgeted_threads.min(1)
150
- }
151
- ModelFamily::Other => 0,
152
- }
153
- }
154
-
155
- fn infer_model_family(
156
- with_attention_mask: bool,
157
- with_type_ids: bool,
158
- output_name: &str,
159
- ) -> ModelFamily {
160
- if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
161
- return ModelFamily::E5Like;
162
- }
163
- if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
164
- return ModelFamily::SiglipLike;
165
- }
166
- if output_name == "text_embeds" && !with_attention_mask {
167
- return ModelFamily::ClipLike;
168
- }
169
- ModelFamily::Other
170
- }
171
-
172
- fn puma_target_concurrency() -> usize {
173
- std::env::var("GTE_PUMA_CONCURRENCY")
174
- .ok()
175
- .and_then(|raw| raw.parse::<usize>().ok())
176
- .filter(|value| *value > 0)
177
- .unwrap_or(16)
178
- }
179
-
180
- fn host_parallelism() -> usize {
181
- std::thread::available_parallelism()
182
- .map(|n| n.get())
183
- .unwrap_or(1)
184
- }
185
-
186
- fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
187
- let candidates = [dir.join("onnx").join(name), dir.join(name)];
188
- for path in &candidates {
189
- if path.exists() {
190
- return Ok(path.clone());
191
- }
192
- }
193
- Err(GteError::Inference(format!(
194
- "model '{}' not found in {} (checked onnx/{0} and {0})",
195
- name,
196
- dir.display()
197
- )))
198
- }
199
-
200
- fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
201
- let candidates = [
202
- dir.join("onnx").join("text_model.onnx"),
203
- dir.join("text_model.onnx"),
204
- dir.join("onnx").join("model.onnx"),
205
- dir.join("model.onnx"),
206
- ];
207
- for path in &candidates {
208
- if path.exists() {
209
- return Ok(path.clone());
210
- }
211
- }
212
- Err(GteError::Inference(format!(
213
- "no ONNX model found in {} (checked text_model.onnx and model.onnx)",
214
- dir.display()
215
- )))
216
- }
217
-
218
- const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
219
-
220
- fn validate_supported_inputs(session: &Session) -> Result<()> {
221
- let unsupported: Vec<String> = session
222
- .inputs
223
- .iter()
224
- .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
225
- .map(|i| i.name.clone())
226
- .collect();
227
-
228
- if unsupported.is_empty() {
229
- return Ok(());
230
- }
231
-
232
- let mut message = format!(
233
- "unsupported model inputs for text embedding API: {}",
234
- unsupported.join(", ")
235
- );
236
- if unsupported.iter().any(|n| n == "pixel_values") {
237
- message.push_str(
238
- ". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
239
- );
240
- } else {
241
- message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
242
- }
243
- Err(GteError::Inference(message))
244
- }
245
-
246
- fn output_name_matches(name: &str, preferred: &str) -> bool {
247
- let lower = name.to_ascii_lowercase();
248
- lower == preferred || lower.ends_with(&format!("/{}", preferred))
249
- }
250
-
251
- fn select_output_tensor(session: &Session) -> Result<String> {
252
- const PREFERRED: [&str; 4] = [
253
- "text_embeds",
254
- "pooler_output",
255
- "sentence_embedding",
256
- "last_hidden_state",
257
- ];
258
-
259
- for preferred in PREFERRED {
260
- if let Some(output) = session
261
- .outputs
262
- .iter()
263
- .find(|o| output_name_matches(o.name.as_str(), preferred))
264
- {
265
- return Ok(output.name.clone());
266
- }
267
- }
268
-
269
- session
270
- .outputs
271
- .first()
272
- .map(|o| o.name.clone())
273
- .ok_or_else(|| GteError::Inference("model has no outputs".into()))
274
- }
275
-
276
- fn read_max_length(dir: &Path) -> usize {
277
- (|| -> Option<usize> {
278
- let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
279
- let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
280
- let v = json.get("model_max_length")?;
281
- let n = v
282
- .as_u64()
283
- .or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
284
- Some((n as usize).min(8192))
285
- })()
286
- .unwrap_or(512)
287
- }
288
-
289
- #[cfg(test)]
290
- mod tests {
291
- use super::{infer_model_family, tune_num_threads, ModelFamily};
292
-
293
- #[test]
294
- fn infer_model_family_recognizes_known_signatures() {
295
- assert_eq!(
296
- infer_model_family(true, true, "last_hidden_state"),
297
- ModelFamily::E5Like
298
- );
299
- assert_eq!(
300
- infer_model_family(true, false, "last_hidden_state"),
301
- ModelFamily::SiglipLike
302
- );
303
- assert_eq!(
304
- infer_model_family(false, false, "text_embeds"),
305
- ModelFamily::ClipLike
306
- );
307
- assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
308
- }
309
-
310
- #[test]
311
- fn tune_num_threads_respects_requested_value() {
312
- assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
313
- }
314
-
315
- #[test]
316
- fn tune_num_threads_returns_ort_default_for_other_family() {
317
- assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
318
- }
319
- }
320
-
321
- fn output_basename(name: &str) -> &str {
322
- name.rsplit('/').next().unwrap_or(name)
323
- }
324
-
325
- fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
326
- let output = session
327
- .outputs
328
- .iter()
329
- .find(|o| o.name == output_tensor)
330
- .ok_or_else(|| {
331
- GteError::Inference(format!(
332
- "output tensor '{}' not found in model outputs",
333
- output_tensor
334
- ))
335
- })?;
336
-
337
- let ndims = match &output.output_type {
338
- ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
339
- other => {
340
- return Err(GteError::Inference(format!(
341
- "output is not a tensor: {:?}",
342
- other
343
- )))
344
- }
345
- };
346
-
347
- match (output_basename(output_tensor), ndims) {
348
- ("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
349
- (_, 2) => Ok(ExtractorMode::Raw),
350
- (_, 3) => Ok(ExtractorMode::MeanPool),
351
- (_, n) => Err(GteError::Inference(format!(
352
- "unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
353
- n, output_tensor
354
- ))),
355
- }
356
122
  }
357
123
 
358
124
  pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
data/ext/gte/src/lib.rs CHANGED
@@ -1,7 +1,10 @@
1
1
  pub mod embedder;
2
2
  pub mod error;
3
3
  pub mod model_config;
4
+ pub mod model_profile;
5
+ pub mod pipeline;
4
6
  pub mod postprocess;
7
+ pub mod reranker;
5
8
  pub mod session;
6
9
  pub mod tokenizer;
7
10