gte 0.0.1 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +19 -1
- data/ext/gte/src/ruby_embedder.rs +8 -2
- data/ext/gte/src/session.rs +1 -6
- data/ext/gte/tests/inference_integration_test.rs +4 -4
- data/lib/gte.rb +9 -5
- metadata +2 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: b7ce34f894403d3d2767d9c7f694aa712b42af251b0babf741e2dcd9dd6c7a27
|
|
4
|
+
data.tar.gz: c91aa21b10b2a20358c5d56c511623927c6e4cd4e0667cc7f40cdca405a4d10f
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 87e824d3fa79dc67a9584b902d17329aa85eb4f8fc4a358a6350c7f19e3d4e3c170a59b852abd16332caada49106bbba3356b6a5486bbb52c97b8bef22b1b9a0
|
|
7
|
+
data.tar.gz: 0dfeb1f6b4223f7ee88609411b94548740b588d89a92b55ba7e093564417086f24a12ebbf98bfee6a9fbd4c74d0f55dc0d66c2a6095d0d7ad7d9b1adca1b2eb7
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.3
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -40,10 +40,14 @@ impl Embedder {
|
|
|
40
40
|
dir: P,
|
|
41
41
|
num_threads: usize,
|
|
42
42
|
optimization_level: u8,
|
|
43
|
+
model_name: Option<&str>,
|
|
43
44
|
) -> Result<Self> {
|
|
44
45
|
let dir = dir.as_ref();
|
|
45
46
|
let tokenizer_path = dir.join("tokenizer.json");
|
|
46
|
-
let model_path =
|
|
47
|
+
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
48
|
+
Some(name) => resolve_named_model(dir, name)?,
|
|
49
|
+
None => resolve_model_path(dir)?,
|
|
50
|
+
};
|
|
47
51
|
|
|
48
52
|
if !tokenizer_path.exists() {
|
|
49
53
|
return Err(GteError::Tokenizer(format!(
|
|
@@ -179,6 +183,20 @@ fn host_parallelism() -> usize {
|
|
|
179
183
|
.unwrap_or(1)
|
|
180
184
|
}
|
|
181
185
|
|
|
186
|
+
fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
187
|
+
let candidates = [dir.join("onnx").join(name), dir.join(name)];
|
|
188
|
+
for path in &candidates {
|
|
189
|
+
if path.exists() {
|
|
190
|
+
return Ok(path.clone());
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
Err(GteError::Inference(format!(
|
|
194
|
+
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
195
|
+
name,
|
|
196
|
+
dir.display()
|
|
197
|
+
)))
|
|
198
|
+
}
|
|
199
|
+
|
|
182
200
|
fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
|
|
183
201
|
let candidates = [
|
|
184
202
|
dir.join("onnx").join("text_model.onnx"),
|
|
@@ -96,8 +96,14 @@ impl RbEmbedder {
|
|
|
96
96
|
dir_path: String,
|
|
97
97
|
num_threads: usize,
|
|
98
98
|
optimization_level: u8,
|
|
99
|
+
model_name: String,
|
|
99
100
|
) -> Result<Self, Error> {
|
|
100
|
-
let
|
|
101
|
+
let name = if model_name.is_empty() {
|
|
102
|
+
None
|
|
103
|
+
} else {
|
|
104
|
+
Some(model_name.as_str())
|
|
105
|
+
};
|
|
106
|
+
let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, name)
|
|
101
107
|
.map_err(magnus::Error::from)?;
|
|
102
108
|
Ok(RbEmbedder {
|
|
103
109
|
inner: Arc::new(embedder),
|
|
@@ -202,7 +208,7 @@ impl RbTensor {
|
|
|
202
208
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
203
209
|
let module = ruby.define_module("GTE")?;
|
|
204
210
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
205
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
211
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 4))?;
|
|
206
212
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
207
213
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
208
214
|
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -36,13 +36,8 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
|
|
39
|
-
let default_providers = if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
|
|
40
|
-
"xnnpack,coreml"
|
|
41
|
-
} else {
|
|
42
|
-
"xnnpack"
|
|
43
|
-
};
|
|
44
39
|
let order = std::env::var("GTE_EXECUTION_PROVIDERS")
|
|
45
|
-
.unwrap_or_else(|_|
|
|
40
|
+
.unwrap_or_else(|_| "xnnpack".to_string())
|
|
46
41
|
.to_ascii_lowercase();
|
|
47
42
|
|
|
48
43
|
let mut providers = Vec::new();
|
|
@@ -5,7 +5,7 @@ use gte::embedder::Embedder;
|
|
|
5
5
|
fn test_e5_single_embedding_shape() {
|
|
6
6
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
7
|
|
|
8
|
-
let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
|
|
8
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
|
|
9
9
|
let result = embedder
|
|
10
10
|
.embed(vec!["query: Hello world".to_string()])
|
|
11
11
|
.expect("embed should succeed");
|
|
@@ -19,7 +19,7 @@ fn test_e5_single_embedding_shape() {
|
|
|
19
19
|
fn test_clip_single_embedding_shape() {
|
|
20
20
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
21
21
|
|
|
22
|
-
let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
|
|
22
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
|
|
23
23
|
let result = embedder
|
|
24
24
|
.embed(vec!["a photo of a cat".to_string()])
|
|
25
25
|
.expect("embed should succeed");
|
|
@@ -33,7 +33,7 @@ fn test_clip_single_embedding_shape() {
|
|
|
33
33
|
fn test_e5_batch_embedding_shape() {
|
|
34
34
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
35
35
|
|
|
36
|
-
let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
|
|
36
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
|
|
37
37
|
let texts = vec![
|
|
38
38
|
"query: first sentence".to_string(),
|
|
39
39
|
"query: second sentence".to_string(),
|
|
@@ -51,7 +51,7 @@ fn test_e5_batch_embedding_shape() {
|
|
|
51
51
|
fn test_e5_long_input_truncation_no_error() {
|
|
52
52
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
53
53
|
|
|
54
|
-
let embedder = Embedder::from_dir(DIR, 0, 3).expect("embedder should initialize");
|
|
54
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None).expect("embedder should initialize");
|
|
55
55
|
let very_long_text = "word ".repeat(1000);
|
|
56
56
|
let result = embedder
|
|
57
57
|
.embed(vec![very_long_text])
|
data/lib/gte.rb
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
begin
|
|
4
|
+
require "gte/#{RUBY_VERSION.to_f}/gte"
|
|
5
|
+
rescue LoadError
|
|
6
|
+
require 'gte/gte'
|
|
7
|
+
end
|
|
4
8
|
|
|
5
9
|
module GTE
|
|
6
10
|
VERSION = File.read(File.expand_path('../VERSION', __dir__)).strip
|
|
7
11
|
|
|
8
12
|
class Model
|
|
9
|
-
def initialize(dir, num_threads: 0, optimization_level: 3)
|
|
10
|
-
@embedder = GTE::Embedder.new(dir, num_threads, optimization_level)
|
|
13
|
+
def initialize(dir, num_threads: 0, optimization_level: 3, model_name: nil)
|
|
14
|
+
@embedder = GTE::Embedder.new(dir, num_threads, optimization_level, model_name.to_s)
|
|
11
15
|
end
|
|
12
16
|
|
|
13
17
|
def embed(texts)
|
|
@@ -26,7 +30,7 @@ module GTE
|
|
|
26
30
|
end
|
|
27
31
|
end
|
|
28
32
|
|
|
29
|
-
def self.new(dir, num_threads: 0, optimization_level: 3)
|
|
30
|
-
Model.new(dir, num_threads: num_threads, optimization_level: optimization_level)
|
|
33
|
+
def self.new(dir, num_threads: 0, optimization_level: 3, model_name: nil)
|
|
34
|
+
Model.new(dir, num_threads: num_threads, optimization_level: optimization_level, model_name: model_name)
|
|
31
35
|
end
|
|
32
36
|
end
|
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.3
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
@@ -45,7 +45,7 @@ dependencies:
|
|
|
45
45
|
- - ">="
|
|
46
46
|
- !ruby/object:Gem::Version
|
|
47
47
|
version: '0'
|
|
48
|
-
type: :
|
|
48
|
+
type: :runtime
|
|
49
49
|
prerelease: false
|
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
|
51
51
|
requirements:
|