gte 0.0.7 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 29659e3ab6072d858b1710a779c3d2e5981f7749782182d141ccd5e9790a1fbb
4
- data.tar.gz: c42d51cfa1a2ba6a2e83249e8a725c978b11c7ef80c6d69f09a64e884be42031
3
+ metadata.gz: 2c754b4675ee105e9a280cd9deafa00a81b9e02ee629131f3e908400006b6ae4
4
+ data.tar.gz: 40a0d3e04c3d2943ae50910164d644ecb763eac99a02044dc962cc141a0e13c5
5
5
  SHA512:
6
- metadata.gz: ff2c2b1450a6e82c07aacd2ec98437f03678d56eef9c5516f904021a54f59b2ba5c42b8f6af22b5c4b2dacea98615b99bc54d2c7cdc4e8fbccc1abc195fe9975
7
- data.tar.gz: 04ca056458d40e2ba7fabcdbcab415a087d54802fb3bd86748dc901c2cf0ecb44072fd1820a73e3dcaca097f165df3e70bab747b38340cd738876af5f0ea7645
6
+ metadata.gz: 16614e01e7a33a53339ba9fe7cf32fe7606041518a24177258d7a6e5550516e8cff741d0f0df02b7e5863fc763c02ae81b943dc4b18295701a4cafdec6627cb0
7
+ data.tar.gz: 348e1fd1d9f4c44214b5101ba339109b5ececfbef18b48b7c11324a64481f476d8da831cc5148d17a85c41b525ee753c296d4421a4fb2adda269a3f5fe38cda6
data/README.md CHANGED
@@ -33,8 +33,8 @@ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
33
  config.with(normalize: false)
34
34
  end
35
35
 
36
- full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
- config.with(threads: 0)
36
+ single_thread = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
+ config.with(threads: 1)
38
38
  end
39
39
 
40
40
  custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
@@ -50,7 +50,7 @@ end
50
50
  Config fields and defaults:
51
51
 
52
52
  - `model_dir`: absolute path to model directory
53
- - `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
53
+ - `threads`: `1` (default tuned for p95 latency; use `0` for ONNX Runtime auto-thread mode)
54
54
  - `optimization_level`: `3`
55
55
  - `model_name`: `nil`
56
56
  - `normalize`: `true` (L2 normalization at Ruby-facing API)
@@ -68,7 +68,7 @@ Low-level embedder setup (without model cache):
68
68
 
69
69
  ```ruby
70
70
  embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
71
- config.with(threads: 0, execution_providers: "cpu")
71
+ config.with(threads: 1, execution_providers: "cpu")
72
72
  end
73
73
  ```
74
74
 
@@ -78,7 +78,7 @@ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
78
78
 
79
79
  ```ruby
80
80
  reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
81
- config.with(sigmoid: true, threads: 0)
81
+ config.with(sigmoid: true, threads: 1)
82
82
  end
83
83
 
84
84
  query = "how to train a neural network?"
@@ -102,7 +102,7 @@ ranked = reranker.rerank(query: query, candidates: candidates)
102
102
  Reranker config fields and defaults:
103
103
 
104
104
  - `model_dir`: absolute path to model directory
105
- - `threads`: `3`
105
+ - `threads`: `1`
106
106
  - `optimization_level`: `3`
107
107
  - `model_name`: `nil`
108
108
  - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
@@ -174,7 +174,7 @@ make ci
174
174
 
175
175
  ## Benchmark
176
176
 
177
- The repo includes two benchmark paths:
177
+ The repo includes a shared multi-runtime benchmark harness:
178
178
 
179
179
  ```bash
180
180
  make bench
@@ -183,6 +183,11 @@ nix develop -c bundle exec rake bench:matrix_sweep
183
183
  nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
184
184
  ```
185
185
 
186
+ - `make bench`: Puma-like single-request comparison at concurrency `16`
187
+ - `rake bench:pure_compare`: batch amortization comparison
188
+ - `rake bench:matrix_sweep`: GTE provider/thread sweep using the shared result schema
189
+ - Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
190
+
186
191
  To run benchmark + append a `RUNS.md` entry + enforce goal checks:
187
192
 
