red-candle 1.3.0 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- use magnus::{class, function, method, prelude::*, Error, RModule, Float, RArray};
1
+ use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
2
2
  use candle_transformers::models::bert::{BertModel, Config};
3
3
  use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
4
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
@@ -25,6 +25,9 @@ impl Reranker {
25
25
  }
26
26
 
27
27
  fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
28
+ let ruby = Ruby::get().unwrap();
29
+ let runtime_error = ruby.exception_runtime_error();
30
+
28
31
  let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
29
32
  let api = Api::new()?;
30
33
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
@@ -64,50 +67,56 @@ impl Reranker {
64
67
  Ok((model, tokenizer, pooler, classifier)) => {
65
68
  Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
66
69
  }
67
- Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
70
+ Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
68
71
  }
69
72
  }
70
73
 
71
74
  /// Extract CLS embeddings from the model output, handling Metal device workarounds
72
75
  fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
76
+ let ruby = Ruby::get().unwrap();
77
+ let runtime_error = ruby.exception_runtime_error();
78
+
73
79
  let cls_embeddings = if self.device.is_metal() {
74
80
  // Metal has issues with tensor indexing, use a different approach
75
81
  let (batch_size, seq_len, hidden_size) = embeddings.dims3()
76
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
82
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get dims: {}", e)))?;
77
83
 
78
84
  // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
79
85
  let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
80
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
86
+ .map_err(|e| Error::new(runtime_error, format!("Failed to reshape: {}", e)))?;
81
87
 
82
88
  // Extract CLS tokens (first token of each sequence)
83
89
  let mut cls_vecs = Vec::new();
84
90
  for i in 0..batch_size {
85
91
  let start_idx = i * seq_len;
86
92
  let cls_vec = reshaped.narrow(0, start_idx, 1)
87
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
93
+ .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?;
88
94
  cls_vecs.push(cls_vec);
89
95
  }
90
96
 
91
97
  // Stack the CLS vectors
92
98
  Tensor::cat(&cls_vecs, 0)
93
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
99
+ .map_err(|e| Error::new(runtime_error, format!("Failed to cat CLS tokens: {}", e)))?
94
100
  } else {
95
101
  embeddings.i((.., 0))
96
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
102
+ .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS token: {}", e)))?
97
103
  };
98
104
 
99
105
  // Ensure tensor is contiguous for downstream operations
100
106
  cls_embeddings.contiguous()
101
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
107
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make CLS embeddings contiguous: {}", e)))
102
108
  }
103
109
 
104
- pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
110
+ pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
111
+ let ruby = Ruby::get().unwrap();
112
+ let runtime_error = ruby.exception_runtime_error();
113
+
105
114
  // Create query-document pair for cross-encoder
106
115
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
107
116
 
108
117
  // Tokenize using the inner tokenizer for detailed info
109
118
  let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
110
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
119
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
111
120
 
112
121
  // Get token information
113
122
  let token_ids = encoding.get_ids().to_vec();
