red-candle 1.3.1 → 1.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +24 -33
- data/ext/candle/Cargo.toml +1 -1
- data/ext/candle/src/lib.rs +4 -2
- data/ext/candle/src/llm/gemma.rs +1 -1
- data/ext/candle/src/llm/llama.rs +1 -1
- data/ext/candle/src/llm/mistral.rs +1 -1
- data/ext/candle/src/llm/phi.rs +1 -1
- data/ext/candle/src/llm/quantized_gguf.rs +1 -1
- data/ext/candle/src/llm/qwen.rs +1 -1
- data/ext/candle/src/ruby/device.rs +8 -7
- data/ext/candle/src/ruby/dtype.rs +3 -2
- data/ext/candle/src/ruby/embedding_model.rs +31 -14
- data/ext/candle/src/ruby/errors.rs +6 -4
- data/ext/candle/src/ruby/llm.rs +78 -68
- data/ext/candle/src/ruby/ner.rs +106 -95
- data/ext/candle/src/ruby/reranker.rs +51 -38
- data/ext/candle/src/ruby/structured.rs +13 -12
- data/ext/candle/src/ruby/tensor.rs +7 -6
- data/ext/candle/src/ruby/tokenizer.rs +101 -84
- data/ext/candle/src/ruby/utils.rs +26 -0
- data/lib/candle/version.rb +1 -1
- data/lib/candle.rb +1 -1
- metadata +45 -6
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 56022d9cb3677fa07bedae0272b5855374e2a6a5af863e055b546e1a0a9ddd5d
|
|
4
|
+
data.tar.gz: b027603292ba34d1cf75460e74edf1d2472ea5651f8690d86adc50dc7000bdcc
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: cbc00f433b91bddc9d117e836488f1991d42a4e7ca8c9968949adfe671277a22e36a2a65f26b7cd17036c24c6e24d6a45ce1f4684f6281811d5b25f5830b0558
|
|
7
|
+
data.tar.gz: 2970cd8e5d7d59c6d6be5e69219da7c518962e12b59f2c18f2d4fa806a50d264198b1c027f82263e185987e2f75d0fd4ee81194ed272a257fdd804575756b82c
|
data/Cargo.lock
CHANGED
|
@@ -167,7 +167,7 @@ dependencies = [
|
|
|
167
167
|
"bitflags 2.9.4",
|
|
168
168
|
"cexpr",
|
|
169
169
|
"clang-sys",
|
|
170
|
-
"itertools 0.
|
|
170
|
+
"itertools 0.11.0",
|
|
171
171
|
"lazy_static",
|
|
172
172
|
"lazycell",
|
|
173
173
|
"proc-macro2",
|
|
@@ -300,9 +300,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
|
|
300
300
|
|
|
301
301
|
[[package]]
|
|
302
302
|
name = "bytes"
|
|
303
|
-
version = "1.
|
|
303
|
+
version = "1.11.1"
|
|
304
304
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
305
|
-
checksum = "
|
|
305
|
+
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
|
|
306
306
|
|
|
307
307
|
[[package]]
|
|
308
308
|
name = "candle"
|
|
@@ -819,7 +819,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
|
819
819
|
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
|
820
820
|
dependencies = [
|
|
821
821
|
"libc",
|
|
822
|
-
"windows-sys 0.
|
|
822
|
+
"windows-sys 0.59.0",
|
|
823
823
|
]
|
|
824
824
|
|
|
825
825
|
[[package]]
|
|
@@ -1750,15 +1750,6 @@ dependencies = [
|
|
|
1750
1750
|
"either",
|
|
1751
1751
|
]
|
|
1752
1752
|
|
|
1753
|
-
[[package]]
|
|
1754
|
-
name = "itertools"
|
|
1755
|
-
version = "0.12.1"
|
|
1756
|
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1757
|
-
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
|
1758
|
-
dependencies = [
|
|
1759
|
-
"either",
|
|
1760
|
-
]
|
|
1761
|
-
|
|
1762
1753
|
[[package]]
|
|
1763
1754
|
name = "itertools"
|
|
1764
1755
|
version = "0.13.0"
|
|
@@ -1890,9 +1881,9 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
|
|
|
1890
1881
|
|
|
1891
1882
|
[[package]]
|
|
1892
1883
|
name = "magnus"
|
|
1893
|
-
version = "0.
|
|
1884
|
+
version = "0.8.2"
|
|
1894
1885
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1895
|
-
checksum = "
|
|
1886
|
+
checksum = "3b36a5b126bbe97eb0d02d07acfeb327036c6319fd816139a49824a83b7f9012"
|
|
1896
1887
|
dependencies = [
|
|
1897
1888
|
"magnus-macros",
|
|
1898
1889
|
"rb-sys",
|
|
@@ -1902,9 +1893,9 @@ dependencies = [
|
|
|
1902
1893
|
|
|
1903
1894
|
[[package]]
|
|
1904
1895
|
name = "magnus-macros"
|
|
1905
|
-
version = "0.
|
|
1896
|
+
version = "0.8.0"
|
|
1906
1897
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1907
|
-
checksum = "
|
|
1898
|
+
checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892"
|
|
1908
1899
|
dependencies = [
|
|
1909
1900
|
"proc-macro2",
|
|
1910
1901
|
"quote",
|
|
@@ -2477,9 +2468,9 @@ dependencies = [
|
|
|
2477
2468
|
|
|
2478
2469
|
[[package]]
|
|
2479
2470
|
name = "quinn-proto"
|
|
2480
|
-
version = "0.11.
|
|
2471
|
+
version = "0.11.14"
|
|
2481
2472
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2482
|
-
checksum = "
|
|
2473
|
+
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
|
|
2483
2474
|
dependencies = [
|
|
2484
2475
|
"bytes",
|
|
2485
2476
|
"getrandom 0.3.3",
|
|
@@ -2507,7 +2498,7 @@ dependencies = [
|
|
|
2507
2498
|
"once_cell",
|
|
2508
2499
|
"socket2",
|
|
2509
2500
|
"tracing",
|
|
2510
|
-
"windows-sys 0.
|
|
2501
|
+
"windows-sys 0.59.0",
|
|
2511
2502
|
]
|
|
2512
2503
|
|
|
2513
2504
|
[[package]]
|
|
@@ -2656,18 +2647,18 @@ dependencies = [
|
|
|
2656
2647
|
|
|
2657
2648
|
[[package]]
|
|
2658
2649
|
name = "rb-sys"
|
|
2659
|
-
version = "0.9.
|
|
2650
|
+
version = "0.9.124"
|
|
2660
2651
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2661
|
-
checksum = "
|
|
2652
|
+
checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489"
|
|
2662
2653
|
dependencies = [
|
|
2663
2654
|
"rb-sys-build",
|
|
2664
2655
|
]
|
|
2665
2656
|
|
|
2666
2657
|
[[package]]
|
|
2667
2658
|
name = "rb-sys-build"
|
|
2668
|
-
version = "0.9.
|
|
2659
|
+
version = "0.9.124"
|
|
2669
2660
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2670
|
-
checksum = "
|
|
2661
|
+
checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3"
|
|
2671
2662
|
dependencies = [
|
|
2672
2663
|
"bindgen 0.69.5",
|
|
2673
2664
|
"lazy_static",
|
|
@@ -2680,9 +2671,9 @@ dependencies = [
|
|
|
2680
2671
|
|
|
2681
2672
|
[[package]]
|
|
2682
2673
|
name = "rb-sys-env"
|
|
2683
|
-
version = "0.
|
|
2674
|
+
version = "0.2.3"
|
|
2684
2675
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2685
|
-
checksum = "
|
|
2676
|
+
checksum = "cca7ad6a7e21e72151d56fe2495a259b5670e204c3adac41ee7ef676ea08117a"
|
|
2686
2677
|
|
|
2687
2678
|
[[package]]
|
|
2688
2679
|
name = "reborrow"
|
|
@@ -2828,7 +2819,7 @@ dependencies = [
|
|
|
2828
2819
|
"errno",
|
|
2829
2820
|
"libc",
|
|
2830
2821
|
"linux-raw-sys",
|
|
2831
|
-
"windows-sys 0.
|
|
2822
|
+
"windows-sys 0.59.0",
|
|
2832
2823
|
]
|
|
2833
2824
|
|
|
2834
2825
|
[[package]]
|
|
@@ -2859,9 +2850,9 @@ dependencies = [
|
|
|
2859
2850
|
|
|
2860
2851
|
[[package]]
|
|
2861
2852
|
name = "rustls-webpki"
|
|
2862
|
-
version = "0.103.
|
|
2853
|
+
version = "0.103.10"
|
|
2863
2854
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2864
|
-
checksum = "
|
|
2855
|
+
checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef"
|
|
2865
2856
|
dependencies = [
|
|
2866
2857
|
"aws-lc-rs",
|
|
2867
2858
|
"ring",
|
|
@@ -3204,9 +3195,9 @@ dependencies = [
|
|
|
3204
3195
|
|
|
3205
3196
|
[[package]]
|
|
3206
3197
|
name = "tar"
|
|
3207
|
-
version = "0.4.
|
|
3198
|
+
version = "0.4.45"
|
|
3208
3199
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
3209
|
-
checksum = "
|
|
3200
|
+
checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973"
|
|
3210
3201
|
dependencies = [
|
|
3211
3202
|
"filetime",
|
|
3212
3203
|
"libc",
|
|
@@ -3223,7 +3214,7 @@ dependencies = [
|
|
|
3223
3214
|
"getrandom 0.3.3",
|
|
3224
3215
|
"once_cell",
|
|
3225
3216
|
"rustix",
|
|
3226
|
-
"windows-sys 0.
|
|
3217
|
+
"windows-sys 0.59.0",
|
|
3227
3218
|
]
|
|
3228
3219
|
|
|
3229
3220
|
[[package]]
|
|
@@ -3897,7 +3888,7 @@ version = "0.1.11"
|
|
|
3897
3888
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
3898
3889
|
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
|
3899
3890
|
dependencies = [
|
|
3900
|
-
"windows-sys 0.
|
|
3891
|
+
"windows-sys 0.48.0",
|
|
3901
3892
|
]
|
|
3902
3893
|
|
|
3903
3894
|
[[package]]
|
data/ext/candle/Cargo.toml
CHANGED
|
@@ -15,7 +15,7 @@ candle-transformers = { version = "0.9.1" }
|
|
|
15
15
|
tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
|
|
16
16
|
hf-hub = "0.4.1"
|
|
17
17
|
half = "2.6.0"
|
|
18
|
-
magnus = "0.
|
|
18
|
+
magnus = "0.8"
|
|
19
19
|
safetensors = "0.3"
|
|
20
20
|
serde_json = "1.0"
|
|
21
21
|
serde = { version = "1.0", features = ["derive"] }
|
data/ext/candle/src/lib.rs
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
use magnus::{function, prelude::*, Ruby};
|
|
2
2
|
|
|
3
3
|
use crate::ruby::candle_utils;
|
|
4
|
+
use crate::ruby::utils::ensure_hf_cache_dir;
|
|
4
5
|
use crate::ruby::Result;
|
|
5
6
|
|
|
6
7
|
pub mod llm;
|
|
@@ -22,19 +23,20 @@ const DEFAULT_DEVICE: &str = "cpu";
|
|
|
22
23
|
pub fn get_build_info() -> magnus::RHash {
|
|
23
24
|
let ruby = magnus::Ruby::get().unwrap();
|
|
24
25
|
let hash = ruby.hash_new();
|
|
25
|
-
|
|
26
|
+
|
|
26
27
|
let _ = hash.aset("default_device", DEFAULT_DEVICE);
|
|
27
28
|
let _ = hash.aset("cuda_available", cfg!(feature = "cuda"));
|
|
28
29
|
let _ = hash.aset("metal_available", cfg!(feature = "metal"));
|
|
29
30
|
let _ = hash.aset("mkl_available", cfg!(feature = "mkl"));
|
|
30
31
|
let _ = hash.aset("accelerate_available", cfg!(feature = "accelerate"));
|
|
31
32
|
let _ = hash.aset("cudnn_available", cfg!(feature = "cudnn"));
|
|
32
|
-
|
|
33
|
+
|
|
33
34
|
hash
|
|
34
35
|
}
|
|
35
36
|
|
|
36
37
|
#[magnus::init]
|
|
37
38
|
fn init(ruby: &Ruby) -> Result<()> {
|
|
39
|
+
ensure_hf_cache_dir();
|
|
38
40
|
let rb_candle = ruby.define_module("Candle")?;
|
|
39
41
|
|
|
40
42
|
// Export build info
|
data/ext/candle/src/llm/gemma.rs
CHANGED
|
@@ -34,7 +34,7 @@ impl Gemma {
|
|
|
34
34
|
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
39
39
|
|
|
40
40
|
// Download model files
|
data/ext/candle/src/llm/llama.rs
CHANGED
|
@@ -41,7 +41,7 @@ impl Llama {
|
|
|
41
41
|
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
42
42
|
let api = Api::new()
|
|
43
43
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
44
|
-
|
|
44
|
+
|
|
45
45
|
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
46
46
|
|
|
47
47
|
// Download model files
|
|
@@ -34,7 +34,7 @@ impl Mistral {
|
|
|
34
34
|
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
39
39
|
|
|
40
40
|
// Download model files
|
data/ext/candle/src/llm/phi.rs
CHANGED
|
@@ -42,7 +42,7 @@ impl Phi {
|
|
|
42
42
|
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
43
43
|
let api = Api::new()
|
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
let repo = api.model(model_id.to_string());
|
|
47
47
|
|
|
48
48
|
// Download configuration
|
data/ext/candle/src/llm/qwen.rs
CHANGED
|
@@ -34,7 +34,7 @@ impl Qwen {
|
|
|
34
34
|
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
let repo = api.model(model_id.to_string());
|
|
39
39
|
|
|
40
40
|
// Download configuration
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
use magnus::Error;
|
|
2
|
-
use magnus::{function, method,
|
|
2
|
+
use magnus::{function, method, RModule, Module, Object, Ruby};
|
|
3
3
|
|
|
4
4
|
use ::candle_core::Device as CoreDevice;
|
|
5
5
|
use crate::ruby::Result;
|
|
@@ -101,7 +101,7 @@ impl Device {
|
|
|
101
101
|
#[cfg(not(feature = "cuda"))]
|
|
102
102
|
{
|
|
103
103
|
return Err(Error::new(
|
|
104
|
-
|
|
104
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
105
105
|
"CUDA support not compiled in. Rebuild with CUDA available.",
|
|
106
106
|
));
|
|
107
107
|
}
|
|
@@ -115,7 +115,7 @@ impl Device {
|
|
|
115
115
|
#[cfg(not(feature = "metal"))]
|
|
116
116
|
{
|
|
117
117
|
return Err(Error::new(
|
|
118
|
-
|
|
118
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
119
119
|
"Metal support not compiled in. Rebuild on macOS.",
|
|
120
120
|
));
|
|
121
121
|
}
|
|
@@ -139,7 +139,7 @@ impl Device {
|
|
|
139
139
|
#[cfg(not(feature = "cuda"))]
|
|
140
140
|
{
|
|
141
141
|
return Err(Error::new(
|
|
142
|
-
|
|
142
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
143
143
|
"CUDA support not compiled in. Rebuild with CUDA available.",
|
|
144
144
|
));
|
|
145
145
|
}
|
|
@@ -161,7 +161,7 @@ impl Device {
|
|
|
161
161
|
#[cfg(not(feature = "metal"))]
|
|
162
162
|
{
|
|
163
163
|
return Err(Error::new(
|
|
164
|
-
|
|
164
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
165
165
|
"Metal support not compiled in. Rebuild on macOS.",
|
|
166
166
|
));
|
|
167
167
|
}
|
|
@@ -211,14 +211,15 @@ impl magnus::TryConvert for Device {
|
|
|
211
211
|
"cpu" => Device::Cpu,
|
|
212
212
|
"cuda" => Device::Cuda,
|
|
213
213
|
"metal" => Device::Metal,
|
|
214
|
-
_ => return Err(Error::new(
|
|
214
|
+
_ => return Err(Error::new(Ruby::get().unwrap().exception_arg_error(), "invalid device")),
|
|
215
215
|
};
|
|
216
216
|
Ok(device)
|
|
217
217
|
}
|
|
218
218
|
}
|
|
219
219
|
|
|
220
220
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
221
|
-
let
|
|
221
|
+
let ruby = Ruby::get().unwrap();
|
|
222
|
+
let rb_device = rb_candle.define_class("Device", ruby.class_object())?;
|
|
222
223
|
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
|
223
224
|
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
|
224
225
|
rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
use magnus::value::ReprValue;
|
|
2
|
-
use magnus::{method,
|
|
2
|
+
use magnus::{method, RModule, Module, Ruby};
|
|
3
3
|
|
|
4
4
|
use ::candle_core::DType as CoreDType;
|
|
5
5
|
use crate::ruby::Result;
|
|
@@ -30,7 +30,8 @@ impl DType {
|
|
|
30
30
|
}
|
|
31
31
|
|
|
32
32
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
33
|
-
let
|
|
33
|
+
let ruby = Ruby::get().unwrap();
|
|
34
|
+
let rb_dtype = rb_candle.define_class("DType", ruby.class_object())?;
|
|
34
35
|
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
|
35
36
|
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|
|
36
37
|
Ok(())
|
|
@@ -13,7 +13,7 @@ use candle_transformers::models::{
|
|
|
13
13
|
jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
|
|
14
14
|
distilbert::{DistilBertModel, Config as DistilBertConfig}
|
|
15
15
|
};
|
|
16
|
-
use magnus::{
|
|
16
|
+
use magnus::{function, method, prelude::*, Error, RModule, RHash, Ruby};
|
|
17
17
|
use std::path::Path;
|
|
18
18
|
use serde_json;
|
|
19
19
|
|
|
@@ -103,28 +103,30 @@ impl EmbeddingModel {
|
|
|
103
103
|
/// &RETURNS&: Tensor
|
|
104
104
|
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
|
105
105
|
pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
|
|
106
|
+
let ruby = Ruby::get().unwrap();
|
|
106
107
|
match &self.0.model {
|
|
107
108
|
Some(model) => {
|
|
108
109
|
match &self.0.tokenizer {
|
|
109
110
|
Some(tokenizer) => Ok(Tensor(self.compute_embedding(input, model, tokenizer, &pooling_method)?)),
|
|
110
|
-
None => Err(magnus::Error::new(
|
|
111
|
+
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
|
|
111
112
|
}
|
|
112
113
|
}
|
|
113
|
-
None => Err(magnus::Error::new(
|
|
114
|
+
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
|
|
114
115
|
}
|
|
115
116
|
}
|
|
116
117
|
|
|
117
118
|
/// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
|
|
118
119
|
/// &RETURNS&: Tensor
|
|
119
120
|
pub fn embeddings(&self, input: String) -> Result<Tensor> {
|
|
121
|
+
let ruby = Ruby::get().unwrap();
|
|
120
122
|
match &self.0.model {
|
|
121
123
|
Some(model) => {
|
|
122
124
|
match &self.0.tokenizer {
|
|
123
125
|
Some(tokenizer) => Ok(Tensor(self.compute_embeddings(input, model, tokenizer)?)),
|
|
124
|
-
None => Err(magnus::Error::new(
|
|
126
|
+
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
|
|
125
127
|
}
|
|
126
128
|
}
|
|
127
|
-
None => Err(magnus::Error::new(
|
|
129
|
+
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
|
|
128
130
|
}
|
|
129
131
|
}
|
|
130
132
|
|
|
@@ -165,7 +167,10 @@ impl EmbeddingModel {
|
|
|
165
167
|
},
|
|
166
168
|
Err(_) => None
|
|
167
169
|
};
|
|
168
|
-
inferred_emb_dim.ok_or_else(||
|
|
170
|
+
inferred_emb_dim.ok_or_else(|| {
|
|
171
|
+
let ruby = Ruby::get().unwrap();
|
|
172
|
+
magnus::Error::new(ruby.exception_runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly.")
|
|
173
|
+
})
|
|
169
174
|
}
|
|
170
175
|
}
|
|
171
176
|
}
|
|
@@ -178,8 +183,9 @@ impl EmbeddingModel {
|
|
|
178
183
|
EmbeddingModelType::JinaBert => {
|
|
179
184
|
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
180
185
|
if !std::path::Path::new(&model_path).exists() {
|
|
186
|
+
let ruby = Ruby::get().unwrap();
|
|
181
187
|
return Err(magnus::Error::new(
|
|
182
|
-
|
|
188
|
+
ruby.exception_runtime_error(),
|
|
183
189
|
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
184
190
|
));
|
|
185
191
|
}
|
|
@@ -196,8 +202,9 @@ impl EmbeddingModel {
|
|
|
196
202
|
EmbeddingModelType::StandardBert => {
|
|
197
203
|
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
198
204
|
if !std::path::Path::new(&model_path).exists() {
|
|
205
|
+
let ruby = Ruby::get().unwrap();
|
|
199
206
|
return Err(magnus::Error::new(
|
|
200
|
-
|
|
207
|
+
ruby.exception_runtime_error(),
|
|
201
208
|
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
202
209
|
));
|
|
203
210
|
}
|
|
@@ -214,8 +221,9 @@ impl EmbeddingModel {
|
|
|
214
221
|
EmbeddingModelType::DistilBert => {
|
|
215
222
|
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
216
223
|
if !std::path::Path::new(&model_path).exists() {
|
|
224
|
+
let ruby = Ruby::get().unwrap();
|
|
217
225
|
return Err(magnus::Error::new(
|
|
218
|
-
|
|
226
|
+
ruby.exception_runtime_error(),
|
|
219
227
|
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
220
228
|
));
|
|
221
229
|
}
|
|
@@ -235,8 +243,9 @@ impl EmbeddingModel {
|
|
|
235
243
|
EmbeddingModelType::MiniLM => {
|
|
236
244
|
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
237
245
|
if !std::path::Path::new(&model_path).exists() {
|
|
246
|
+
let ruby = Ruby::get().unwrap();
|
|
238
247
|
return Err(magnus::Error::new(
|
|
239
|
-
|
|
248
|
+
ruby.exception_runtime_error(),
|
|
240
249
|
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
241
250
|
));
|
|
242
251
|
}
|
|
@@ -357,7 +366,10 @@ impl EmbeddingModel {
|
|
|
357
366
|
"pooled" => Self::pooled_embedding(&result),
|
|
358
367
|
"pooled_normalized" => Self::pooled_normalized_embedding(&result),
|
|
359
368
|
"cls" => Self::pooled_cls_embedding(&result),
|
|
360
|
-
_ =>
|
|
369
|
+
_ => {
|
|
370
|
+
let ruby = Ruby::get().unwrap();
|
|
371
|
+
Err(magnus::Error::new(ruby.exception_runtime_error(), "Unknown pooling method"))
|
|
372
|
+
},
|
|
361
373
|
}
|
|
362
374
|
}
|
|
363
375
|
|
|
@@ -390,7 +402,10 @@ impl EmbeddingModel {
|
|
|
390
402
|
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
|
391
403
|
match &self.0.tokenizer {
|
|
392
404
|
Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
|
|
393
|
-
None =>
|
|
405
|
+
None => {
|
|
406
|
+
let ruby = Ruby::get().unwrap();
|
|
407
|
+
Err(magnus::Error::new(ruby.exception_runtime_error(), "No tokenizer loaded for this model"))
|
|
408
|
+
}
|
|
394
409
|
}
|
|
395
410
|
}
|
|
396
411
|
|
|
@@ -409,7 +424,8 @@ impl EmbeddingModel {
|
|
|
409
424
|
|
|
410
425
|
/// Get all options as a hash
|
|
411
426
|
pub fn options(&self) -> Result<RHash> {
|
|
412
|
-
let
|
|
427
|
+
let ruby = Ruby::get().unwrap();
|
|
428
|
+
let hash = ruby.hash_new();
|
|
413
429
|
|
|
414
430
|
// Add model_id
|
|
415
431
|
if let Some(model_id) = &self.0.model_id {
|
|
@@ -439,7 +455,8 @@ impl EmbeddingModel {
|
|
|
439
455
|
}
|
|
440
456
|
|
|
441
457
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
442
|
-
let
|
|
458
|
+
let ruby = Ruby::get().unwrap();
|
|
459
|
+
let rb_embedding_model = rb_candle.define_class("EmbeddingModel", ruby.class_object())?;
|
|
443
460
|
rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
|
|
444
461
|
// Expose embedding with an optional pooling_method argument (default: "pooled")
|
|
445
462
|
rb_embedding_model.define_method("_embedding", method!(EmbeddingModel::embedding, 2))?;
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
use magnus::Error;
|
|
2
2
|
|
|
3
3
|
pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
|
|
4
|
-
|
|
4
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
5
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
5
6
|
}
|
|
6
7
|
|
|
7
8
|
pub fn wrap_candle_err(err: candle_core::Error) -> Error {
|
|
8
|
-
|
|
9
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
10
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
9
11
|
}
|
|
10
12
|
|
|
11
13
|
pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
|
|
12
|
-
|
|
14
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
15
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
13
16
|
}
|
|
14
|
-
|