188
193
  ```bash
@@ -191,5 +196,5 @@ make bench-record
191
196
 
192
197
  `bench/runs_ledger.rb check` is goal-focused by default:
193
198
 
194
- - Enforces goal metric (`response_time_p95` ratio threshold).
199
+ - Enforces the goal metric (`response_time_p95`) across every enabled competitor.
195
200
  - Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
data/Rakefile CHANGED
@@ -10,17 +10,52 @@ rescue LoadError
10
10
  end
11
11
 
12
12
  spec = Gem::Specification.load('gte.gemspec')
13
+ cross_target = ENV.fetch('RUBY_TARGET', nil)
13
14
 
14
- Rake::ExtensionTask.new('gte', spec) do |ext|
15
+ if cross_target == 'arm64-darwin'
16
+ # rb-sys-dock's darwin image can expose an unusable default LIBRARY_PATH.
17
+ # Force the compiler-rt darwin runtime directory so -lclang_rt.osx resolves.
18
+ ENV['LIBRARY_PATH'] = '/usr/lib/llvm-10/lib/clang/10.0.0/lib/darwin'
19
+ end
20
+
21
+ extension_task = Rake::ExtensionTask.new('gte', spec) do |ext|
15
22
  ext.lib_dir = 'lib/gte'
16
23
  ext.cross_compile = true
17
- ext.cross_platform = %w[x86_64-linux arm64-darwin]
24
+ # rb-sys-dock invokes `rake native:$RUBY_TARGET gem` without the `cross` task,
25
+ # so scope platforms during dock builds to avoid host-Ruby fallback copy tasks.
26
+ cross_platforms = if cross_target && !cross_target.empty?
27
+ [cross_target]
28
+ else
29
+ %w[x86_64-linux aarch64-linux arm64-darwin]
30
+ end
31
+ ext.cross_platform = cross_platforms
32
+ end
33
+
34
+ if cross_target && !cross_target.empty? && ENV['RUBY_CC_VERSION']
35
+ ruby_version = ENV['RUBY_CC_VERSION'].split(':').first
36
+ lib_binary_path = File.join(extension_task.lib_dir, File.basename(extension_task.binary(cross_target)))
37
+ copy_task = "copy:gte:#{cross_target}:#{ruby_version}"
38
+
39
+ if Rake::Task.task_defined?(lib_binary_path) && Rake::Task.task_defined?(copy_task)
40
+ Rake::Task[lib_binary_path].prerequisites.clear
41
+ Rake::Task[lib_binary_path].enhance([copy_task])
42
+ end
18
43
  end
19
44
 
20
45
  task default: %i[compile spec]
21
46
 
47
+ def bundler_env
48
+ root = File.expand_path(__dir__)
49
+ {
50
+ 'BUNDLE_DISABLE_SHARED_GEMS' => '1',
51
+ 'GEM_HOME' => File.join(root, '.bundle-gems'),
52
+ 'GEM_PATH' => File.join(root, '.bundle-gems'),
53
+ 'BUNDLE_PATH' => File.join(root, 'vendor/bundle')
54
+ }
55
+ end
56
+
22
57
  def run_in_nix(*command)
23
- sh('nix', 'develop', '-c', *command)
58
+ sh(bundler_env, 'nix', 'develop', '-c', *command)
24
59
  end
25
60
 
26
61
  namespace :bench do
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.7
1
+ 0.0.8
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.7"
3
+ version = "0.0.8"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -21,10 +21,10 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
21
21
  [dependencies]
22
22
  rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
23
23
  magnus = { version = "0.8", optional = true }
24
- ort = { version = "=2.0.0-rc.9", features = ["ndarray"] }
25
- ort-sys = "=2.0.0-rc.9"
24
+ ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
25
+ ort-sys = "=2.0.0-rc.12"
26
26
  tokenizers = "0.21.0"
27
- ndarray = "0.16.0"
27
+ ndarray = "0.17"
28
28
  half = "2"
29
29
  serde = { version = "1", features = ["derive"] }
30
30
  serde_json = "1"
@@ -5,15 +5,14 @@ use crate::model_profile::{
5
5
  resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
- use crate::session::{build_session, run_session};
8
+ use crate::session::{build_session, run_session, SessionPool};
9
9
  use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
10
10
  use ndarray::Array2;
11
- use ort::session::Session;
12
- use std::path::Path;
11
+ use std::path::{Path, PathBuf};
13
12
 
14
13
  pub struct Embedder {
15
14
  tokenizer: Tokenizer,
16
- session: Session,
15
+ pool: SessionPool,
17
16
  config: ModelConfig,
18
17
  }
19
18
 
@@ -30,12 +29,10 @@ impl Embedder {
30
29
  config.padding_mode,
31
30
  None,
32
31
  )?;
33
- let session = build_session(model_path, &config)?;
34
- Ok(Self {
35
- tokenizer,
36
- session,
37
- config,
38
- })
32
+ let model_path = model_path.as_ref().to_path_buf();
33
+ let session = build_session(&model_path, &config)?;
34
+ let pool = SessionPool::new(session, model_path, config.clone());
35
+ Ok(Self { tokenizer, pool, config })
39
36
  }
40
37
 
41
38
  pub fn from_dir<P: AsRef<Path>>(
@@ -53,7 +50,7 @@ impl Embedder {
53
50
 
54
51
  let dir = dir.as_ref();
55
52
  let tokenizer_path = resolve_tokenizer_path(dir)?;
56
- let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
53
+ let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
57
54
  Some(name) => resolve_named_model(dir, name)?,
58
55
  None => resolve_default_text_model(dir)?,
59
56
  };
@@ -117,11 +114,8 @@ impl Embedder {
117
114
  tokenizer_profile.fixed_padding_length,
118
115
  )?;
119
116
 
120
- Ok(Self {
121
- tokenizer,
122
- session,
123
- config,
124
- })
117
+ let pool = SessionPool::new(session, model_path, session_config);
118
+ Ok(Self { tokenizer, pool, config })
125
119
  }
126
120
 
127
121
  pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
@@ -134,7 +128,8 @@ impl Embedder {
134
128
  }
135
129
 
136
130
  pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
137
- run_session(&self.session, tokenized, &self.config)
131
+ let mut session = self.pool.acquire()?;
132
+ run_session(&mut session, tokenized, &self.config)
138
133
  }
139
134
  }
140
135
 
@@ -133,11 +133,9 @@ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Opt
133
133
  }
134
134
 
135
135
  pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
136
- let unsupported: Vec<String> = session
137
- .inputs
138
- .iter()
139
- .filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
140
- .map(|i| i.name.clone())
136
+ let unsupported: Vec<String> = session.inputs().iter()
137
+ .filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
138
+ .map(|i| i.name().to_owned())
141
139
  .collect();
142
140
 
143
141
  if unsupported.is_empty() {
@@ -160,7 +158,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
160
158
  }
161
159
 
162
160
  pub fn has_input(session: &Session, name: &str) -> bool {
163
- session.inputs.iter().any(|input| input.name == name)
161
+ session.inputs().iter().any(|input| input.name() == name)
164
162
  }
165
163
 
166
164
  fn output_name_matches(name: &str, preferred: &str) -> bool {
@@ -175,16 +173,16 @@ pub fn select_output_tensor(
175
173
  ) -> Result<String> {
176
174
  if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
177
175
  if let Some(output) = session
178
- .outputs
176
+ .outputs()
179
177
  .iter()
180
- .find(|o| output_name_matches(o.name.as_str(), requested_name))
178
+ .find(|o| output_name_matches(o.name(), requested_name))
181
179
  {
182
- return Ok(output.name.clone());
180
+ return Ok(output.name().to_owned());
183
181
  }
184
182
  let available = session
185
- .outputs
183
+ .outputs()
186
184
  .iter()
187
- .map(|o| o.name.as_str())
185
+ .map(|o| o.name())
188
186
  .collect::<Vec<_>>()
189
187
  .join(", ");
190
188
  return Err(GteError::Inference(format!(
@@ -195,18 +193,18 @@ pub fn select_output_tensor(
195
193
 
196
194
  for preferred in preferred_outputs {
197
195
  if let Some(output) = session
198
- .outputs
196
+ .outputs()
199
197
  .iter()
200
- .find(|o| output_name_matches(o.name.as_str(), preferred))
198
+ .find(|o| output_name_matches(o.name(), preferred))
201
199
  {
202
- return Ok(output.name.clone());
200
+ return Ok(output.name().to_owned());
203
201
  }
204
202
  }
205
203
 
206
204
  session
207
- .outputs
205
+ .outputs()
208
206
  .first()
209
- .map(|o| o.name.clone())
207
+ .map(|o| o.name().to_owned())
210
208
  .ok_or_else(|| GteError::Inference("model has no outputs".into()))
211
209
  }
212
210
 
@@ -216,9 +214,9 @@ fn output_basename(name: &str) -> &str {
216
214
 
217
215
  pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
218
216
  let output = session
219
- .outputs
217
+ .outputs()
220
218
  .iter()
221
- .find(|o| o.name == output_tensor)
219
+ .find(|o| o.name() == output_tensor)
222
220
  .ok_or_else(|| {
223
221
  GteError::Inference(format!(
224
222
  "output tensor '{}' not found in model outputs",
@@ -226,8 +224,8 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
226
224
  ))
227
225
  })?;
228
226
 
229
- let ndims = match &output.output_type {
230
- ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
227
+ let ndims = match output.dtype() {
228
+ ort::value::ValueType::Tensor { shape, .. } => shape.len(),
231
229
  other => {
232
230
  return Err(GteError::Inference(format!(
233
231
  "output is not a tensor: {:?}",
@@ -1,8 +1,8 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::tokenizer::Tokenized;
3
- use ndarray::ArrayView2;
3
+ use ndarray::{ArrayView2, ArrayViewD};
4
4
  use ort::session::SessionInputValue;
5
- use ort::value::Value;
5
+ use ort::value::TensorRef;
6
6
 
7
7
  pub struct InputTensors<'a> {
8
8
  pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
@@ -23,13 +23,13 @@ impl<'a> InputTensors<'a> {
23
23
  let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
24
24
  inputs.push((
25
25
  "input_ids",
26
- SessionInputValue::from(Value::from_array(input_ids_view)?),
26
+ SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
27
27
  ));
28
28
 
29
29
  if with_attention_mask {
30
30
  inputs.push((
31
31
  "attention_mask",
32
- SessionInputValue::from(Value::from_array(attention_mask)?),
32
+ SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
33
33
  ));
34
34
  }
35
35
 
@@ -38,7 +38,7 @@ impl<'a> InputTensors<'a> {
38
38
  ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
39
39
  inputs.push((
40
40
  "token_type_ids",
41
- SessionInputValue::from(Value::from_array(type_ids_view)?),
41
+ SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
42
42
  ));
43
43
  }
44
44
 
@@ -50,11 +50,14 @@ impl<'a> InputTensors<'a> {
50
50
  }
51
51
 
52
52
  pub fn extract_output_tensor<'a>(
53
- outputs: &'a ort::session::SessionOutputs<'a, 'a>,
53
+ outputs: &'a ort::session::SessionOutputs<'_>,
54
54
  output_name: &str,
55
- ) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
55
+ ) -> Result<ArrayViewD<'a, f32>> {
56
56
  let tensor_value = outputs.get(output_name).ok_or_else(|| {
57
- GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
57
+ GteError::Inference(format!(
58
+ "output tensor '{}' not found in model outputs",
59
+ output_name
60
+ ))
58
61
  })?;
59
- Ok(tensor_value.try_extract_tensor::<f32>()?.into())
62
+ Ok(tensor_value.try_extract_array::<f32>()?)
60
63
  }
@@ -6,10 +6,9 @@ use crate::model_profile::{
6
6
  };
7
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
8
8
  use crate::postprocess::sigmoid_scores;
9
- use crate::session::build_session;
9
+ use crate::session::{build_session, SessionPool};
10
10
  use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
- use ort::session::Session;
12
- use std::path::Path;
11
+ use std::path::{Path, PathBuf};
13
12
 
14
13
  #[derive(Debug, Clone)]
15
14
  struct RerankerConfig {
@@ -22,7 +21,7 @@ struct RerankerConfig {
22
21
 
23
22
  pub struct Reranker {
24
23
  tokenizer: Tokenizer,
25
- session: Session,
24
+ pool: SessionPool,
26
25
  config: RerankerConfig,
27
26
  }
28
27
 
@@ -35,7 +34,7 @@ impl Reranker {
35
34
  ) -> Result<Self> {
36
35
  let dir = dir.as_ref();
37
36
  let tokenizer_path = resolve_tokenizer_path(dir)?;
38
- let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
37
+ let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
39
38
  Some(name) => resolve_named_model(dir, name)?,
40
39
  None => resolve_default_text_model(dir)?,
41
40
  };
@@ -88,11 +87,8 @@ impl Reranker {
88
87
  tokenizer_profile.fixed_padding_length,
89
88
  )?;
90
89
 
91
- Ok(Self {
92
- tokenizer,
93
- session,
94
- config,
95
- })
90
+ let pool = SessionPool::new(session, model_path, probe_config);
91
+ Ok(Self { tokenizer, pool, config })
96
92
  }
97
93
 
98
94
  pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
@@ -111,7 +107,8 @@ impl Reranker {
111
107
  apply_sigmoid: bool,
112
108
  ) -> Result<Vec<f32>> {
113
109
  let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
114
- let outputs = self.session.run(input_tensors.inputs)?;
110
+ let mut session = self.pool.acquire()?;
111
+ let outputs = session.run(input_tensors.inputs).map_err(|e| GteError::Ort(e.to_string()))?;
115
112
  let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
116
113
 
117
114
  let mut scores = match array.ndim() {
@@ -28,21 +28,24 @@ pub struct RbTensor {
28
28
  data: Vec<f32>,
29
29
  }
30
30
 
31
+ // ---------------------------------------------------------------------------
32
+ // GVL-release helpers
33
+ // ---------------------------------------------------------------------------
34
+
31
35
  struct InferArgs {
32
36
  embedder: *const Embedder,
33
37
  texts: *const Vec<String>,
34
38
  normalize: bool,
35
- result: Option<Result<ndarray::Array2<f32>, GteError>>,
39
+ result: Option<crate::error::Result<ndarray::Array2<f32>>>,
36
40
  }
37
41
 
38
42
  unsafe impl Send for InferArgs {}
39
43
 
40
44
  struct ScoreArgs {
41
45
  reranker: *const Reranker,
42
- query: *const String,
43
- candidates: *const Vec<String>,
46
+ pairs: *const Vec<(String, String)>,
44
47
  apply_sigmoid: bool,
45
- result: Option<Result<Vec<f32>, GteError>>,
48
+ result: Option<crate::error::Result<Vec<f32>>>,
46
49
  }
47
50
 
48
51
  unsafe impl Send for ScoreArgs {}
@@ -57,6 +60,38 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
57
60
  }
58
61
  }
59
62
 
63
+ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
+ let args = &mut *(ptr as *mut InferArgs);
65
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
66
+ let tokenized = (*args.embedder).tokenize(&*args.texts)?;
67
+ let embeddings = (*args.embedder).run(&tokenized)?;
68
+ if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
69
+ }));
70
+ args.result = Some(match run_result {
71
+ Ok(result) => result,
72
+ Err(payload) => Err(GteError::Inference(format!(
73
+ "panic during inference: {}",
74
+ panic_payload_to_string(payload),
75
+ ))),
76
+ });
77
+ std::ptr::null_mut()
78
+ }
79
+
80
+ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
81
+ let args = &mut *(ptr as *mut ScoreArgs);
82
+ let run_result = catch_unwind(AssertUnwindSafe(|| {
83
+ (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)
84
+ }));
85
+ args.result = Some(match run_result {
86
+ Ok(result) => result,
87
+ Err(payload) => Err(GteError::Inference(format!(
88
+ "panic during reranking: {}",
89
+ panic_payload_to_string(payload),
90
+ ))),
91
+ });
92
+ std::ptr::null_mut()
93
+ }
94
+
60
95
  fn infer_without_gvl(
61
96
  embedder: &Arc<Embedder>,
62
97
  normalize: bool,
@@ -87,15 +122,13 @@ fn infer_without_gvl(
87
122
 
88
123
  fn score_without_gvl(
89
124
  reranker: &Arc<Reranker>,
90
- query: String,
91
- candidates: Vec<String>,
125
+ pairs: Vec<(String, String)>,
92
126
  apply_sigmoid: bool,
93
127
  ) -> Result<Vec<f32>, Error> {
94
128
  let scores = unsafe {
95
129
  let mut args = ScoreArgs {
96
130
  reranker: Arc::as_ptr(reranker),
97
- query: &query as *const String,
98
- candidates: &candidates as *const Vec<String>,
131
+ pairs: &pairs as *const Vec<(String, String)>,
99
132
  apply_sigmoid,
100
133
  result: None,
101
134
  };
@@ -115,41 +148,7 @@ fn score_without_gvl(
115
148
  Ok(scores)
116
149
  }
117
150
 
118
- unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
119
- let args = &mut *(ptr as *mut InferArgs);
120
- let run_result = catch_unwind(AssertUnwindSafe(|| {
121
- let tokenized = (*args.embedder).tokenize(&*args.texts)?;
122
- let embeddings = (*args.embedder).run(&tokenized)?;
123
- if args.normalize {
124
- Ok(normalize_l2(embeddings))
125
- } else {
126
- Ok(embeddings)
127
- }
128
- }));
129
- args.result = Some(match run_result {
130
- Ok(result) => result,
131
- Err(payload) => Err(GteError::Inference(format!(
132
- "panic during inference: {}",
133
- panic_payload_to_string(payload),
134
- ))),
135
- });
136
- std::ptr::null_mut()
137
- }
138
-
139
- unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
140
- let args = &mut *(ptr as *mut ScoreArgs);
141
- let run_result = catch_unwind(AssertUnwindSafe(|| {
142
- (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
143
- }));
144
- args.result = Some(match run_result {
145
- Ok(result) => result,
146
- Err(payload) => Err(GteError::Inference(format!(
147
- "panic during reranking: {}",
148
- panic_payload_to_string(payload),
149
- ))),
150
- });
151
- std::ptr::null_mut()
152
- }
151
+ // ---------------------------------------------------------------------------
153
152
 
154
153
  fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
155
154
  let rows = embeddings.nrows();
@@ -177,31 +176,11 @@ impl RbEmbedder {
177
176
  padding: String,
178
177
  execution_providers: String,
179
178
  ) -> Result<Self, Error> {
180
- let name = if model_name.is_empty() {
181
- None
182
- } else {
183
- Some(model_name.as_str())
184
- };
185
- let output_override = if output_tensor.is_empty() {
186
- None
187
- } else {
188
- Some(output_tensor.as_str())
189
- };
190
- let max_length_override = if max_length == 0 {
191
- None
192
- } else {
193
- Some(max_length)
194
- };
195
- let execution_providers_override = if execution_providers.is_empty() {
196
- None
197
- } else {
198
- Some(execution_providers.as_str())
199
- };
200
- let padding_override = if padding.is_empty() {
201
- None
202
- } else {
203
- Some(padding.as_str())
204
- };
179
+ let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
180
+ let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
181
+ let max_length_override = if max_length == 0 { None } else { Some(max_length) };
182
+ let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
183
+ let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
205
184
  let overrides = ModelLoadOverrides {
206
185
  model_name: name,
207
186
  output_tensor: output_override,
@@ -209,17 +188,9 @@ impl RbEmbedder {
209
188
  padding: padding_override,
210
189
  execution_providers: execution_providers_override,
211
190
  };
212
- let embedder = Embedder::from_dir(
213
- &dir_path,
214
- num_threads,
215
- optimization_level,
216
- overrides,
217
- )
218
- .map_err(magnus::Error::from)?;
219
- Ok(RbEmbedder {
220
- inner: Arc::new(embedder),
221
- normalize,
222
- })
191
+ let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, overrides)
192
+ .map_err(magnus::Error::from)?;
193
+ Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
223
194
  }
224
195
 
225
196
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
@@ -247,31 +218,11 @@ impl RbReranker {
247
218
  padding: String,
248
219
  execution_providers: String,
249
220
  ) -> Result<Self, Error> {
250
- let name = if model_name.is_empty() {
251
- None
252
- } else {
253
- Some(model_name.as_str())
254
- };
255
- let output_override = if output_tensor.is_empty() {
256
- None
257
- } else {
258
- Some(output_tensor.as_str())
259
- };
260
- let max_length_override = if max_length == 0 {
261
- None
262
- } else {
263
- Some(max_length)
264
- };
265
- let execution_providers_override = if execution_providers.is_empty() {
266
- None
267
- } else {
268
- Some(execution_providers.as_str())
269
- };
270
- let padding_override = if padding.is_empty() {
271
- None
272
- } else {
273
- Some(padding.as_str())
274
- };
221
+ let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
222
+ let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
223
+ let max_length_override = if max_length == 0 { None } else { Some(max_length) };
224
+ let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
225
+ let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
275
226
  let overrides = ModelLoadOverrides {
276
227
  model_name: name,
277
228
  output_tensor: output_override,
@@ -279,17 +230,9 @@ impl RbReranker {
279
230
  padding: padding_override,
280
231
  execution_providers: execution_providers_override,
281
232
  };
282
- let reranker = Reranker::from_dir(
283
- &dir_path,
284
- num_threads,
285
- optimization_level,
286
- overrides,
287
- )
288
- .map_err(magnus::Error::from)?;
289
- Ok(RbReranker {
290
- inner: Arc::new(reranker),
291
- sigmoid,
292
- })
233
+ let reranker = Reranker::from_dir(&dir_path, num_threads, optimization_level, overrides)
234
+ .map_err(magnus::Error::from)?;
235
+ Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
293
236
  }
294
237
 
295
238
  pub fn rb_score(
@@ -299,8 +242,8 @@ impl RbReranker {
299
242
  candidates: RArray,
300
243
  ) -> Result<RArray, Error> {
301
244
  let candidates: Vec<String> = candidates.to_vec()?;
302
- let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
303
-
245
+ let pairs: Vec<(String, String)> = candidates.into_iter().map(|c| (query.clone(), c)).collect();
246
+ let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
304
247
  let out = ruby.ary_new_capa(scores.len());
305
248
  for score in scores {
306
249
  out.push(score)?;
@@ -336,7 +279,6 @@ impl RbTensor {
336
279
  index, rb_self.rows
337
280
  ))));
338
281
  }
339
-
340
282
  let start = index * rb_self.cols;
341
283
  let end = start + rb_self.cols;
342
284
  let out = ruby.ary_new_capa(rb_self.cols);
@@ -361,7 +303,6 @@ impl RbTensor {
361
303
  index, rb_self.rows
362
304
  ))));
363
305
  }
364
-
365
306
  let start = index * rb_self.cols;
366
307
  let end = start + rb_self.cols;
367
308
  let bytes = unsafe {
@@ -3,12 +3,14 @@ use crate::model_config::{ExtractorMode, ModelConfig};
3
3
  use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
- use ndarray::{Array2, Ix2};
6
+ use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
7
  use ort::execution_providers::{
8
8
  CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
9
  };
10
10
  use ort::session::Session;
11
- use std::path::Path;
11
+ use std::path::{Path, PathBuf};
12
+ use std::sync::atomic::{AtomicUsize, Ordering};
13
+ use std::sync::{Condvar, Mutex};
12
14
 
13
15
  pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
14
16
  let opt_level = match config.optimization_level {
@@ -18,22 +20,176 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
18
20
  _ => ort::session::builder::GraphOptimizationLevel::Level3,
19
21
  };
20
22
 
21
- let mut builder = Session::builder()?
22
- .with_optimization_level(opt_level)?
23
- .with_memory_pattern(true)?;
23
+ fn ort_err(e: impl std::fmt::Display) -> GteError {
24
+ GteError::Ort(e.to_string())
25
+ }
26
+
27
+ let mut builder = Session::builder()
28
+ .map_err(ort_err)?
29
+ .with_optimization_level(opt_level)
30
+ .map_err(ort_err)?
31
+ .with_memory_pattern(true)
32
+ .map_err(ort_err)?;
24
33
 
25
34
  let providers = preferred_execution_providers(config.execution_providers.as_deref());
26
35
  if !providers.is_empty() {
27
- builder = builder.with_execution_providers(providers)?;
36
+ builder = builder
37
+ .with_execution_providers(providers)
38
+ .map_err(ort_err)?;
28
39
  }
29
40
 
30
41
  if config.num_threads > 0 {
31
- builder = builder.with_intra_threads(config.num_threads)?;
42
+ builder = builder
43
+ .with_intra_threads(config.num_threads)
44
+ .map_err(ort_err)?;
45
+ builder = builder
46
+ .with_inter_threads(config.num_threads)
47
+ .map_err(ort_err)?;
48
+ }
49
+
50
+ builder.commit_from_file(model_path).map_err(ort_err)
51
+ }
52
+
53
+ // ---------------------------------------------------------------------------
54
+ // Session pool
55
+ // ---------------------------------------------------------------------------
56
+
57
+ const AUTO_THREAD_POOL_CAP: usize = 6;
58
+
59
+ /// Keep enough sessions to cover the configured thread budget without
60
+ /// oversubscribing CPU parallelism. In ORT auto-thread mode (`num_threads == 0`)
61
+ /// we still keep a modest pool because request-level concurrency benefits from
62
+ /// more than one session even when ORT manages thread counts internally.
63
+ fn pool_capacity(num_threads: usize) -> usize {
64
+ let available_parallelism = std::thread::available_parallelism()
65
+ .map(|n| n.get())
66
+ .unwrap_or(1);
67
+ pool_capacity_with_parallelism(num_threads, available_parallelism)
68
+ }
69
+
70
+ fn pool_capacity_with_parallelism(num_threads: usize, available_parallelism: usize) -> usize {
71
+ if available_parallelism == 0 {
72
+ return 1;
73
+ }
74
+
75
+ if num_threads == 0 {
76
+ return available_parallelism.clamp(1, AUTO_THREAD_POOL_CAP);
77
+ }
78
+
79
+ available_parallelism.div_ceil(num_threads).max(1)
80
+ }
81
+
82
+ pub struct SessionPool {
83
+ sessions: Mutex<Vec<Session>>,
84
+ available: Condvar,
85
+ created: AtomicUsize,
86
+ capacity: usize,
87
+ model_path: PathBuf,
88
+ build_config: ModelConfig,
89
+ }
90
+
91
+ impl SessionPool {
92
+ pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
93
+ let capacity = pool_capacity(build_config.num_threads);
94
+ Self {
95
+ sessions: Mutex::new(vec![initial]),
96
+ available: Condvar::new(),
97
+ created: AtomicUsize::new(1),
98
+ capacity,
99
+ model_path,
100
+ build_config,
101
+ }
102
+ }
103
+
104
+ pub fn acquire(&self) -> Result<PooledSession<'_>> {
105
+ if let Some(session) = self.take_available_session() {
106
+ return Ok(PooledSession {
107
+ pool: self,
108
+ session: Some(session),
109
+ });
110
+ }
111
+
112
+ if let Some(session) = self.try_grow()? {
113
+ return Ok(PooledSession {
114
+ pool: self,
115
+ session: Some(session),
116
+ });
117
+ }
118
+
119
+ let session = self.wait_for_session();
120
+ Ok(PooledSession {
121
+ pool: self,
122
+ session: Some(session),
123
+ })
124
+ }
125
+
126
+ fn release(&self, session: Session) {
127
+ self.sessions.lock().unwrap().push(session);
128
+ self.available.notify_one();
129
+ }
130
+
131
+ fn take_available_session(&self) -> Option<Session> {
132
+ self.sessions.lock().unwrap().pop()
133
+ }
134
+
135
+ fn try_grow(&self) -> Result<Option<Session>> {
136
+ let grew = self
137
+ .created
138
+ .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
139
+ (count < self.capacity).then_some(count + 1)
140
+ });
141
+ if grew.is_err() {
142
+ return Ok(None);
143
+ }
144
+
145
+ match build_session(&self.model_path, &self.build_config) {
146
+ Ok(session) => Ok(Some(session)),
147
+ Err(error) => {
148
+ self.created.fetch_sub(1, Ordering::AcqRel);
149
+ Err(error)
150
+ }
151
+ }
152
+ }
153
+
154
+ fn wait_for_session(&self) -> Session {
155
+ let mut lock = self.sessions.lock().unwrap();
156
+ loop {
157
+ if let Some(session) = lock.pop() {
158
+ return session;
159
+ }
160
+ lock = self.available.wait(lock).unwrap();
161
+ }
162
+ }
163
+ }
164
+
165
+ pub struct PooledSession<'a> {
166
+ pool: &'a SessionPool,
167
+ session: Option<Session>,
168
+ }
169
+
170
+ impl std::ops::Deref for PooledSession<'_> {
171
+ type Target = Session;
172
+ fn deref(&self) -> &Session {
173
+ self.session.as_ref().unwrap()
32
174
  }
175
+ }
33
176
 
34
- Ok(builder.commit_from_file(model_path)?)
177
+ impl std::ops::DerefMut for PooledSession<'_> {
178
+ fn deref_mut(&mut self) -> &mut Session {
179
+ self.session.as_mut().unwrap()
180
+ }
35
181
  }
36
182
 
183
+ impl Drop for PooledSession<'_> {
184
+ fn drop(&mut self) {
185
+ if let Some(s) = self.session.take() {
186
+ self.pool.release(s);
187
+ }
188
+ }
189
+ }
190
+
191
+ // ---------------------------------------------------------------------------
192
+
37
193
  fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
38
194
  let order = resolve_provider_order(order_override);
39
195
 
@@ -55,7 +211,10 @@ fn resolve_provider_order(order_override: Option<&str>) -> String {
55
211
  resolve_provider_order_with_env(order_override, env_order.as_deref())
56
212
  }
57
213
 
58
- fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
214
+ fn resolve_provider_order_with_env(
215
+ order_override: Option<&str>,
216
+ env_order: Option<&str>,
217
+ ) -> String {
59
218
  order_override
60
219
  .or(env_order)
61
220
  .unwrap_or("cpu")
@@ -75,14 +234,24 @@ fn parse_provider_registrations(order: &str) -> Vec<&str> {
75
234
  }
76
235
 
77
236
  pub fn run_session(
78
- session: &Session,
237
+ session: &mut Session,
79
238
  tokenized: &Tokenized,
80
239
  config: &ModelConfig,
81
240
  ) -> Result<Array2<f32>> {
82
241
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
83
- let outputs = session.run(input_tensors.inputs)?;
242
+ let outputs = session
243
+ .run(input_tensors.inputs)
244
+ .map_err(|e| GteError::Ort(e.to_string()))?;
84
245
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
85
246
 
247
+ extract_embeddings(array, input_tensors.attention_mask, config)
248
+ }
249
+
250
+ fn extract_embeddings(
251
+ array: ArrayViewD<'_, f32>,
252
+ attention_mask: ArrayView2<'_, i64>,
253
+ config: &ModelConfig,
254
+ ) -> Result<Array2<f32>> {
86
255
  match config.mode {
87
256
  ExtractorMode::Token(idx) => {
88
257
  let shape = array.shape();
@@ -102,15 +271,43 @@ pub fn run_session(
102
271
  ndim
103
272
  ))
104
273
  })?;
105
- mean_pool(hidden_states.view(), input_tensors.attention_mask)
274
+ mean_pool(hidden_states, attention_mask)
106
275
  }
107
- ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
276
+ ExtractorMode::Raw => array
277
+ .into_dimensionality::<Ix2>()
278
+ .map(|view| view.to_owned())
279
+ .map_err(|e| GteError::Shape(e.to_string())),
108
280
  }
109
281
  }
110
282
 
111
283
  #[cfg(test)]
112
284
  mod tests {
113
- use super::{parse_provider_registrations, resolve_provider_order_with_env};
285
+ use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
286
+ use ndarray::{array, ArrayView2};
287
+
288
+ use super::{
289
+ extract_embeddings, parse_provider_registrations, pool_capacity_with_parallelism,
290
+ resolve_provider_order_with_env,
291
+ };
292
+
293
+ fn test_config(mode: ExtractorMode) -> ModelConfig {
294
+ ModelConfig {
295
+ max_length: 8,
296
+ padding_mode: PaddingMode::BatchLongest,
297
+ output_tensor: "output".to_string(),
298
+ mode,
299
+ with_type_ids: false,
300
+ with_attention_mask: true,
301
+ num_threads: 1,
302
+ optimization_level: 3,
303
+ execution_providers: None,
304
+ }
305
+ }
306
+
307
+ fn empty_attention_mask() -> ArrayView2<'static, i64> {
308
+ static EMPTY: [i64; 0] = [];
309
+ ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
310
+ }
114
311
 
115
312
  #[test]
116
313
  fn parse_provider_registrations_keeps_supported_order() {
@@ -142,7 +339,74 @@ mod tests {
142
339
 
143
340
  #[test]
144
341
  fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
145
- assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
342
+ assert_eq!(
343
+ resolve_provider_order_with_env(None, Some("coreml")),
344
+ "coreml"
345
+ );
146
346
  assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
147
347
  }
348
+
349
+ #[test]
350
+ fn pool_capacity_uses_bounded_parallel_pool_for_auto_thread_mode() {
351
+ assert_eq!(pool_capacity_with_parallelism(0, 1), 1);
352
+ assert_eq!(pool_capacity_with_parallelism(0, 4), 4);
353
+ assert_eq!(pool_capacity_with_parallelism(0, 8), 6);
354
+ }
355
+
356
+ #[test]
357
+ fn pool_capacity_scales_with_available_parallelism() {
358
+ assert_eq!(pool_capacity_with_parallelism(1, 1), 1);
359
+ assert_eq!(pool_capacity_with_parallelism(1, 8), 8);
360
+ assert_eq!(pool_capacity_with_parallelism(2, 8), 4);
361
+ assert_eq!(pool_capacity_with_parallelism(3, 8), 3);
362
+ assert_eq!(pool_capacity_with_parallelism(8, 4), 1);
363
+ }
364
+
365
+ #[test]
366
+ fn extract_embeddings_raw_copies_only_final_matrix() {
367
+ let output = array![[1.0f32, 2.0], [3.0, 4.0]];
368
+ let extracted = extract_embeddings(
369
+ output.view().into_dyn(),
370
+ empty_attention_mask(),
371
+ &test_config(ExtractorMode::Raw),
372
+ )
373
+ .unwrap();
374
+
375
+ assert_eq!(extracted, output);
376
+ }
377
+
378
+ #[test]
379
+ fn extract_embeddings_token_selects_without_copying_full_sequence() {
380
+ let output = array![
381
+ [[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
382
+ [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
383
+ ];
384
+ let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
385
+ let extracted = extract_embeddings(
386
+ output.view().into_dyn(),
387
+ empty_attention_mask(),
388
+ &test_config(ExtractorMode::Token(1)),
389
+ )
390
+ .unwrap();
391
+
392
+ assert_eq!(extracted, expected);
393
+ }
394
+
395
+ #[test]
396
+ fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
397
+ let output = array![
398
+ [[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
399
+ [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
400
+ ];
401
+ let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
402
+ let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
403
+ let extracted = extract_embeddings(
404
+ output.view().into_dyn(),
405
+ attention_mask.view(),
406
+ &test_config(ExtractorMode::MeanPool),
407
+ )
408
+ .unwrap();
409
+
410
+ assert_eq!(extracted, expected);
411
+ }
148
412
  }
data/lib/gte/embedder.rb CHANGED
@@ -2,6 +2,9 @@
2
2
 
3
3
  module GTE
4
4
  class Embedder
5
+ DEFAULT_THREADS = 1
6
+ DEFAULT_OPTIMIZATION_LEVEL = 3
7
+
5
8
  class << self
6
9
  def config(model_dir)
7
10
  cfg = default_config(model_dir)
@@ -23,13 +26,11 @@ module GTE
23
26
  )
24
27
  end
25
28
 
26
- private
27
-
28
29
  def default_config(model_dir)
29
30
  Config::Text.new(
30
31
  model_dir: File.expand_path(model_dir),
31
- threads: 3,
32
- optimization_level: 3,
32
+ threads: DEFAULT_THREADS,
33
+ optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
33
34
  model_name: nil,
34
35
  normalize: true,
35
36
  output_tensor: nil,
data/lib/gte/reranker.rb CHANGED
@@ -19,7 +19,7 @@ module GTE
19
19
  def default_config(model_dir)
20
20
  Config::Reranker.new(
21
21
  model_dir: File.expand_path(model_dir),
22
- threads: 3,
22
+ threads: 1,
23
23
  optimization_level: 3,
24
24
  model_name: nil,
25
25
  sigmoid: false,
data/lib/gte.rb CHANGED
@@ -19,17 +19,7 @@ module GTE
19
19
 
20
20
  class << self
21
21
  def config(model_dir)
22
- cfg = Config::Text.new(
23
- model_dir: File.expand_path(model_dir),
24
- threads: 3,
25
- optimization_level: 3,
26
- model_name: nil,
27
- normalize: true,
28
- output_tensor: nil,
29
- max_length: nil,
30
- padding: nil,
31
- execution_providers: nil
32
- )
22
+ cfg = Embedder.default_config(model_dir)
33
23
 
34
24
  cfg = yield(cfg) if block_given?
35
25
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.7
4
+ version: 0.0.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2026-04-16 00:00:00.000000000 Z
11
+ date: 2026-04-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -42,16 +42,16 @@ dependencies:
42
42
  name: rb_sys
43
43
  requirement: !ruby/object:Gem::Requirement
44
44
  requirements:
45
- - - ">="
45
+ - - '='
46
46
  - !ruby/object:Gem::Version
47
- version: '0'
47
+ version: 0.9.126
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
- - - ">="
52
+ - - '='
53
53
  - !ruby/object:Gem::Version
54
- version: '0'
54
+ version: 0.9.126
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: rspec
57
57
  requirement: !ruby/object:Gem::Requirement