red-candle 1.8.0-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,284 @@
1
+ use candle_core::{Result as CandleResult, Tensor};
2
+ use candle_transformers::generation::LogitsProcessor;
3
+ use std::sync::Arc;
4
+
5
+ use super::GenerationConfig;
6
+ use crate::structured::Index;
7
+
8
+ /// Helper struct for text generation process
9
+ pub struct TextGeneration {
10
+ logits_processor: LogitsProcessor,
11
+ tokens: Vec<u32>,
12
+ eos_token_id: Option<u32>,
13
+ repetition_penalty: f32,
14
+ repetition_penalty_last_n: usize,
15
+ constraint: Option<Arc<Index>>,
16
+ constraint_state: Option<u32>,
17
+ constraint_completed: bool,
18
+ tokens_since_constraint_start: usize,
19
+ }
20
+
21
+ impl TextGeneration {
22
+ pub fn new(config: &GenerationConfig) -> Self {
23
+ let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
24
+
25
+ let mut text_gen = Self {
26
+ logits_processor,
27
+ tokens: Vec::new(),
28
+ eos_token_id: None,
29
+ repetition_penalty: config.repetition_penalty,
30
+ repetition_penalty_last_n: config.repetition_penalty_last_n,
31
+ constraint: None,
32
+ constraint_state: None,
33
+ constraint_completed: false,
34
+ tokens_since_constraint_start: 0,
35
+ };
36
+
37
+ // Set constraint if provided
38
+ if let Some(ref constraint) = config.constraint {
39
+ text_gen.set_constraint(Arc::clone(constraint));
40
+ }
41
+
42
+ text_gen
43
+ }
44
+
45
+ pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
46
+ self.eos_token_id = Some(eos_token_id);
47
+ }
48
+
49
+ pub fn set_tokens(&mut self, tokens: Vec<u32>) {
50
+ self.tokens = tokens;
51
+ }
52
+
53
+ pub fn get_tokens(&self) -> &[u32] {
54
+ &self.tokens
55
+ }
56
+
57
+ pub fn push_token(&mut self, token: u32) {
58
+ self.tokens.push(token);
59
+ }
60
+
61
+ pub fn set_constraint(&mut self, constraint: Arc<Index>) {
62
+ // Initialize with the first state
63
+ self.constraint_state = Some(constraint.initial_state());
64
+ self.constraint = Some(constraint);
65
+ self.constraint_completed = false;
66
+ self.tokens_since_constraint_start = self.tokens.len();
67
+ }
68
+
69
+ /// Apply constraints to logits by masking disallowed tokens
70
+ fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
71
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
72
+ let device = logits.device();
73
+ let vocab_size = logits.dims1()?;
74
+
75
+ // Get allowed tokens from the constraint index for current state
76
+ if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
77
+ // Create a mask where allowed tokens have value 0 and others have -inf
78
+ let mut mask = vec![f32::NEG_INFINITY; vocab_size];
79
+ for &token_id in &allowed_tokens {
80
+ if (token_id as usize) < vocab_size {
81
+ mask[token_id as usize] = 0.0;
82
+ }
83
+ }
84
+
85
+ // Apply mask to logits
86
+ let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
87
+ *logits = logits.add(&mask_tensor)?;
88
+ }
89
+ }
90
+ Ok(())
91
+ }
92
+
93
+ /// Apply repetition penalty to logits
94
+ pub fn apply_repetition_penalty(
95
+ &self,
96
+ logits: &mut Tensor,
97
+ penalty: f32,
98
+ context_size: usize,
99
+ ) -> CandleResult<()> {
100
+ if penalty == 1.0 {
101
+ return Ok(());
102
+ }
103
+
104
+ let device = logits.device();
105
+ let vocab_size = logits.dims1()?;
106
+
107
+ // Get the context tokens to apply penalty to
108
+ let start = self.tokens.len().saturating_sub(context_size);
109
+ let context_tokens = &self.tokens[start..];
110
+
111
+ // Apply penalty to tokens that appear in the context
112
+ let mut logits_vec = logits.to_vec1::<f32>()?;
113
+ for &token in context_tokens {
114
+ if (token as usize) < vocab_size {
115
+ let idx = token as usize;
116
+ if logits_vec[idx] > 0.0 {
117
+ logits_vec[idx] /= penalty;
118
+ } else {
119
+ logits_vec[idx] *= penalty;
120
+ }
121
+ }
122
+ }
123
+
124
+ *logits = Tensor::from_vec(logits_vec, vocab_size, device)?;
125
+ Ok(())
126
+ }
127
+
128
+ /// Sample next token from logits
129
+ pub fn sample_next_token(
130
+ &mut self,
131
+ logits: &Tensor,
132
+ ) -> CandleResult<u32> {
133
+ let mut logits = logits.clone();
134
+
135
+ // Apply repetition penalty using stored parameters
136
+ if self.repetition_penalty != 1.0 {
137
+ self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
138
+ }
139
+
140
+ // Apply constraints if active
141
+ self.apply_constraints(&mut logits)?;
142
+
143
+ // Sample token
144
+ let next_token = self.logits_processor.sample(&logits)?;
145
+ self.tokens.push(next_token);
146
+
147
+ // Update constraint state if active
148
+ if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
149
+ // Get the next state
150
+ let next_state = constraint_index.next_state(&current_state, &next_token);
151
+
152
+ // Check if we're transitioning to a state with no allowed tokens (completion)
153
+ if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
154
+ // Check if next state has no allowed tokens at all - this is definitive completion
155
+ if let Some(next_state_val) = next_state {
156
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
157
+ if allowed.is_empty() {
158
+ self.constraint_completed = true;
159
+ }
160
+ // Only mark as complete if ONLY EOS is allowed (not just if EOS is one of many options)
161
+ else if let Some(eos) = self.eos_token_id {
162
+ if allowed.len() == 1 && allowed.contains(&eos) {
163
+ self.constraint_completed = true;
164
+ }
165
+ }
166
+ } else {
167
+ // None means no tokens allowed - constraint is complete
168
+ self.constraint_completed = true;
169
+ }
170
+ }
171
+ }
172
+
173
+ self.constraint_state = next_state;
174
+ }
175
+
176
+ Ok(next_token)
177
+ }
178
+
179
+ /// Check if the constraint is satisfied (reached a valid completion state)
180
+ pub fn is_constraint_satisfied(&self) -> bool {
181
+ // If we've explicitly marked the constraint as completed, return true
182
+ if self.constraint_completed {
183
+ return true;
184
+ }
185
+
186
+ // Also check the current state
187
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
188
+ // Check if the constraint has reached a state where it MUST end
189
+ // This happens when there are no more allowed tokens (constraint fully satisfied)
190
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
191
+ // If no tokens are allowed, the constraint is fully satisfied
192
+ if allowed.is_empty() {
193
+ return true;
194
+ }
195
+
196
+ // For JSON schemas, check if ONLY the EOS token is allowed
197
+ // This means we've generated a complete, valid JSON structure
198
+ // Don't treat EOS as a satisfaction signal if other tokens are also allowed
199
+ if let Some(eos) = self.eos_token_id {
200
+ if allowed.len() == 1 && allowed.contains(&eos) {
201
+ return true;
202
+ }
203
+ }
204
+ } else {
205
+ // None means no tokens allowed - constraint is satisfied
206
+ return true;
207
+ }
208
+ }
209
+ false
210
+ }
211
+
212
+ /// Check if the constraint is satisfied when stop_on_match is true
213
+ /// NOTE: For JSON schemas, this should only return true when the JSON structure is complete,
214
+ /// not just because we're in a state with many allowed tokens (like inside a string).
215
+ pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
216
+ // When stop_on_match is true, we stop as soon as the constraint is completed
217
+ if self.constraint_completed {
218
+ return true;
219
+ }
220
+
221
+ // For JSON and other structured outputs, don't use the "large allowed set" heuristic.
222
+ // Instead, only consider the constraint satisfied when:
223
+ // 1. There are no allowed tokens (definitive completion)
224
+ // 2. Only EOS is allowed (completion with optional termination)
225
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
226
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
227
+ // No more tokens allowed - definitely complete
228
+ if allowed.is_empty() {
229
+ return true;
230
+ }
231
+
232
+ // Only EOS is allowed - complete JSON structure
233
+ if let Some(eos) = self.eos_token_id {
234
+ if allowed.len() == 1 && allowed.contains(&eos) {
235
+ return true;
236
+ }
237
+ }
238
+ } else {
239
+ // None means no tokens allowed - constraint is complete
240
+ return true;
241
+ }
242
+ }
243
+
244
+ false
245
+ }
246
+
247
+ /// Check if we should stop generation
248
+ pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
249
+ if self.tokens.len() >= max_length {
250
+ return true;
251
+ }
252
+
253
+ if let Some(eos) = self.eos_token_id {
254
+ if token == eos {
255
+ return true;
256
+ }
257
+ }
258
+
259
+ // Check if we've reached a final state in constraint
260
+ // A state is considered final if it has no allowed tokens
261
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
262
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
263
+ if allowed.is_empty() {
264
+ return true;
265
+ }
266
+ } else {
267
+ // None means no tokens allowed - we're done
268
+ return true;
269
+ }
270
+ }
271
+
272
+ false
273
+ }
274
+
275
+ /// Check if the generated text ends with any stop sequence
276
+ pub fn check_stop_sequences(&self, text: &str, stop_sequences: &[String]) -> bool {
277
+ for seq in stop_sequences {
278
+ if text.ends_with(seq) {
279
+ return true;
280
+ }
281
+ }
282
+ false
283
+ }
284
+ }
@@ -0,0 +1,234 @@
1
+ use magnus::Error;
2
+ use magnus::{function, method, RModule, Module, Object, Ruby};
3
+
4
+ use ::candle_core::Device as CoreDevice;
5
+ use crate::ruby::Result;
6
+
7
+ #[cfg(any(feature = "cuda", feature = "metal"))]
8
+ use crate::ruby::errors::wrap_candle_err;
9
+
10
+ #[cfg(feature = "cuda")]
11
+ static CUDA_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
12
+
13
+ #[cfg(feature = "metal")]
14
+ static METAL_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
15
+
16
+ /// Get list of available devices based on compile-time features
17
+ pub fn available_devices() -> Vec<String> {
18
+ let devices = vec!["cpu".to_string()];
19
+
20
+ #[cfg(all(feature = "cuda", not(force_cpu)))]
21
+ let devices = {
22
+ let mut devices = devices;
23
+ devices.push("cuda".to_string());
24
+ devices
25
+ };
26
+
27
+ #[cfg(all(feature = "metal", not(force_cpu)))]
28
+ let devices = {
29
+ let mut devices = devices;
30
+ devices.push("metal".to_string());
31
+ devices
32
+ };
33
+
34
+ devices
35
+ }
36
+
37
+ /// Get the default device based on what's available
38
+ pub fn default_device() -> Device {
39
+ // Return based on compiled features, not detection
40
+ #[cfg(all(feature = "metal", not(force_cpu)))]
41
+ {
42
+ Device::Metal
43
+ }
44
+
45
+ #[cfg(all(feature = "cuda", not(feature = "metal"), not(force_cpu)))]
46
+ {
47
+ Device::Cuda
48
+ }
49
+
50
+ #[cfg(not(any(all(feature = "metal", not(force_cpu)), all(feature = "cuda", not(feature = "metal"), not(force_cpu)))))]
51
+ {
52
+ Device::Cpu
53
+ }
54
+ }
55
+
56
+ /// Get the best available device by checking runtime availability
57
+ pub fn best_device() -> Device {
58
+ // Try devices in order of preference
59
+
60
+ #[cfg(feature = "metal")]
61
+ {
62
+ // Check if Metal is actually available at runtime
63
+ if CoreDevice::new_metal(0).is_ok() {
64
+ return Device::Metal;
65
+ }
66
+ }
67
+
68
+ #[cfg(feature = "cuda")]
69
+ {
70
+ // Check if CUDA is actually available at runtime
71
+ if CoreDevice::new_cuda(0).is_ok() {
72
+ return Device::Cuda;
73
+ }
74
+ }
75
+
76
+ // Always fall back to CPU
77
+ Device::Cpu
78
+ }
79
+
80
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
81
+ #[magnus::wrap(class = "Candle::Device")]
82
+ pub enum Device {
83
+ Cpu,
84
+ Cuda,
85
+ Metal,
86
+ }
87
+
88
+ impl Device {
89
+ /// Create a CPU device
90
+ pub fn cpu() -> Self {
91
+ Self::Cpu
92
+ }
93
+
94
+ /// Get the best available device
95
+ pub fn best() -> Self {
96
+ best_device()
97
+ }
98
+
99
+ /// Create a CUDA device (GPU)
100
+ pub fn cuda() -> Result<Self> {
101
+ #[cfg(not(feature = "cuda"))]
102
+ {
103
+ return Err(Error::new(
104
+ Ruby::get().unwrap().exception_runtime_error(),
105
+ "CUDA support not compiled in. Rebuild with CUDA available.",
106
+ ));
107
+ }
108
+
109
+ #[cfg(feature = "cuda")]
110
+ Ok(Self::Cuda)
111
+ }
112
+
113
+ /// Create a Metal device (Apple GPU)
114
+ pub fn metal() -> Result<Self> {
115
+ #[cfg(not(feature = "metal"))]
116
+ {
117
+ return Err(Error::new(
118
+ Ruby::get().unwrap().exception_runtime_error(),
119
+ "Metal support not compiled in. Rebuild on macOS.",
120
+ ));
121
+ }
122
+
123
+ #[cfg(feature = "metal")]
124
+ Ok(Self::Metal)
125
+ }
126
+
127
+ pub fn from_device(device: &CoreDevice) -> Self {
128
+ match device {
129
+ CoreDevice::Cpu => Self::Cpu,
130
+ CoreDevice::Cuda(_) => Self::Cuda,
131
+ CoreDevice::Metal(_) => Self::Metal,
132
+ }
133
+ }
134
+
135
+ pub fn as_device(&self) -> Result<CoreDevice> {
136
+ match self {
137
+ Self::Cpu => Ok(CoreDevice::Cpu),
138
+ Self::Cuda => {
139
+ #[cfg(not(feature = "cuda"))]
140
+ {
141
+ return Err(Error::new(
142
+ Ruby::get().unwrap().exception_runtime_error(),
143
+ "CUDA support not compiled in. Rebuild with CUDA available.",
144
+ ));
145
+ }
146
+
147
+ #[cfg(feature = "cuda")]
148
+ {
149
+ let mut device = CUDA_DEVICE.lock().unwrap();
150
+ if let Some(device) = device.as_ref() {
151
+ return Ok(device.clone());
152
+ };
153
+ // Note: new_cuda() is used here (not cuda_if_available) because
154
+ // we want to fail if CUDA isn't available at runtime, not fall back to CPU
155
+ let d = CoreDevice::new_cuda(0).map_err(wrap_candle_err)?;
156
+ *device = Some(d.clone());
157
+ Ok(d)
158
+ }
159
+ }
160
+ Self::Metal => {
161
+ #[cfg(not(feature = "metal"))]
162
+ {
163
+ return Err(Error::new(
164
+ Ruby::get().unwrap().exception_runtime_error(),
165
+ "Metal support not compiled in. Rebuild on macOS.",
166
+ ));
167
+ }
168
+
169
+ #[cfg(feature = "metal")]
170
+ {
171
+ let mut device = METAL_DEVICE.lock().unwrap();
172
+ if let Some(device) = device.as_ref() {
173
+ return Ok(device.clone());
174
+ };
175
+ let d = CoreDevice::new_metal(0).map_err(wrap_candle_err)?;
176
+ *device = Some(d.clone());
177
+ Ok(d)
178
+ }
179
+ }
180
+ }
181
+ }
182
+
183
+ pub fn __repr__(&self) -> String {
184
+ match self {
185
+ Self::Cpu => "cpu".to_string(),
186
+ Self::Cuda => "cuda".to_string(),
187
+ Self::Metal => "metal".to_string(),
188
+ }
189
+ }
190
+
191
+ pub fn __str__(&self) -> String {
192
+ self.__repr__()
193
+ }
194
+
195
+ pub fn __eq__(&self, other: &Device) -> bool {
196
+ self == other
197
+ }
198
+ }
199
+
200
+ impl magnus::TryConvert for Device {
201
+ fn try_convert(val: magnus::Value) -> Result<Self> {
202
+ // First check if it's already a wrapped Device object
203
+ if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
204
+ return Ok(*device);
205
+ }
206
+
207
+ // Otherwise try to convert from string
208
+ let device = magnus::RString::try_convert(val)?;
209
+ let device = unsafe { device.as_str() }.unwrap();
210
+ let device = match device {
211
+ "cpu" => Device::Cpu,
212
+ "cuda" => Device::Cuda,
213
+ "metal" => Device::Metal,
214
+ _ => return Err(Error::new(Ruby::get().unwrap().exception_arg_error(), "invalid device")),
215
+ };
216
+ Ok(device)
217
+ }
218
+ }
219
+
220
+ pub fn init(rb_candle: RModule) -> Result<()> {
221
+ let ruby = Ruby::get().unwrap();
222
+ let rb_device = rb_candle.define_class("Device", ruby.class_object())?;
223
+ rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
224
+ rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
225
+ rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
226
+ rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
227
+ rb_device.define_singleton_method("default", function!(default_device, 0))?;
228
+ rb_device.define_singleton_method("best", function!(best_device, 0))?;
229
+ rb_device.define_method("to_s", method!(Device::__str__, 0))?;
230
+ rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
231
+ rb_device.define_method("==", method!(Device::__eq__, 1))?;
232
+ Ok(())
233
+ }
234
+
@@ -0,0 +1,39 @@
1
+ use magnus::value::ReprValue;
2
+ use magnus::{method, RModule, Module, Ruby};
3
+
4
+ use ::candle_core::DType as CoreDType;
5
+ use crate::ruby::Result;
6
+
7
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
8
+ #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
9
+
10
+ /// A `candle` dtype.
11
+ pub struct DType(pub CoreDType);
12
+
13
+ impl DType {
14
+ pub fn __repr__(&self) -> String {
15
+ format!("{:?}", self.0)
16
+ }
17
+
18
+ pub fn __str__(&self) -> String {
19
+ self.__repr__()
20
+ }
21
+ }
22
+
23
+ impl DType {
24
+ pub fn from_rbobject(dtype: magnus::Symbol) -> Result<Self> {
25
+ let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
26
+ use std::str::FromStr;
27
+ let dtype = CoreDType::from_str(&dtype).unwrap();
28
+ Ok(Self(dtype))
29
+ }
30
+ }
31
+
32
+ pub fn init(rb_candle: RModule) -> Result<()> {
33
+ let ruby = Ruby::get().unwrap();
34
+ let rb_dtype = rb_candle.define_class("DType", ruby.class_object())?;
35
+ rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
36
+ rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
37
+ Ok(())
38
+ }
39
+