red-candle 0.0.6 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,410 @@
1
+ // MKL and Accelerate are handled by candle-core when their features are enabled
2
+
3
+ use crate::ruby::{
4
+ errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
5
+ };
6
+ use crate::ruby::{Tensor, Device, Result as RbResult};
7
+ use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
8
+ use safetensors::tensor::SafeTensors;
9
+ use candle_nn::VarBuilder;
10
+ use candle_transformers::models::{
11
+ bert::{BertModel as StdBertModel, Config as BertConfig},
12
+ jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
13
+ distilbert::{DistilBertModel, Config as DistilBertConfig}
14
+ };
15
+ use magnus::{class, function, method, prelude::*, Error, RModule};
16
+ use std::path::Path;
17
+ use tokenizers::Tokenizer;
18
+ use serde_json;
19
+
20
+
21
+ #[magnus::wrap(class = "Candle::EmbeddingModel", free_immediately, size)]
22
+ pub struct EmbeddingModel(pub EmbeddingModelInner);
23
+
24
+ /// Supported model types for embedding generation
25
+ #[derive(Debug, Clone, Copy, PartialEq)]
26
+ pub enum EmbeddingModelType {
27
+ JinaBert,
28
+ StandardBert,
29
+ DistilBert,
30
+ MiniLM,
31
+ }
32
+
33
+ impl EmbeddingModelType {
34
+ pub fn from_string(model_type: &str) -> Option<Self> {
35
+ match model_type.to_lowercase().as_str() {
36
+ "jina_bert" | "jinabert" | "jina" => Some(EmbeddingModelType::JinaBert),
37
+ "bert" | "standard_bert" | "standardbert" => Some(EmbeddingModelType::StandardBert),
38
+ "minilm" => Some(EmbeddingModelType::MiniLM),
39
+
40
+ "distilbert" => Some(EmbeddingModelType::DistilBert),
41
+ _ => None
42
+ }
43
+ }
44
+ }
45
+
46
+ /// Model variants that can produce embeddings
47
+ pub enum EmbeddingModelVariant {
48
+ JinaBert(JinaBertModel),
49
+ StandardBert(StdBertModel),
50
+ DistilBert(DistilBertModel),
51
+ MiniLM(StdBertModel),
52
+
53
+ }
54
+
55
+ impl EmbeddingModelVariant {
56
+ pub fn embedding_model_type(&self) -> EmbeddingModelType {
57
+ match self {
58
+ EmbeddingModelVariant::JinaBert(_) => EmbeddingModelType::JinaBert,
59
+ EmbeddingModelVariant::StandardBert(_) => EmbeddingModelType::StandardBert,
60
+ EmbeddingModelVariant::DistilBert(_) => EmbeddingModelType::DistilBert,
61
+ EmbeddingModelVariant::MiniLM(_) => EmbeddingModelType::MiniLM,
62
+
63
+ }
64
+ }
65
+ }
66
+
67
+ pub struct EmbeddingModelInner {
68
+ device: CoreDevice,
69
+ tokenizer_path: Option<String>,
70
+ model_path: Option<String>,
71
+ embedding_model_type: Option<EmbeddingModelType>,
72
+ model: Option<EmbeddingModelVariant>,
73
+ tokenizer: Option<Tokenizer>,
74
+ embedding_size: Option<usize>,
75
+ }
76
+
77
+ impl EmbeddingModel {
78
+ pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> RbResult<Self> {
79
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
80
+ let embedding_model_type = embedding_model_type
81
+ .and_then(|mt| EmbeddingModelType::from_string(&mt))
82
+ .unwrap_or(EmbeddingModelType::JinaBert);
83
+ Ok(EmbeddingModel(EmbeddingModelInner {
84
+ device: device.clone(),
85
+ model_path: model_path.clone(),
86
+ tokenizer_path: tokenizer_path.clone(),
87
+ embedding_model_type: Some(embedding_model_type),
88
+ model: match model_path {
89
+ Some(mp) => Some(Self::build_embedding_model(Path::new(&mp), device, embedding_model_type, embedding_size)?),
90
+ None => None
91
+ },
92
+ tokenizer: match tokenizer_path {
93
+ Some(tp) => Some(Self::build_tokenizer(tp)?),
94
+ None => None
95
+ },
96
+ embedding_size,
97
+ }))
98
+ }
99
+
100
+ /// Generates an embedding vector for the input text
101
+ /// &RETURNS&: Tensor
102
+ /// Generates an embedding vector for the input text using the specified pooling method.
103
+ /// &RETURNS&: Tensor
104
+ /// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
105
+ pub fn embedding(&self, input: String, pooling_method: String) -> RbResult<Tensor> {
106
+ match &self.0.model {
107
+ Some(model) => {
108
+ match &self.0.tokenizer {
109
+ Some(tokenizer) => Ok(Tensor(self.compute_embedding(input, model, tokenizer, &pooling_method)?)),
110
+ None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
111
+ }
112
+ }
113
+ None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
114
+ }
115
+ }
116
+
117
+ /// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
118
+ /// &RETURNS&: Tensor
119
+ pub fn embeddings(&self, input: String) -> RbResult<Tensor> {
120
+ match &self.0.model {
121
+ Some(model) => {
122
+ match &self.0.tokenizer {
123
+ Some(tokenizer) => Ok(Tensor(self.compute_embeddings(input, model, tokenizer)?)),
124
+ None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
125
+ }
126
+ }
127
+ None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
128
+ }
129
+ }
130
+
131
+ /// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
132
+ /// &RETURNS&: Tensor
133
+ pub fn pool_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
134
+ let pooled = Self::pooled_embedding(&tensor.0)?;
135
+ Ok(Tensor(pooled))
136
+ }
137
+
138
+ /// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
139
+ /// &RETURNS&: Tensor
140
+ pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
141
+ let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
142
+ Ok(Tensor(pooled))
143
+ }
144
+
145
+ /// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
146
+ /// &RETURNS&: Tensor
147
+ pub fn pool_cls_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
148
+ let pooled = Self::pooled_cls_embedding(&tensor.0)?;
149
+ Ok(Tensor(pooled))
150
+ }
151
+
152
+ /// Infers and validates the embedding size from a safetensors file
153
+ fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> Result<usize, magnus::Error> {
154
+ match embedding_size {
155
+ Some(user_dim) => {
156
+ Ok(user_dim)
157
+ },
158
+ None => {
159
+ let inferred_emb_dim = match SafeTensors::deserialize(&std::fs::read(model_path).map_err(|e| wrap_std_err(Box::new(e)))?) {
160
+ Ok(st) => {
161
+ if let Some(tensor) = st.tensor("embeddings.word_embeddings.weight").ok() {
162
+ let shape = tensor.shape();
163
+ if shape.len() == 2 { Some(shape[1] as usize) } else { None }
164
+ } else { None }
165
+ },
166
+ Err(_) => None
167
+ };
168
+ inferred_emb_dim.ok_or_else(|| magnus::Error::new(magnus::exception::runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly."))
169
+ }
170
+ }
171
+ }
172
+
173
+ fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> RbResult<EmbeddingModelVariant> {
174
+ use hf_hub::{api::sync::Api, Repo, RepoType};
175
+ let api = Api::new().map_err(wrap_hf_err)?;
176
+ let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
177
+ match embedding_model_type {
178
+ EmbeddingModelType::JinaBert => {
179
+ let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
180
+ if !std::path::Path::new(&model_path).exists() {
181
+ return Err(magnus::Error::new(
182
+ magnus::exception::runtime_error(),
183
+ "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
184
+ ));
185
+ }
186
+ let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
187
+ let mut config = JinaConfig::v2_base();
188
+ config.hidden_size = final_emb_dim;
189
+ let vb = unsafe {
190
+ VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
191
+ .map_err(wrap_candle_err)?
192
+ };
193
+ let model = JinaBertModel::new(vb, &config).map_err(wrap_candle_err)?;
194
+ Ok(EmbeddingModelVariant::JinaBert(model))
195
+ },
196
+ EmbeddingModelType::StandardBert => {
197
+ let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
198
+ if !std::path::Path::new(&model_path).exists() {
199
+ return Err(magnus::Error::new(
200
+ magnus::exception::runtime_error(),
201
+ "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
202
+ ));
203
+ }
204
+ let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
205
+ let mut config = BertConfig::default();
206
+ config.hidden_size = final_emb_dim;
207
+ let vb = unsafe {
208
+ VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
209
+ .map_err(wrap_candle_err)?
210
+ };
211
+ let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
212
+ Ok(EmbeddingModelVariant::StandardBert(model))
213
+ },
214
+ EmbeddingModelType::DistilBert => {
215
+ let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
216
+ if !std::path::Path::new(&model_path).exists() {
217
+ return Err(magnus::Error::new(
218
+ magnus::exception::runtime_error(),
219
+ "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
220
+ ));
221
+ }
222
+ let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
223
+ let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
224
+ let mut config: DistilBertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
225
+ if let Some(embedding_size) = embedding_size {
226
+ config.dim = embedding_size;
227
+ }
228
+ let vb = unsafe {
229
+ VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
230
+ .map_err(wrap_candle_err)?
231
+ };
232
+ let model = DistilBertModel::load(vb, &config).map_err(wrap_candle_err)?;
233
+ Ok(EmbeddingModelVariant::DistilBert(model))
234
+ },
235
+ EmbeddingModelType::MiniLM => {
236
+ let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
237
+ if !std::path::Path::new(&model_path).exists() {
238
+ return Err(magnus::Error::new(
239
+ magnus::exception::runtime_error(),
240
+ "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
241
+ ));
242
+ }
243
+ let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
244
+ let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
245
+ let mut config: BertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
246
+ if let Some(embedding_size) = embedding_size {
247
+ config.hidden_size = embedding_size;
248
+ }
249
+ let vb = unsafe {
250
+ VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
251
+ .map_err(wrap_candle_err)?
252
+ };
253
+ let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
254
+ Ok(EmbeddingModelVariant::MiniLM(model))
255
+ },
256
+
257
+ }
258
+ }
259
+
260
+ fn build_tokenizer(tokenizer_path: String) -> RbResult<Tokenizer> {
261
+ use hf_hub::{api::sync::Api, Repo, RepoType};
262
+ let tokenizer_path = Api::new()
263
+ .map_err(wrap_hf_err)?
264
+ .repo(Repo::new(
265
+ tokenizer_path,
266
+ RepoType::Model,
267
+ ))
268
+ .get("tokenizer.json")
269
+ .map_err(wrap_hf_err)?;
270
+ let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
271
+ .map_err(wrap_std_err)?;
272
+ let pp = tokenizers::PaddingParams {
273
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
274
+ ..Default::default()
275
+ };
276
+ tokenizer.with_padding(Some(pp));
277
+
278
+ Ok(tokenizer)
279
+ }
280
+
281
+ /// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
282
+ /// &RETURNS&: Tensor
283
+ fn pooled_cls_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
284
+ // 1) sanity-check that we have a 3D tensor
285
+ let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
286
+
287
+ // 2) slice out just the first token (CLS) along the sequence axis:
288
+ // [B, seq_len, H] → [B, 1, H]
289
+ let first = result
290
+ .narrow(1, 0, 1)
291
+ .map_err(wrap_candle_err)?;
292
+
293
+ // 3) remove that length-1 axis → [B, H]
294
+ let cls = first
295
+ .squeeze(1)
296
+ .map_err(wrap_candle_err)?;
297
+
298
+ Ok(cls)
299
+ }
300
+
301
+ fn pooled_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
302
+ let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
303
+ let sum = result.sum(1).map_err(wrap_candle_err)?;
304
+ let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
305
+ Ok(mean)
306
+ }
307
+
308
+ fn pooled_normalized_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
309
+ let mean = Self::pooled_embedding(result)?;
310
+ let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
311
+ Ok(norm)
312
+ }
313
+
314
+ fn compute_embeddings(
315
+ &self,
316
+ prompt: String,
317
+ model: &EmbeddingModelVariant,
318
+ tokenizer: &Tokenizer,
319
+ ) -> Result<CoreTensor, Error> {
320
+ let tokens = tokenizer
321
+ .encode(prompt, true)
322
+ .map_err(wrap_std_err)?
323
+ .get_ids()
324
+ .to_vec();
325
+ let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
326
+ .map_err(wrap_candle_err)?
327
+ .unsqueeze(0)
328
+ .map_err(wrap_candle_err)?;
329
+ let batch_size = token_ids.dims()[0];
330
+ let seq_len = token_ids.dims()[1];
331
+ let token_type_ids = CoreTensor::zeros(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
332
+ .map_err(wrap_candle_err)?;
333
+ let attention_mask = CoreTensor::ones(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
334
+ .map_err(wrap_candle_err)?;
335
+ match model {
336
+ EmbeddingModelVariant::JinaBert(model) => {
337
+ model.forward(&token_ids).map_err(wrap_candle_err)
338
+ },
339
+ EmbeddingModelVariant::StandardBert(model) => {
340
+ model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
341
+ },
342
+ EmbeddingModelVariant::DistilBert(model) => {
343
+ model.forward(&token_ids, &attention_mask).map_err(wrap_candle_err)
344
+ },
345
+ EmbeddingModelVariant::MiniLM(model) => {
346
+ model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
347
+ },
348
+
349
+ }
350
+ }
351
+
352
+ /// Computes an embedding for the prompt using the specified pooling method.
353
+ /// pooling_method: "pooled", "pooled_normalized", or "cls"
354
+ fn compute_embedding(
355
+ &self,
356
+ prompt: String,
357
+ model: &EmbeddingModelVariant,
358
+ tokenizer: &Tokenizer,
359
+ pooling_method: &str,
360
+ ) -> Result<CoreTensor, Error> {
361
+ let result = self.compute_embeddings(prompt, model, tokenizer)?;
362
+ match pooling_method {
363
+ "pooled" => Self::pooled_embedding(&result),
364
+ "pooled_normalized" => Self::pooled_normalized_embedding(&result),
365
+ "cls" => Self::pooled_cls_embedding(&result),
366
+ _ => Err(magnus::Error::new(magnus::exception::runtime_error(), "Unknown pooling method")),
367
+ }
368
+ }
369
+
370
+ #[allow(dead_code)]
371
+ fn normalize_l2(v: &CoreTensor) -> Result<CoreTensor, candle_core::Error> {
372
+ v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
373
+ }
374
+
375
+ pub fn embedding_model_type(&self) -> String {
376
+ match self.0.embedding_model_type {
377
+ Some(model_type) => format!("{:?}", model_type),
378
+ None => "nil".to_string(),
379
+ }
380
+ }
381
+
382
+ pub fn __repr__(&self) -> String {
383
+ format!(
384
+ "#<Candle::EmbeddingModel embedding_model_type: {}, model_path: {}, tokenizer_path: {}, embedding_size: {}>",
385
+ self.embedding_model_type(),
386
+ self.0.model_path.as_deref().unwrap_or("nil"),
387
+ self.0.tokenizer_path.as_deref().unwrap_or("nil"),
388
+ self.0.embedding_size.map(|x| x.to_string()).unwrap_or("nil".to_string())
389
+ )
390
+ }
391
+
392
+ pub fn __str__(&self) -> String {
393
+ self.__repr__()
394
+ }
395
+ }
396
+
397
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
398
+ let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
399
+ rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
400
+ // Expose embedding with an optional pooling_method argument (default: "pooled")
401
+ rb_embedding_model.define_method("_embedding", method!(EmbeddingModel::embedding, 2))?;
402
+ rb_embedding_model.define_method("embeddings", method!(EmbeddingModel::embeddings, 1))?;
403
+ rb_embedding_model.define_method("pool_embedding", method!(EmbeddingModel::pool_embedding, 1))?;
404
+ rb_embedding_model.define_method("pool_and_normalize_embedding", method!(EmbeddingModel::pool_and_normalize_embedding, 1))?;
405
+ rb_embedding_model.define_method("pool_cls_embedding", method!(EmbeddingModel::pool_cls_embedding, 1))?;
406
+ rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
407
+ rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
408
+ rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
409
+ Ok(())
410
+ }
@@ -0,0 +1,13 @@
1
+ use magnus::Error;
2
+
3
+ pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
4
+ Error::new(magnus::exception::runtime_error(), err.to_string())
5
+ }
6
+
7
+ pub fn wrap_candle_err(err: candle_core::Error) -> Error {
8
+ Error::new(magnus::exception::runtime_error(), err.to_string())
9
+ }
10
+
11
+ pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
12
+ Error::new(magnus::exception::runtime_error(), err.to_string())
13
+ }