gte 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 566cb32a193255c0cf3d087a1a907cdbf0a96292ee58d4676fd18745d55ec1b2
4
- data.tar.gz: 56479bb218282bf189ace46852a42766d4c2211230be635b6a79414cf3eb82c8
3
+ metadata.gz: 46dc6316dd04fd0a15bd0f0be71ecb1dca4f177612da7e2f7992a3030aa51e88
4
+ data.tar.gz: d05a7dc0230a9b2c670cd3001afad0724ba8543ec9de16f82d5d1abb4df73853
5
5
  SHA512:
6
- metadata.gz: f873b83b16e4cf2685f26b84d26f4c3b8abd90a9d32d5e60046d1aad577d409704838d3da1329f741535fef1f1a90f1edeeeb6d4fbab559569117869ae42677e
7
- data.tar.gz: d5d5b49b8f51cbf3222409b941fa39b3f00074bf34aafdfaf74d4a3fd37ab99ec8a90bba9be641efb42176b182fe191925042ae09d324c851324b27bd62031ce
6
+ metadata.gz: acdc81039070f6307548ea7438aeea69df4d9d5d6884f2046be1c62424d4529fa7ca10c7de7b6c60ca150f388dcda88936ec5a04bb34f67e022bf96a78682b73
7
+ data.tar.gz: fc4052f0c7d4b99c5c9bfa9d04f44c9cecd52e32b496e176dc8a72305b856683e5081a40146b7cbc9cd9352ae7e42a173360f7165711994f594802dbe03022ee
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.1
1
+ 0.0.2
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.1"
3
+ version = "0.0.2"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -40,10 +40,14 @@ impl Embedder {
40
40
  dir: P,
41
41
  num_threads: usize,
42
42
  optimization_level: u8,
43
+ model_name: Option<&str>,
43
44
  ) -> Result<Self> {
44
45
  let dir = dir.as_ref();
45
46
  let tokenizer_path = dir.join("tokenizer.json");
46
- let model_path = resolve_model_path(dir)?;
47
+ let model_path = match model_name.filter(|s| !s.is_empty()) {
48
+ Some(name) => resolve_named_model(dir, name)?,
49
+ None => resolve_model_path(dir)?,
50
+ };
47
51
 
48
52
  if !tokenizer_path.exists() {
49
53
  return Err(GteError::Tokenizer(format!(
@@ -179,6 +183,20 @@ fn host_parallelism() -> usize {
179
183
  .unwrap_or(1)
180
184
  }
181
185
 
186
+ fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
187
+ let candidates = [dir.join("onnx").join(name), dir.join(name)];
188
+ for path in &candidates {
189
+ if path.exists() {
190
+ return Ok(path.clone());
191
+ }
192
+ }
193
+ Err(GteError::Inference(format!(
194
+ "model '{}' not found in {} (checked onnx/{0} and {0})",
195
+ name,
196
+ dir.display()
197
+ )))
198
+ }
199
+
182
200
  fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
183
201
  let candidates = [
184
202
  dir.join("onnx").join("text_model.onnx"),
@@ -96,8 +96,14 @@ impl RbEmbedder {
96
96
  dir_path: String,
97
97
  num_threads: usize,
98
98
  optimization_level: u8,
99
+ model_name: String,
99
100
  ) -> Result<Self, Error> {
100
- let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level)
101
+ let name = if model_name.is_empty() {
102
+ None
103
+ } else {
104
+ Some(model_name.as_str())
105
+ };
106
+ let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, name)
101
107
  .map_err(magnus::Error::from)?;
102
108
  Ok(RbEmbedder {
103
109
  inner: Arc::new(embedder),
@@ -202,7 +208,7 @@ impl RbTensor {
202
208
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
203
209
  let module = ruby.define_module("GTE")?;
204
210
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
205
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 3))?;
211
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 4))?;
206
212
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
207
213
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
208
214
 
@@ -36,13 +36,8 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
36
36
  }
37
37
 
38
38
  fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
39
- let default_providers = if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
40
- "xnnpack,coreml"
41
- } else {
42
- "xnnpack"
43
- };
44
39
  let order = std::env::var("GTE_EXECUTION_PROVIDERS")
45
- .unwrap_or_else(|_| default_providers.to_string())
40
+ .unwrap_or_else(|_| "xnnpack".to_string())
46
41
  .to_ascii_lowercase();
47
42
 
48
43
  let mut providers = Vec::new();
@@ -5,7 +5,7 @@ use gte::embedder::Embedder;
5
5
  fn test_e5_single_embedding_shape() {
6
6
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
7
 
8
- let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
8
+ let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
9
9
  let result = embedder
10
10
  .embed(vec!["query: Hello world".to_string()])
11
11
  .expect("embed should succeed");
@@ -19,7 +19,7 @@ fn test_e5_single_embedding_shape() {
19
19
  fn test_clip_single_embedding_shape() {
20
20
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
21
21
 
22
- let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
22
+ let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
23
23
  let result = embedder
24
24
  .embed(vec!["a photo of a cat".to_string()])
25
25
  .expect("embed should succeed");
@@ -33,7 +33,7 @@ fn test_clip_single_embedding_shape() {
33
33
  fn test_e5_batch_embedding_shape() {
34
34
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
35
35
 
36
- let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
36
+ let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
37
37
  let texts = vec![
38
38
  "query: first sentence".to_string(),
39
39
  "query: second sentence".to_string(),
@@ -51,7 +51,7 @@ fn test_e5_batch_embedding_shape() {
51
51
  fn test_e5_long_input_truncation_no_error() {
52
52
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
53
53
 
54
- let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
54
+ let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
55
55
  let very_long_text = "word ".repeat(1000);
56
56
  let result = embedder
57
57
  .embed(vec![very_long_text])
data/lib/gte.rb CHANGED
@@ -1,13 +1,17 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'gte/gte'
3
+ begin
4
+ require "gte/#{RUBY_VERSION.to_f}/gte"
5
+ rescue LoadError
6
+ require 'gte/gte'
7
+ end
4
8
 
5
9
  module GTE
6
10
  VERSION = File.read(File.expand_path('../VERSION', __dir__)).strip
7
11
 
8
12
  class Model
9
- def initialize(dir, num_threads: 0, optimization_level: 3)
10
- @embedder = GTE::Embedder.new(dir, num_threads, optimization_level)
13
+ def initialize(dir, num_threads: 0, optimization_level: 3, model_name: nil)
14
+ @embedder = GTE::Embedder.new(dir, num_threads, optimization_level, model_name.to_s)
11
15
  end
12
16
 
13
17
  def embed(texts)
@@ -26,7 +30,7 @@ module GTE
26
30
  end
27
31
  end
28
32
 
29
- def self.new(dir, num_threads: 0, optimization_level: 3)
30
- Model.new(dir, num_threads: num_threads, optimization_level: optimization_level)
33
+ def self.new(dir, num_threads: 0, optimization_level: 3, model_name: nil)
34
+ Model.new(dir, num_threads: num_threads, optimization_level: optimization_level, model_name: model_name)
31
35
  end
32
36
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo