red-candle 1.8.0.pre3-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,731 @@
1
+ use magnus::prelude::*;
2
+ use magnus::{function, method, RModule, Module, Object, Ruby};
3
+
4
+ use crate::ruby::{
5
+ errors::wrap_candle_err,
6
+ utils::{actual_dim, actual_index},
7
+ };
8
+ use crate::ruby::{DType, Device, Result};
9
+ use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
10
+
11
+ #[derive(Clone, Debug)]
12
+ #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
13
+ /// A `candle` tensor.
14
+ pub struct Tensor(pub CoreTensor);
15
+
16
+ impl std::ops::Deref for Tensor {
17
+ type Target = CoreTensor;
18
+
19
+ fn deref(&self) -> &Self::Target {
20
+ &self.0
21
+ }
22
+ }
23
+
24
+ // Helper functions for tensor operations
25
+ impl Tensor {
26
+ /// Check if device is Metal
27
+ fn is_metal_device(device: &CoreDevice) -> bool {
28
+ matches!(device.location(), DeviceLocation::Metal { .. })
29
+ }
30
+
31
+ /// Convert tensor to target dtype, handling Metal limitations
32
+ fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
33
+ if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
34
+ // Move to CPU first to avoid Metal conversion limitations
35
+ self.0
36
+ .to_device(&CoreDevice::Cpu)
37
+ .map_err(wrap_candle_err)?
38
+ .to_dtype(target_dtype)
39
+ .map_err(wrap_candle_err)
40
+ } else {
41
+ // Direct conversion for CPU or when dtype matches
42
+ self.0
43
+ .to_dtype(target_dtype)
44
+ .map_err(wrap_candle_err)
45
+ }
46
+ }
47
+ }
48
+
49
+ impl Tensor {
50
+ pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
51
+ let dtype = dtype
52
+ .map(|dtype| DType::from_rbobject(dtype))
53
+ .unwrap_or(Ok(DType(CoreDType::F32)))?;
54
+ let device = device.unwrap_or(Device::best()).as_device()?;
55
+
56
+ // Create tensor based on target dtype to avoid conversion issues on Metal
57
+ let tensor = match dtype.0 {
58
+ CoreDType::F32 => {
59
+ // Convert to f32 directly to avoid F64->F32 conversion on Metal
60
+ let array: Vec<f32> = array
61
+ .into_iter()
62
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
63
+ .collect::<Result<Vec<_>>>()?;
64
+ let len = array.len();
65
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
66
+ }
67
+ CoreDType::F64 => {
68
+ let array: Vec<f64> = array
69
+ .into_iter()
70
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
71
+ .collect::<Result<Vec<_>>>()?;
72
+ let len = array.len();
73
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
74
+ }
75
+ CoreDType::I64 => {
76
+ // Convert to i64 directly to avoid conversion issues on Metal
77
+ let array: Vec<i64> = array
78
+ .into_iter()
79
+ .map(|v| {
80
+ // Try integer first, then float
81
+ if let Ok(i) = <i64>::try_convert(v) {
82
+ Ok(i)
83
+ } else if let Ok(f) = magnus::Float::try_convert(v) {
84
+ Ok(f.to_f64() as i64)
85
+ } else {
86
+ Err(magnus::Error::new(
87
+ Ruby::get().unwrap().exception_type_error(),
88
+ "Cannot convert to i64"
89
+ ))
90
+ }
91
+ })
92
+ .collect::<Result<Vec<_>>>()?;
93
+ let len = array.len();
94
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
95
+ }
96
+ _ => {
97
+ // For other dtypes, create on CPU first if on Metal, then convert
98
+ let cpu_device = CoreDevice::Cpu;
99
+ let use_cpu = Self::is_metal_device(&device);
100
+ let target_device = if use_cpu { &cpu_device } else { &device };
101
+
102
+ let array: Vec<f64> = array
103
+ .into_iter()
104
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
105
+ .collect::<Result<Vec<_>>>()?;
106
+ let tensor = CoreTensor::new(array.as_slice(), target_device)
107
+ .map_err(wrap_candle_err)?
108
+ .to_dtype(dtype.0)
109
+ .map_err(wrap_candle_err)?;
110
+
111
+ // Move to target device if needed
112
+ if use_cpu {
113
+ tensor.to_device(&device).map_err(wrap_candle_err)?
114
+ } else {
115
+ tensor
116
+ }
117
+ }
118
+ };
119
+
120
+ Ok(Self(tensor))
121
+ }
122
+
123
+ pub fn values(&self) -> Result<Vec<f64>> {
124
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
125
+ let values = tensor
126
+ .flatten_all()
127
+ .map_err(wrap_candle_err)?
128
+ .to_vec1()
129
+ .map_err(wrap_candle_err)?;
130
+ Ok(values)
131
+ }
132
+
133
+ /// Get values as f32 without dtype conversion
134
+ pub fn values_f32(&self) -> Result<Vec<f32>> {
135
+ match self.0.dtype() {
136
+ CoreDType::F32 => {
137
+ let values = self
138
+ .0
139
+ .flatten_all()
140
+ .map_err(wrap_candle_err)?
141
+ .to_vec1()
142
+ .map_err(wrap_candle_err)?;
143
+ Ok(values)
144
+ }
145
+ _ => Err(magnus::Error::new(
146
+ Ruby::get().unwrap().exception_runtime_error(),
147
+ "Tensor must be F32 dtype for values_f32",
148
+ )),
149
+ }
150
+ }
151
+
152
+ /// Get a single scalar value from a rank-0 tensor
153
+ pub fn item(&self) -> Result<f64> {
154
+ if self.0.rank() != 0 {
155
+ return Err(magnus::Error::new(
156
+ Ruby::get().unwrap().exception_runtime_error(),
157
+ format!("item() can only be called on scalar tensors (rank 0), but tensor has rank {}", self.0.rank()),
158
+ ));
159
+ }
160
+
161
+ // Try to get the value based on dtype
162
+ match self.0.dtype() {
163
+ CoreDType::F32 => {
164
+ let val: f32 = self.0.to_vec0().map_err(wrap_candle_err)?;
165
+ Ok(val as f64)
166
+ }
167
+ CoreDType::F64 => {
168
+ let val: f64 = self.0.to_vec0().map_err(wrap_candle_err)?;
169
+ Ok(val)
170
+ }
171
+ _ => {
172
+ // For other dtypes, convert to F64 first
173
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
174
+ let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
175
+ Ok(val)
176
+ }
177
+ }
178
+ }
179
+
180
+ /// Gets the tensor's shape.
181
+ /// &RETURNS&: Tuple[int]
182
+ pub fn shape(&self) -> Vec<usize> {
183
+ self.0.dims().to_vec()
184
+ }
185
+
186
+ /// Gets the tensor's strides.
187
+ /// &RETURNS&: Tuple[int]
188
+ pub fn stride(&self) -> Vec<usize> {
189
+ self.0.stride().to_vec()
190
+ }
191
+
192
+ /// Gets the tensor's dtype.
193
+ /// &RETURNS&: DType
194
+ pub fn dtype(&self) -> DType {
195
+ DType(self.0.dtype())
196
+ }
197
+
198
+ /// Gets the tensor's device.
199
+ /// &RETURNS&: Device
200
+ pub fn device(&self) -> Device {
201
+ Device::from_device(self.0.device())
202
+ }
203
+
204
+ /// Gets the tensor's rank.
205
+ /// &RETURNS&: int
206
+ pub fn rank(&self) -> usize {
207
+ self.0.rank()
208
+ }
209
+
210
+ /// The number of elements stored in this tensor.
211
+ /// &RETURNS&: int
212
+ pub fn elem_count(&self) -> usize {
213
+ self.0.elem_count()
214
+ }
215
+
216
+ pub fn __repr__(&self) -> String {
217
+ format!("{}", self.0)
218
+ }
219
+
220
+ pub fn __str__(&self) -> String {
221
+ self.__repr__()
222
+ }
223
+
224
+ /// Performs the `sin` operation on the tensor.
225
+ /// &RETURNS&: Tensor
226
+ pub fn sin(&self) -> Result<Self> {
227
+ Ok(Tensor(self.0.sin().map_err(wrap_candle_err)?))
228
+ }
229
+
230
+ /// Performs the `cos` operation on the tensor.
231
+ /// &RETURNS&: Tensor
232
+ pub fn cos(&self) -> Result<Self> {
233
+ Ok(Tensor(self.0.cos().map_err(wrap_candle_err)?))
234
+ }
235
+
236
+ /// Performs the `log` operation on the tensor.
237
+ /// &RETURNS&: Tensor
238
+ pub fn log(&self) -> Result<Self> {
239
+ Ok(Tensor(self.0.log().map_err(wrap_candle_err)?))
240
+ }
241
+
242
+ /// Squares the tensor.
243
+ /// &RETURNS&: Tensor
244
+ pub fn sqr(&self) -> Result<Self> {
245
+ Ok(Tensor(self.0.sqr().map_err(wrap_candle_err)?))
246
+ }
247
+
248
+ /// Returns the mean along the specified axis.
249
+ /// @param axis [Integer, optional] The axis to reduce over (default: 0)
250
+ /// @return [Candle::Tensor]
251
+ pub fn mean(&self, axis: Option<i64>) -> Result<Self> {
252
+ let axis = axis.unwrap_or(0) as usize;
253
+ Ok(Tensor(self.0.mean(axis).map_err(wrap_candle_err)?))
254
+ }
255
+
256
+ /// Returns the sum along the specified axis.
257
+ /// @param axis [Integer, optional] The axis to reduce over (default: 0)
258
+ /// @return [Candle::Tensor]
259
+ pub fn sum(&self, axis: Option<i64>) -> Result<Self> {
260
+ let axis = axis.unwrap_or(0) as usize;
261
+ Ok(Tensor(self.0.sum(axis).map_err(wrap_candle_err)?))
262
+ }
263
+
264
+ /// Calculates the square root of the tensor.
265
+ /// &RETURNS&: Tensor
266
+ pub fn sqrt(&self) -> Result<Self> {
267
+ Ok(Tensor(self.0.sqrt().map_err(wrap_candle_err)?))
268
+ }
269
+
270
+ /// Get the `recip` of the tensor.
271
+ /// &RETURNS&: Tensor
272
+ pub fn recip(&self) -> Result<Self> {
273
+ Ok(Tensor(self.0.recip().map_err(wrap_candle_err)?))
274
+ }
275
+
276
+ /// Performs the `exp` operation on the tensor.
277
+ /// &RETURNS&: Tensor
278
+ pub fn exp(&self) -> Result<Self> {
279
+ Ok(Tensor(self.0.exp().map_err(wrap_candle_err)?))
280
+ }
281
+
282
+ /// Performs the `pow` operation on the tensor with the given exponent.
283
+ /// &RETURNS&: Tensor
284
+ pub fn powf(&self, p: f64) -> Result<Self> {
285
+ Ok(Tensor(self.0.powf(p).map_err(wrap_candle_err)?))
286
+ }
287
+
288
+ /// Select values for the input tensor at the target indexes across the specified dimension.
289
+ ///
290
+ /// The `indexes` is argument is an int tensor with a single dimension.
291
+ /// The output has the same number of dimension as the `self` input. The target dimension of
292
+ /// the output has length the length of `indexes` and the values are taken from `self` using
293
+ /// the index from `indexes`. Other dimensions have the same number of elements as the input
294
+ /// tensor.
295
+ /// &RETURNS&: Tensor
296
+ pub fn index_select(&self, rhs: &Self, dim: i64) -> Result<Self> {
297
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
298
+ Ok(Tensor(
299
+ self.0.index_select(rhs, dim).map_err(wrap_candle_err)?,
300
+ ))
301
+ }
302
+
303
+ /// Performs a matrix multiplication between the two tensors.
304
+ /// &RETURNS&: Tensor
305
+ pub fn matmul(&self, rhs: &Self) -> Result<Self> {
306
+ Ok(Tensor(self.0.matmul(rhs).map_err(wrap_candle_err)?))
307
+ }
308
+
309
+ /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
310
+ /// &RETURNS&: Tensor
311
+ pub fn broadcast_add(&self, rhs: &Self) -> Result<Self> {
312
+ Ok(Tensor(
313
+ self.0.broadcast_add(rhs).map_err(wrap_candle_err)?,
314
+ ))
315
+ }
316
+
317
+ /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
318
+ /// &RETURNS&: Tensor
319
+ pub fn broadcast_sub(&self, rhs: &Self) -> Result<Self> {
320
+ Ok(Tensor(
321
+ self.0.broadcast_sub(rhs).map_err(wrap_candle_err)?,
322
+ ))
323
+ }
324
+
325
+ /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
326
+ /// &RETURNS&: Tensor
327
+ pub fn broadcast_mul(&self, rhs: &Self) -> Result<Self> {
328
+ Ok(Tensor(
329
+ self.0.broadcast_mul(rhs).map_err(wrap_candle_err)?,
330
+ ))
331
+ }
332
+
333
+ /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
334
+ /// &RETURNS&: Tensor
335
+ pub fn broadcast_div(&self, rhs: &Self) -> Result<Self> {
336
+ Ok(Tensor(
337
+ self.0.broadcast_div(rhs).map_err(wrap_candle_err)?,
338
+ ))
339
+ }
340
+
341
+ /// Returns a tensor with the same shape as the input tensor, the values are taken from
342
+ /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
343
+ /// input tensor is equal to zero.
344
+ /// &RETURNS&: Tensor
345
+ pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
346
+ Ok(Tensor(
347
+ self.0
348
+ .where_cond(on_true, on_false)
349
+ .map_err(wrap_candle_err)?,
350
+ ))
351
+ }
352
+
353
+ /// Add two tensors.
354
+ /// &RETURNS&: Tensor
355
+ pub fn __add__(&self, rhs: &Tensor) -> Result<Self> {
356
+ Ok(Self(self.0.add(&rhs.0).map_err(wrap_candle_err)?))
357
+ }
358
+
359
+ /// Multiply two tensors.
360
+ /// &RETURNS&: Tensor
361
+ pub fn __mul__(&self, rhs: &Tensor) -> Result<Self> {
362
+ Ok(Self(self.0.mul(&rhs.0).map_err(wrap_candle_err)?))
363
+ }
364
+
365
+ /// Subtract two tensors.
366
+ /// &RETURNS&: Tensor
367
+ pub fn __sub__(&self, rhs: &Tensor) -> Result<Self> {
368
+ Ok(Self(self.0.sub(&rhs.0).map_err(wrap_candle_err)?))
369
+ }
370
+
371
+ /// Divide two tensors.
372
+ /// &RETURNS&: Tensor
373
+ /// Divides this tensor by another tensor or a scalar (Float/Integer).
374
+ /// @param rhs [Candle::Tensor, Float, or Integer]
375
+ /// @return [Candle::Tensor]
376
+ pub fn __truediv__(&self, rhs: magnus::Value) -> Result<Self> {
377
+ use magnus::TryConvert;
378
+ if let Ok(tensor) = <&Tensor>::try_convert(rhs) {
379
+ Ok(Self(self.0.broadcast_div(&tensor.0).map_err(wrap_candle_err)?))
380
+ } else if let Ok(f) = <f64>::try_convert(rhs) {
381
+ let scalar = CoreTensor::from_vec(vec![f as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
382
+ Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
383
+ } else if let Ok(i) = <i64>::try_convert(rhs) {
384
+ let scalar = CoreTensor::from_vec(vec![i as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
385
+ Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
386
+ } else {
387
+ Err(magnus::Error::new(Ruby::get().unwrap().exception_type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
388
+ }
389
+ }
390
+
391
+ /// Reshapes the tensor to the given shape.
392
+ /// &RETURNS&: Tensor
393
+ pub fn reshape(&self, shape: Vec<usize>) -> Result<Self> {
394
+ Ok(Tensor(self.0.reshape(shape).map_err(wrap_candle_err)?))
395
+ }
396
+
397
+ /// Broadcasts the tensor to the given shape.
398
+ /// &RETURNS&: Tensor
399
+ pub fn broadcast_as(&self, shape: Vec<usize>) -> Result<Self> {
400
+ Ok(Tensor(
401
+ self.0.broadcast_as(shape).map_err(wrap_candle_err)?,
402
+ ))
403
+ }
404
+
405
+ /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
406
+ /// &RETURNS&: Tensor
407
+ pub fn broadcast_left(&self, shape: Vec<usize>) -> Result<Self> {
408
+ Ok(Tensor(
409
+ self.0.broadcast_left(shape).map_err(wrap_candle_err)?,
410
+ ))
411
+ }
412
+
413
+ /// Creates a new tensor with the specified dimension removed if its size was one.
414
+ /// &RETURNS&: Tensor
415
+ pub fn squeeze(&self, dim: i64) -> Result<Self> {
416
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
417
+ Ok(Tensor(self.0.squeeze(dim).map_err(wrap_candle_err)?))
418
+ }
419
+
420
+ /// Creates a new tensor with a dimension of size one inserted at the specified position.
421
+ /// &RETURNS&: Tensor
422
+ pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
423
+ Ok(Tensor(self.0.unsqueeze(dim).map_err(wrap_candle_err)?))
424
+ }
425
+
426
+ /// Gets the value at the specified index.
427
+ /// &RETURNS&: Tensor
428
+ pub fn get(&self, index: i64) -> Result<Self> {
429
+ let index = actual_index(self, 0, index).map_err(wrap_candle_err)?;
430
+ Ok(Tensor(self.0.get(index).map_err(wrap_candle_err)?))
431
+ }
432
+
433
+ /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
434
+ /// &RETURNS&: Tensor
435
+ pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
436
+ Ok(Tensor(
437
+ self.0.transpose(dim1, dim2).map_err(wrap_candle_err)?,
438
+ ))
439
+ }
440
+
441
+ /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
442
+ /// ranges from `start` to `start + len`.
443
+ /// &RETURNS&: Tensor
444
+ pub fn narrow(&self, dim: i64, start: i64, len: usize) -> Result<Self> {
445
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
446
+ let start = actual_index(self, dim, start).map_err(wrap_candle_err)?;
447
+ Ok(Tensor(
448
+ self.0.narrow(dim, start, len).map_err(wrap_candle_err)?,
449
+ ))
450
+ }
451
+
452
+ /// Returns the indices of the maximum value(s) across the selected dimension.
453
+ /// &RETURNS&: Tensor
454
+ pub fn argmax_keepdim(&self, dim: i64) -> Result<Self> {
455
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
456
+ Ok(Tensor(
457
+ self.0.argmax_keepdim(dim).map_err(wrap_candle_err)?,
458
+ ))
459
+ }
460
+
461
+ /// Returns the indices of the minimum value(s) across the selected dimension.
462
+ /// &RETURNS&: Tensor
463
+ pub fn argmin_keepdim(&self, dim: i64) -> Result<Self> {
464
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
465
+ Ok(Tensor(
466
+ self.0.argmin_keepdim(dim).map_err(wrap_candle_err)?,
467
+ ))
468
+ }
469
+
470
+ /// Gathers the maximum value across the selected dimension.
471
+ /// &RETURNS&: Tensor
472
+ pub fn max_keepdim(&self, dim: i64) -> Result<Self> {
473
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
474
+ Ok(Tensor(self.0.max_keepdim(dim).map_err(wrap_candle_err)?))
475
+ }
476
+
477
+ /// Gathers the minimum value across the selected dimension.
478
+ /// &RETURNS&: Tensor
479
+ pub fn min_keepdim(&self, dim: i64) -> Result<Self> {
480
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
481
+ Ok(Tensor(self.0.min_keepdim(dim).map_err(wrap_candle_err)?))
482
+ }
483
+
484
+ // fn eq(&self, rhs: &Self) -> Result<Self> {
485
+ // Ok(Tensor(self.0.eq(rhs).map_err(wrap_candle_err)?))
486
+ // }
487
+
488
+ // fn ne(&self, rhs: &Self) -> Result<Self> {
489
+ // Ok(Tensor(self.0.ne(rhs).map_err(wrap_candle_err)?))
490
+ // }
491
+
492
+ // fn lt(&self, rhs: &Self) -> Result<Self> {
493
+ // Ok(Tensor(self.0.lt(rhs).map_err(wrap_candle_err)?))
494
+ // }
495
+
496
+ // fn gt(&self, rhs: &Self) -> Result<Self> {
497
+ // Ok(Tensor(self.0.gt(rhs).map_err(wrap_candle_err)?))
498
+ // }
499
+
500
+ // fn ge(&self, rhs: &Self) -> Result<Self> {
501
+ // Ok(Tensor(self.0.ge(rhs).map_err(wrap_candle_err)?))
502
+ // }
503
+
504
+ // fn le(&self, rhs: &Self) -> Result<Self> {
505
+ // Ok(Tensor(self.0.le(rhs).map_err(wrap_candle_err)?))
506
+ // }
507
+
508
+ /// Returns the sum of the tensor.
509
+ /// &RETURNS&: Tensor
510
+ pub fn sum_all(&self) -> Result<Self> {
511
+ Ok(Tensor(self.0.sum_all().map_err(wrap_candle_err)?))
512
+ }
513
+
514
+ /// Returns the mean of the tensor.
515
+ /// &RETURNS&: Tensor
516
+ pub fn mean_all(&self) -> Result<Self> {
517
+ let elements = self.0.elem_count();
518
+ let sum = self.0.sum_all().map_err(wrap_candle_err)?;
519
+ let mean = (sum / elements as f64).map_err(wrap_candle_err)?;
520
+ Ok(Tensor(mean))
521
+ }
522
+
523
+ /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
524
+ /// &RETURNS&: Tensor
525
+ pub fn flatten_from(&self, dim: i64) -> Result<Self> {
526
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
527
+ Ok(Tensor(self.0.flatten_from(dim).map_err(wrap_candle_err)?))
528
+ }
529
+
530
+ ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
531
+ /// &RETURNS&: Tensor
532
+ pub fn flatten_to(&self, dim: i64) -> Result<Self> {
533
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
534
+ Ok(Tensor(self.0.flatten_to(dim).map_err(wrap_candle_err)?))
535
+ }
536
+
537
+ /// Flattens the tensor into a 1D tensor.
538
+ /// &RETURNS&: Tensor
539
+ pub fn flatten_all(&self) -> Result<Self> {
540
+ Ok(Tensor(self.0.flatten_all().map_err(wrap_candle_err)?))
541
+ }
542
+
543
+ /// Transposes the tensor.
544
+ /// &RETURNS&: Tensor
545
+ pub fn t(&self) -> Result<Self> {
546
+ Ok(Tensor(self.0.t().map_err(wrap_candle_err)?))
547
+ }
548
+
549
+ /// Makes the tensor contiguous in memory.
550
+ /// &RETURNS&: Tensor
551
+ pub fn contiguous(&self) -> Result<Self> {
552
+ Ok(Tensor(self.0.contiguous().map_err(wrap_candle_err)?))
553
+ }
554
+
555
+ /// Returns true if the tensor is contiguous in C order.
556
+ /// &RETURNS&: bool
557
+ pub fn is_contiguous(&self) -> bool {
558
+ self.0.is_contiguous()
559
+ }
560
+
561
+ /// Returns true if the tensor is contiguous in Fortran order.
562
+ /// &RETURNS&: bool
563
+ pub fn is_fortran_contiguous(&self) -> bool {
564
+ self.0.is_fortran_contiguous()
565
+ }
566
+
567
+ /// Detach the tensor from the computation graph.
568
+ /// &RETURNS&: Tensor
569
+ pub fn detach(&self) -> Result<Self> {
570
+ Ok(Tensor(self.0.detach()))
571
+ }
572
+
573
+ /// Returns a copy of the tensor.
574
+ /// &RETURNS&: Tensor
575
+ pub fn copy(&self) -> Result<Self> {
576
+ Ok(Tensor(self.0.copy().map_err(wrap_candle_err)?))
577
+ }
578
+
579
+ /// Convert the tensor to a new dtype.
580
+ /// &RETURNS&: Tensor
581
+ pub fn to_dtype(&self, dtype: magnus::Symbol) -> Result<Self> {
582
+ let dtype = DType::from_rbobject(dtype)?;
583
+ Ok(Tensor(self.0.to_dtype(dtype.0).map_err(wrap_candle_err)?))
584
+ }
585
+
586
+ /// Move the tensor to a new device.
587
+ /// &RETURNS&: Tensor
588
+ pub fn to_device(&self, device: Device) -> Result<Self> {
589
+ let device = device.as_device()?;
590
+ Ok(Tensor(
591
+ self.0.to_device(&device).map_err(wrap_candle_err)?,
592
+ ))
593
+ }
594
+ }
595
+
596
+ impl Tensor {
597
+ // fn cat(tensors: Vec<Tensor>, dim: i64) -> Result<Tensor> {
598
+ // if tensors.is_empty() {
599
+ // return Err(Error::new(
600
+ // magnus::exception::arg_error(),
601
+ // "empty input to cat",
602
+ // ));
603
+ // }
604
+ // let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_candle_err)?;
605
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
606
+ // let tensor = CoreTensor::cat(&tensors, dim).map_err(wrap_candle_err)?;
607
+ // Ok(Tensor(tensor))
608
+ // }
609
+
610
+ // fn stack(tensors: Vec<Tensor>, dim: usize) -> Result<Self> {
611
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
612
+ // let tensor = CoreTensor::stack(&tensors, dim).map_err(wrap_candle_err)?;
613
+ // Ok(Self(tensor))
614
+ // }
615
+
616
+ /// Creates a new tensor with random values.
617
+ /// &RETURNS&: Tensor
618
+ pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
619
+ let device = device.unwrap_or(Device::best()).as_device()?;
620
+ Ok(Self(
621
+ CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
622
+ ))
623
+ }
624
+
625
+ /// Creates a new tensor with random values from a normal distribution.
626
+ /// &RETURNS&: Tensor
627
+ pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
628
+ let device = device.unwrap_or(Device::best()).as_device()?;
629
+ Ok(Self(
630
+ CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
631
+ ))
632
+ }
633
+
634
+ /// Creates a new tensor filled with ones.
635
+ /// &RETURNS&: Tensor
636
+ pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
637
+ let device = device.unwrap_or(Device::best()).as_device()?;
638
+ Ok(Self(
639
+ CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
640
+ ))
641
+ }
642
+ /// Creates a new tensor filled with zeros.
643
+ /// &RETURNS&: Tensor
644
+ pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
645
+ let device = device.unwrap_or(Device::best()).as_device()?;
646
+ Ok(Self(
647
+ CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
648
+ ))
649
+ }
650
+ }
651
+
652
+ pub fn init(rb_candle: RModule) -> Result<()> {
653
+ let ruby = Ruby::get().unwrap();
654
+ let rb_tensor = rb_candle.define_class("Tensor", ruby.class_object())?;
655
+ rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
656
+ // rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
657
+ // rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
658
+ rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
659
+ rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
660
+ rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
661
+ rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
662
+ rb_tensor.define_method("values", method!(Tensor::values, 0))?;
663
+ rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
664
+ rb_tensor.define_method("item", method!(Tensor::item, 0))?;
665
+ rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
666
+ rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
667
+ rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
668
+ rb_tensor.define_method("device", method!(Tensor::device, 0))?;
669
+ rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
670
+ rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
671
+ rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
672
+ rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
673
+ rb_tensor.define_method("log", method!(Tensor::log, 0))?;
674
+ rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
675
+ rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
676
+ rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
677
+ rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
678
+ rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
679
+ rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
680
+ rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
681
+ rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
682
+ rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
683
+ rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
684
+ rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
685
+ rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
686
+ rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
687
+ rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
688
+ rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
689
+ rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
690
+ rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
691
+ rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
692
+ rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
693
+ rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
694
+ rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
695
+ rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
696
+ rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
697
+ rb_tensor.define_method("get", method!(Tensor::get, 1))?;
698
+ rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
699
+ rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
700
+ rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
701
+ rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
702
+ rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
703
+ rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
704
+ rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
705
+ // rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
706
+ // rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
707
+ // rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
708
+ // rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
709
+ // rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
710
+ // rb_tensor.define_method("le", method!(Tensor::le, 1))?;
711
+ rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
712
+ rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
713
+ rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
714
+ rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
715
+ rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
716
+ rb_tensor.define_method("t", method!(Tensor::t, 0))?;
717
+ rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
718
+ rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
719
+ rb_tensor.define_method(
720
+ "is_fortran_contiguous",
721
+ method!(Tensor::is_fortran_contiguous, 0),
722
+ )?;
723
+ rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
724
+ rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
725
+ rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
726
+ rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
727
+ rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
728
+ rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
729
+ Ok(())
730
+ }
731
+