@@ -116,16 +125,18 @@ impl Reranker {
116
125
  let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
117
126
 
118
127
  // Create result hash
119
- let result = magnus::RHash::new();
120
- result.aset("token_ids", RArray::from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
121
- result.aset("token_type_ids", RArray::from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
122
- result.aset("attention_mask", RArray::from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
123
- result.aset("tokens", RArray::from_vec(tokens))?;
128
+ let result = ruby.hash_new();
129
+ result.aset("token_ids", ruby.ary_from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
130
+ result.aset("token_type_ids", ruby.ary_from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
131
+ result.aset("attention_mask", ruby.ary_from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
132
+ result.aset("tokens", ruby.ary_from_vec(tokens))?;
124
133
 
125
134
  Ok(result)
126
135
  }
127
136
 
128
137
  pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
138
+ let ruby = Ruby::get().unwrap();
139
+ let runtime_error = ruby.exception_runtime_error();
129
140
  let documents: Vec<String> = documents.to_vec()?;
130
141
 
131
142
  // Create query-document pairs for cross-encoder
@@ -136,7 +147,7 @@ impl Reranker {
136
147
 
137
148
  // Tokenize batch using inner tokenizer for access to token type IDs
138
149
  let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
139
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
150
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
140
151
 
141
152
  // Convert to tensors
142
153
  let token_ids = encodings
@@ -150,15 +161,15 @@ impl Reranker {
150
161
  .collect::<Vec<_>>();
151
162
 
152
163
  let token_ids = Tensor::new(token_ids, &self.device)
153
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create tensor: {}", e)))?;
164
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
154
165
  let token_type_ids = Tensor::new(token_type_ids, &self.device)
155
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create token type ids tensor: {}", e)))?;
166
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
156
167
  let attention_mask = token_ids.ne(0u32)
157
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create attention mask: {}", e)))?;
168
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
158
169
 
159
170
  // Forward pass through BERT
160
171
  let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
161
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Model forward pass failed: {}", e)))?;
172
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
162
173
 
163
174
  // Apply pooling based on the specified method
164
175
  let pooled_embeddings = match pooling_method.as_str() {
@@ -166,9 +177,9 @@ impl Reranker {
166
177
  // Extract [CLS] token and apply pooler (dense + tanh)
167
178
  let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
168
179
  let pooled = self.pooler.forward(&cls_embeddings)
169
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
180
+ .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
170
181
  pooled.tanh()
171
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tanh activation failed: {}", e)))?
182
+ .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
172
183
  },
173
184
  "cls" => {
174
185
  // Just use the [CLS] token embeddings directly (no pooler layer)
@@ -177,35 +188,35 @@ impl Reranker {
177
188
  "mean" => {
178
189
  // Mean pooling across all tokens
179
190
  let (_batch, seq_len, _hidden) = embeddings.dims3()
180
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get tensor dimensions: {}", e)))?;
191
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
181
192
  let sum = embeddings.sum(1)
182
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to sum embeddings: {}", e)))?;
193
+ .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
183
194
  (sum / (seq_len as f64))
184
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to compute mean: {}", e)))?
195
+ .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
185
196
  },
186
- _ => return Err(Error::new(magnus::exception::runtime_error(),
197
+ _ => return Err(Error::new(runtime_error,
187
198
  format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
188
199
  };
189
200
 
190
201
  // Apply classifier to get relevance scores (raw logits)
191
202
  // Ensure tensor is contiguous before linear layer
192
203
  let pooled_embeddings = pooled_embeddings.contiguous()
193
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
204
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
194
205
  let logits = self.classifier.forward(&pooled_embeddings)
195
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Classifier forward failed: {}", e)))?;
206
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
196
207
  let scores = logits.squeeze(1)
197
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to squeeze tensor: {}", e)))?;
208
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
198
209
 
199
210
  // Optionally apply sigmoid activation
200
211
  let scores = if apply_sigmoid {
201
212
  sigmoid(&scores)
202
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Sigmoid failed: {}", e)))?
213
+ .map_err(|e| Error::new(runtime_error, format!("Sigmoid failed: {}", e)))?
203
214
  } else {
204
215
  scores
205
216
  };
206
217
 
207
218
  let scores_vec: Vec<f32> = scores.to_vec1()
208
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to convert scores to vec: {}", e)))?;
219
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert scores to vec: {}", e)))?;
209
220
 
210
221
  // Create tuples with document, score, and original index
211
222
  let mut ranked_docs: Vec<(String, f32, usize)> = documents
@@ -219,11 +230,11 @@ impl Reranker {
219
230
  ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
220
231
 
221
232
  // Build result array with [doc, score, doc_id]
222
- let result_array = RArray::new();
233
+ let result_array = ruby.ary_new();
223
234
  for (doc, score, doc_id) in ranked_docs {
224
- let tuple = RArray::new();
235
+ let tuple = ruby.ary_new();
225
236
  tuple.push(doc)?;
226
- tuple.push(Float::from_f64(score as f64))?;
237
+ tuple.push(ruby.float_from_f64(score as f64))?;
227
238
  tuple.push(doc_id)?;
228
239
  result_array.push(tuple)?;
229
240
  }
@@ -246,8 +257,9 @@ impl Reranker {
246
257
  }
247
258
 
248
259
  /// Get all options as a hash
249
- pub fn options(&self) -> std::result::Result<magnus::RHash, Error> {
250
- let hash = magnus::RHash::new();
260
+ pub fn options(&self) -> std::result::Result<RHash, Error> {
261
+ let ruby = Ruby::get().unwrap();
262
+ let hash = ruby.hash_new();
251
263
  hash.aset("model_id", self.model_id.clone())?;
252
264
  hash.aset("device", self.device().__str__())?;
253
265
  Ok(hash)
@@ -255,7 +267,8 @@ impl Reranker {
255
267
  }
256
268
 
257
269
  pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
258
- let c_reranker = rb_candle.define_class("Reranker", class::object())?;
270
+ let ruby = Ruby::get().unwrap();
271
+ let c_reranker = rb_candle.define_class("Reranker", ruby.class_object())?;
259
272
  c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
260
273
  c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
261
274
  c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
@@ -264,4 +277,4 @@ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
264
277
  c_reranker.define_method("device", method!(Reranker::device, 0))?;
265
278
  c_reranker.define_method("options", method!(Reranker::options, 0))?;
266
279
  Ok(())
267
- }
280
+ }
@@ -1,7 +1,7 @@
1
- use magnus::{Error, Module, RModule, function, Object};
1
+ use magnus::{Error, Module, RModule, function, Object, Ruby};
2
2
  use std::sync::Arc;
3
3
 
4
- use crate::structured::{SchemaProcessor, VocabularyAdapter, Index};
4
+ use crate::structured::{SchemaProcessor, VocabularyAdapter, Index, Vocabulary};
5
5
  use crate::ruby::{Result, tokenizer::Tokenizer};
6
6
 
7
7
  /// Ruby wrapper for structured generation constraints
@@ -12,36 +12,81 @@ pub struct StructuredConstraint {
12
12
  }
13
13
 
14
14
  impl StructuredConstraint {
15
- /// Create a constraint from a JSON schema
15
+ /// Create a constraint from a JSON schema using a model ID
16
+ /// This uses Vocabulary::from_pretrained which handles tokenizer byte encoding correctly
17
+ pub fn from_schema_with_model(schema: String, model_id: String) -> Result<Self> {
18
+ // Use tokio runtime for async vocabulary loading
19
+ let rt = tokio::runtime::Runtime::new()
20
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
21
+
22
+ let vocabulary = rt.block_on(async {
23
+ Vocabulary::from_pretrained(&model_id, None)
24
+ })
25
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
26
+
27
+ let processor = SchemaProcessor::new();
28
+ let index = processor.process_schema(&schema, &vocabulary)
29
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
30
+
31
+ Ok(Self { index })
32
+ }
33
+
34
+ /// Create a constraint from a regex pattern using a model ID
35
+ pub fn from_regex_with_model(pattern: String, model_id: String) -> Result<Self> {
36
+ // Use tokio runtime for async vocabulary loading
37
+ let rt = tokio::runtime::Runtime::new()
38
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
39
+
40
+ let vocabulary = rt.block_on(async {
41
+ Vocabulary::from_pretrained(&model_id, None)
42
+ })
43
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
44
+
45
+ let processor = SchemaProcessor::new();
46
+ let index = processor.process_regex(&pattern, &vocabulary)
47
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
48
+
49
+ Ok(Self { index })
50
+ }
51
+
52
+ /// Create a constraint from a JSON schema (legacy method using tokenizer directly)
53
+ /// Note: This may not handle all tokenizer byte encodings correctly
16
54
  pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
17
55
  let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
18
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
19
-
56
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
57
+
20
58
  let processor = SchemaProcessor::new();
21
59
  let index = processor.process_schema(&schema, &vocabulary)
22
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process schema: {}", e)))?;
23
-
60
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
61
+
24
62
  Ok(Self { index })
25
63
  }
26
-
27
- /// Create a constraint from a regex pattern
64
+
65
+ /// Create a constraint from a regex pattern (legacy method using tokenizer directly)
66
+ /// Note: This may not handle all tokenizer byte encodings correctly
28
67
  pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
29
68
  let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
30
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
31
-
69
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
70
+
32
71
  let processor = SchemaProcessor::new();
33
72
  let index = processor.process_regex(&pattern, &vocabulary)
34
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process regex: {}", e)))?;
35
-
73
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
74
+
36
75
  Ok(Self { index })
37
76
  }
38
77
  }
39
78
 
40
79
  pub fn init_structured(rb_candle: RModule) -> Result<()> {
41
- let class = rb_candle.define_class("StructuredConstraint", magnus::class::object())?;
42
-
80
+ let ruby = Ruby::get().unwrap();
81
+ let class = rb_candle.define_class("StructuredConstraint", ruby.class_object())?;
82
+
83
+ // New methods using model_id for proper vocabulary loading
84
+ class.define_singleton_method("from_schema_with_model", function!(StructuredConstraint::from_schema_with_model, 2))?;
85
+ class.define_singleton_method("from_regex_with_model", function!(StructuredConstraint::from_regex_with_model, 2))?;
86
+
87
+ // Legacy methods using tokenizer directly (may have byte encoding issues with some models)
43
88
  class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
44
89
  class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
45
-
90
+
46
91
  Ok(())
47
92
  }
@@ -1,5 +1,5 @@
1
1
  use magnus::prelude::*;
2
- use magnus::{function, method, class, RModule, Module, Object};
2
+ use magnus::{function, method, RModule, Module, Object, Ruby};
3
3
 
4
4
  use crate::ruby::{
5
5
  errors::wrap_candle_err,
@@ -84,7 +84,7 @@ impl Tensor {
84
84
  Ok(f.to_f64() as i64)
85
85
  } else {
86
86
  Err(magnus::Error::new(
87
- magnus::exception::type_error(),
87
+ Ruby::get().unwrap().exception_type_error(),
88
88
  "Cannot convert to i64"
89
89
  ))
90
90
  }
@@ -143,7 +143,7 @@ impl Tensor {
143
143
  Ok(values)
144
144
  }
145
145
  _ => Err(magnus::Error::new(
146
- magnus::exception::runtime_error(),
146
+ Ruby::get().unwrap().exception_runtime_error(),
147
147
  "Tensor must be F32 dtype for values_f32",
148
148
  )),
149
149
  }
@@ -153,7 +153,7 @@ impl Tensor {
153
153
  pub fn item(&self) -> Result<f64> {
154
154
  if self.0.rank() != 0 {
155
155
  return Err(magnus::Error::new(
156
- magnus::exception::runtime_error(),
156
+ Ruby::get().unwrap().exception_runtime_error(),
157
157
  format!("item() can only be called on scalar tensors (rank 0), but tensor has rank {}", self.0.rank()),
158
158
  ));
159
159
  }
@@ -384,7 +384,7 @@ impl Tensor {
384
384
  let scalar = CoreTensor::from_vec(vec![i as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
385
385
  Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
386
386
  } else {
387
- Err(magnus::Error::new(magnus::exception::type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
387
+ Err(magnus::Error::new(Ruby::get().unwrap().exception_type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
388
388
  }
389
389
  }
390
390
 
@@ -650,7 +650,8 @@ impl Tensor {
650
650
  }
651
651
 
652
652
  pub fn init(rb_candle: RModule) -> Result<()> {
653
- let rb_tensor = rb_candle.define_class("Tensor", class::object())?;
653
+ let ruby = Ruby::get().unwrap();
654
+ let rb_tensor = rb_candle.define_class("Tensor", ruby.class_object())?;
654
655
  rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
655
656
  // rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
656
657
  // rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;