red-candle 1.0.0.pre.1 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,654 @@
1
+ use magnus::prelude::*;
2
+ use magnus::{function, method, class, RModule, Error, Module, Object};
3
+
4
+ use crate::ruby::{
5
+ errors::wrap_candle_err,
6
+ utils::{actual_dim, actual_index},
7
+ };
8
+ use crate::ruby::{DType, Device, Result as RbResult};
9
+ use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
10
+
11
+ #[derive(Clone, Debug)]
12
+ #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
13
+ /// A `candle` tensor.
14
+ pub struct Tensor(pub CoreTensor);
15
+
16
+ impl std::ops::Deref for Tensor {
17
+ type Target = CoreTensor;
18
+
19
+ fn deref(&self) -> &Self::Target {
20
+ &self.0
21
+ }
22
+ }
23
+
24
+ impl Tensor {
25
+ pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> RbResult<Self> {
26
+ let dtype = dtype
27
+ .map(|dtype| DType::from_rbobject(dtype))
28
+ .unwrap_or(Ok(DType(CoreDType::F32)))?;
29
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
30
+ // FIXME: Do not use `to_f64` here.
31
+ let array = array
32
+ .into_iter()
33
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
34
+ .collect::<RbResult<Vec<_>>>()?;
35
+ Ok(Self(
36
+ CoreTensor::new(array.as_slice(), &device)
37
+ .map_err(wrap_candle_err)?
38
+ .to_dtype(dtype.0)
39
+ .map_err(wrap_candle_err)?,
40
+ ))
41
+ }
42
+
43
+ pub fn values(&self) -> RbResult<Vec<f64>> {
44
+ let values = self
45
+ .0
46
+ .to_dtype(CoreDType::F64)
47
+ .map_err(wrap_candle_err)?
48
+ .flatten_all()
49
+ .map_err(wrap_candle_err)?
50
+ .to_vec1()
51
+ .map_err(wrap_candle_err)?;
52
+ Ok(values)
53
+ }
54
+
55
+ /// Get values as f32 without dtype conversion
56
+ pub fn values_f32(&self) -> RbResult<Vec<f32>> {
57
+ match self.0.dtype() {
58
+ CoreDType::F32 => {
59
+ let values = self
60
+ .0
61
+ .flatten_all()
62
+ .map_err(wrap_candle_err)?
63
+ .to_vec1()
64
+ .map_err(wrap_candle_err)?;
65
+ Ok(values)
66
+ }
67
+ _ => Err(magnus::Error::new(
68
+ magnus::exception::runtime_error(),
69
+ "Tensor must be F32 dtype for values_f32",
70
+ )),
71
+ }
72
+ }
73
+
74
+ /// Get a single scalar value from a rank-0 tensor
75
+ pub fn item(&self) -> RbResult<f64> {
76
+ if self.0.rank() != 0 {
77
+ return Err(magnus::Error::new(
78
+ magnus::exception::runtime_error(),
79
+ format!("item() can only be called on scalar tensors (rank 0), but tensor has rank {}", self.0.rank()),
80
+ ));
81
+ }
82
+
83
+ // Try to get the value based on dtype
84
+ match self.0.dtype() {
85
+ CoreDType::F32 => {
86
+ let val: f32 = self.0.to_vec0().map_err(wrap_candle_err)?;
87
+ Ok(val as f64)
88
+ }
89
+ CoreDType::F64 => {
90
+ let val: f64 = self.0.to_vec0().map_err(wrap_candle_err)?;
91
+ Ok(val)
92
+ }
93
+ _ => {
94
+ // For other dtypes, convert to F64 first
95
+ let val: f64 = self.0
96
+ .to_dtype(CoreDType::F64)
97
+ .map_err(wrap_candle_err)?
98
+ .to_vec0()
99
+ .map_err(wrap_candle_err)?;
100
+ Ok(val)
101
+ }
102
+ }
103
+ }
104
+
105
+ /// Gets the tensor's shape.
106
+ /// &RETURNS&: Tuple[int]
107
+ pub fn shape(&self) -> Vec<usize> {
108
+ self.0.dims().to_vec()
109
+ }
110
+
111
+ /// Gets the tensor's strides.
112
+ /// &RETURNS&: Tuple[int]
113
+ pub fn stride(&self) -> Vec<usize> {
114
+ self.0.stride().to_vec()
115
+ }
116
+
117
+ /// Gets the tensor's dtype.
118
+ /// &RETURNS&: DType
119
+ pub fn dtype(&self) -> DType {
120
+ DType(self.0.dtype())
121
+ }
122
+
123
+ /// Gets the tensor's device.
124
+ /// &RETURNS&: Device
125
+ pub fn device(&self) -> Device {
126
+ Device::from_device(self.0.device())
127
+ }
128
+
129
+ /// Gets the tensor's rank.
130
+ /// &RETURNS&: int
131
+ pub fn rank(&self) -> usize {
132
+ self.0.rank()
133
+ }
134
+
135
+ /// The number of elements stored in this tensor.
136
+ /// &RETURNS&: int
137
+ pub fn elem_count(&self) -> usize {
138
+ self.0.elem_count()
139
+ }
140
+
141
+ pub fn __repr__(&self) -> String {
142
+ format!("{}", self.0)
143
+ }
144
+
145
+ pub fn __str__(&self) -> String {
146
+ self.__repr__()
147
+ }
148
+
149
+ /// Performs the `sin` operation on the tensor.
150
+ /// &RETURNS&: Tensor
151
+ pub fn sin(&self) -> RbResult<Self> {
152
+ Ok(Tensor(self.0.sin().map_err(wrap_candle_err)?))
153
+ }
154
+
155
+ /// Performs the `cos` operation on the tensor.
156
+ /// &RETURNS&: Tensor
157
+ pub fn cos(&self) -> RbResult<Self> {
158
+ Ok(Tensor(self.0.cos().map_err(wrap_candle_err)?))
159
+ }
160
+
161
+ /// Performs the `log` operation on the tensor.
162
+ /// &RETURNS&: Tensor
163
+ pub fn log(&self) -> RbResult<Self> {
164
+ Ok(Tensor(self.0.log().map_err(wrap_candle_err)?))
165
+ }
166
+
167
+ /// Squares the tensor.
168
+ /// &RETURNS&: Tensor
169
+ pub fn sqr(&self) -> RbResult<Self> {
170
+ Ok(Tensor(self.0.sqr().map_err(wrap_candle_err)?))
171
+ }
172
+
173
+ /// Returns the mean along the specified axis.
174
+ /// @param axis [Integer, optional] The axis to reduce over (default: 0)
175
+ /// @return [Candle::Tensor]
176
+ pub fn mean(&self, axis: Option<i64>) -> RbResult<Self> {
177
+ let axis = axis.unwrap_or(0) as usize;
178
+ Ok(Tensor(self.0.mean(axis).map_err(wrap_candle_err)?))
179
+ }
180
+
181
+ /// Returns the sum along the specified axis.
182
+ /// @param axis [Integer, optional] The axis to reduce over (default: 0)
183
+ /// @return [Candle::Tensor]
184
+ pub fn sum(&self, axis: Option<i64>) -> RbResult<Self> {
185
+ let axis = axis.unwrap_or(0) as usize;
186
+ Ok(Tensor(self.0.sum(axis).map_err(wrap_candle_err)?))
187
+ }
188
+
189
+ /// Calculates the square root of the tensor.
190
+ /// &RETURNS&: Tensor
191
+ pub fn sqrt(&self) -> RbResult<Self> {
192
+ Ok(Tensor(self.0.sqrt().map_err(wrap_candle_err)?))
193
+ }
194
+
195
+ /// Get the `recip` of the tensor.
196
+ /// &RETURNS&: Tensor
197
+ pub fn recip(&self) -> RbResult<Self> {
198
+ Ok(Tensor(self.0.recip().map_err(wrap_candle_err)?))
199
+ }
200
+
201
+ /// Performs the `exp` operation on the tensor.
202
+ /// &RETURNS&: Tensor
203
+ pub fn exp(&self) -> RbResult<Self> {
204
+ Ok(Tensor(self.0.exp().map_err(wrap_candle_err)?))
205
+ }
206
+
207
+ /// Performs the `pow` operation on the tensor with the given exponent.
208
+ /// &RETURNS&: Tensor
209
+ pub fn powf(&self, p: f64) -> RbResult<Self> {
210
+ Ok(Tensor(self.0.powf(p).map_err(wrap_candle_err)?))
211
+ }
212
+
213
+ /// Select values for the input tensor at the target indexes across the specified dimension.
214
+ ///
215
+ /// The `indexes` is argument is an int tensor with a single dimension.
216
+ /// The output has the same number of dimension as the `self` input. The target dimension of
217
+ /// the output has length the length of `indexes` and the values are taken from `self` using
218
+ /// the index from `indexes`. Other dimensions have the same number of elements as the input
219
+ /// tensor.
220
+ /// &RETURNS&: Tensor
221
+ pub fn index_select(&self, rhs: &Self, dim: i64) -> RbResult<Self> {
222
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
223
+ Ok(Tensor(
224
+ self.0.index_select(rhs, dim).map_err(wrap_candle_err)?,
225
+ ))
226
+ }
227
+
228
+ /// Performs a matrix multiplication between the two tensors.
229
+ /// &RETURNS&: Tensor
230
+ pub fn matmul(&self, rhs: &Self) -> RbResult<Self> {
231
+ Ok(Tensor(self.0.matmul(rhs).map_err(wrap_candle_err)?))
232
+ }
233
+
234
+ /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
235
+ /// &RETURNS&: Tensor
236
+ pub fn broadcast_add(&self, rhs: &Self) -> RbResult<Self> {
237
+ Ok(Tensor(
238
+ self.0.broadcast_add(rhs).map_err(wrap_candle_err)?,
239
+ ))
240
+ }
241
+
242
+ /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
243
+ /// &RETURNS&: Tensor
244
+ pub fn broadcast_sub(&self, rhs: &Self) -> RbResult<Self> {
245
+ Ok(Tensor(
246
+ self.0.broadcast_sub(rhs).map_err(wrap_candle_err)?,
247
+ ))
248
+ }
249
+
250
+ /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
251
+ /// &RETURNS&: Tensor
252
+ pub fn broadcast_mul(&self, rhs: &Self) -> RbResult<Self> {
253
+ Ok(Tensor(
254
+ self.0.broadcast_mul(rhs).map_err(wrap_candle_err)?,
255
+ ))
256
+ }
257
+
258
+ /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
259
+ /// &RETURNS&: Tensor
260
+ pub fn broadcast_div(&self, rhs: &Self) -> RbResult<Self> {
261
+ Ok(Tensor(
262
+ self.0.broadcast_div(rhs).map_err(wrap_candle_err)?,
263
+ ))
264
+ }
265
+
266
+ /// Returns a tensor with the same shape as the input tensor, the values are taken from
267
+ /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
268
+ /// input tensor is equal to zero.
269
+ /// &RETURNS&: Tensor
270
+ pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> RbResult<Self> {
271
+ Ok(Tensor(
272
+ self.0
273
+ .where_cond(on_true, on_false)
274
+ .map_err(wrap_candle_err)?,
275
+ ))
276
+ }
277
+
278
+ /// Add two tensors.
279
+ /// &RETURNS&: Tensor
280
+ pub fn __add__(&self, rhs: &Tensor) -> RbResult<Self> {
281
+ Ok(Self(self.0.add(&rhs.0).map_err(wrap_candle_err)?))
282
+ }
283
+
284
+ /// Multiply two tensors.
285
+ /// &RETURNS&: Tensor
286
+ pub fn __mul__(&self, rhs: &Tensor) -> RbResult<Self> {
287
+ Ok(Self(self.0.mul(&rhs.0).map_err(wrap_candle_err)?))
288
+ }
289
+
290
+ /// Subtract two tensors.
291
+ /// &RETURNS&: Tensor
292
+ pub fn __sub__(&self, rhs: &Tensor) -> RbResult<Self> {
293
+ Ok(Self(self.0.sub(&rhs.0).map_err(wrap_candle_err)?))
294
+ }
295
+
296
+ /// Divide two tensors.
297
+ /// &RETURNS&: Tensor
298
+ /// Divides this tensor by another tensor or a scalar (Float/Integer).
299
+ /// @param rhs [Candle::Tensor, Float, or Integer]
300
+ /// @return [Candle::Tensor]
301
+ pub fn __truediv__(&self, rhs: magnus::Value) -> RbResult<Self> {
302
+ use magnus::TryConvert;
303
+ if let Ok(tensor) = <&Tensor>::try_convert(rhs) {
304
+ Ok(Self(self.0.broadcast_div(&tensor.0).map_err(wrap_candle_err)?))
305
+ } else if let Ok(f) = <f64>::try_convert(rhs) {
306
+ let scalar = CoreTensor::from_vec(vec![f as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
307
+ Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
308
+ } else if let Ok(i) = <i64>::try_convert(rhs) {
309
+ let scalar = CoreTensor::from_vec(vec![i as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
310
+ Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
311
+ } else {
312
+ Err(magnus::Error::new(magnus::exception::type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
313
+ }
314
+ }
315
+
316
+ /// Reshapes the tensor to the given shape.
317
+ /// &RETURNS&: Tensor
318
+ pub fn reshape(&self, shape: Vec<usize>) -> RbResult<Self> {
319
+ Ok(Tensor(self.0.reshape(shape).map_err(wrap_candle_err)?))
320
+ }
321
+
322
+ /// Broadcasts the tensor to the given shape.
323
+ /// &RETURNS&: Tensor
324
+ pub fn broadcast_as(&self, shape: Vec<usize>) -> RbResult<Self> {
325
+ Ok(Tensor(
326
+ self.0.broadcast_as(shape).map_err(wrap_candle_err)?,
327
+ ))
328
+ }
329
+
330
+ /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
331
+ /// &RETURNS&: Tensor
332
+ pub fn broadcast_left(&self, shape: Vec<usize>) -> RbResult<Self> {
333
+ Ok(Tensor(
334
+ self.0.broadcast_left(shape).map_err(wrap_candle_err)?,
335
+ ))
336
+ }
337
+
338
+ /// Creates a new tensor with the specified dimension removed if its size was one.
339
+ /// &RETURNS&: Tensor
340
+ pub fn squeeze(&self, dim: i64) -> RbResult<Self> {
341
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
342
+ Ok(Tensor(self.0.squeeze(dim).map_err(wrap_candle_err)?))
343
+ }
344
+
345
+ /// Creates a new tensor with a dimension of size one inserted at the specified position.
346
+ /// &RETURNS&: Tensor
347
+ pub fn unsqueeze(&self, dim: usize) -> RbResult<Self> {
348
+ Ok(Tensor(self.0.unsqueeze(dim).map_err(wrap_candle_err)?))
349
+ }
350
+
351
+ /// Gets the value at the specified index.
352
+ /// &RETURNS&: Tensor
353
+ pub fn get(&self, index: i64) -> RbResult<Self> {
354
+ let index = actual_index(self, 0, index).map_err(wrap_candle_err)?;
355
+ Ok(Tensor(self.0.get(index).map_err(wrap_candle_err)?))
356
+ }
357
+
358
+ /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
359
+ /// &RETURNS&: Tensor
360
+ pub fn transpose(&self, dim1: usize, dim2: usize) -> RbResult<Self> {
361
+ Ok(Tensor(
362
+ self.0.transpose(dim1, dim2).map_err(wrap_candle_err)?,
363
+ ))
364
+ }
365
+
366
+ /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
367
+ /// ranges from `start` to `start + len`.
368
+ /// &RETURNS&: Tensor
369
+ pub fn narrow(&self, dim: i64, start: i64, len: usize) -> RbResult<Self> {
370
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
371
+ let start = actual_index(self, dim, start).map_err(wrap_candle_err)?;
372
+ Ok(Tensor(
373
+ self.0.narrow(dim, start, len).map_err(wrap_candle_err)?,
374
+ ))
375
+ }
376
+
377
+ /// Returns the indices of the maximum value(s) across the selected dimension.
378
+ /// &RETURNS&: Tensor
379
+ pub fn argmax_keepdim(&self, dim: i64) -> RbResult<Self> {
380
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
381
+ Ok(Tensor(
382
+ self.0.argmax_keepdim(dim).map_err(wrap_candle_err)?,
383
+ ))
384
+ }
385
+
386
+ /// Returns the indices of the minimum value(s) across the selected dimension.
387
+ /// &RETURNS&: Tensor
388
+ pub fn argmin_keepdim(&self, dim: i64) -> RbResult<Self> {
389
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
390
+ Ok(Tensor(
391
+ self.0.argmin_keepdim(dim).map_err(wrap_candle_err)?,
392
+ ))
393
+ }
394
+
395
+ /// Gathers the maximum value across the selected dimension.
396
+ /// &RETURNS&: Tensor
397
+ pub fn max_keepdim(&self, dim: i64) -> RbResult<Self> {
398
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
399
+ Ok(Tensor(self.0.max_keepdim(dim).map_err(wrap_candle_err)?))
400
+ }
401
+
402
+ /// Gathers the minimum value across the selected dimension.
403
+ /// &RETURNS&: Tensor
404
+ pub fn min_keepdim(&self, dim: i64) -> RbResult<Self> {
405
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
406
+ Ok(Tensor(self.0.min_keepdim(dim).map_err(wrap_candle_err)?))
407
+ }
408
+
409
+ // fn eq(&self, rhs: &Self) -> RbResult<Self> {
410
+ // Ok(Tensor(self.0.eq(rhs).map_err(wrap_candle_err)?))
411
+ // }
412
+
413
+ // fn ne(&self, rhs: &Self) -> RbResult<Self> {
414
+ // Ok(Tensor(self.0.ne(rhs).map_err(wrap_candle_err)?))
415
+ // }
416
+
417
+ // fn lt(&self, rhs: &Self) -> RbResult<Self> {
418
+ // Ok(Tensor(self.0.lt(rhs).map_err(wrap_candle_err)?))
419
+ // }
420
+
421
+ // fn gt(&self, rhs: &Self) -> RbResult<Self> {
422
+ // Ok(Tensor(self.0.gt(rhs).map_err(wrap_candle_err)?))
423
+ // }
424
+
425
+ // fn ge(&self, rhs: &Self) -> RbResult<Self> {
426
+ // Ok(Tensor(self.0.ge(rhs).map_err(wrap_candle_err)?))
427
+ // }
428
+
429
+ // fn le(&self, rhs: &Self) -> RbResult<Self> {
430
+ // Ok(Tensor(self.0.le(rhs).map_err(wrap_candle_err)?))
431
+ // }
432
+
433
+ /// Returns the sum of the tensor.
434
+ /// &RETURNS&: Tensor
435
+ pub fn sum_all(&self) -> RbResult<Self> {
436
+ Ok(Tensor(self.0.sum_all().map_err(wrap_candle_err)?))
437
+ }
438
+
439
+ /// Returns the mean of the tensor.
440
+ /// &RETURNS&: Tensor
441
+ pub fn mean_all(&self) -> RbResult<Self> {
442
+ let elements = self.0.elem_count();
443
+ let sum = self.0.sum_all().map_err(wrap_candle_err)?;
444
+ let mean = (sum / elements as f64).map_err(wrap_candle_err)?;
445
+ Ok(Tensor(mean))
446
+ }
447
+
448
+ /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
449
+ /// &RETURNS&: Tensor
450
+ pub fn flatten_from(&self, dim: i64) -> RbResult<Self> {
451
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
452
+ Ok(Tensor(self.0.flatten_from(dim).map_err(wrap_candle_err)?))
453
+ }
454
+
455
+ ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
456
+ /// &RETURNS&: Tensor
457
+ pub fn flatten_to(&self, dim: i64) -> RbResult<Self> {
458
+ let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
459
+ Ok(Tensor(self.0.flatten_to(dim).map_err(wrap_candle_err)?))
460
+ }
461
+
462
+ /// Flattens the tensor into a 1D tensor.
463
+ /// &RETURNS&: Tensor
464
+ pub fn flatten_all(&self) -> RbResult<Self> {
465
+ Ok(Tensor(self.0.flatten_all().map_err(wrap_candle_err)?))
466
+ }
467
+
468
+ /// Transposes the tensor.
469
+ /// &RETURNS&: Tensor
470
+ pub fn t(&self) -> RbResult<Self> {
471
+ Ok(Tensor(self.0.t().map_err(wrap_candle_err)?))
472
+ }
473
+
474
+ /// Makes the tensor contiguous in memory.
475
+ /// &RETURNS&: Tensor
476
+ pub fn contiguous(&self) -> RbResult<Self> {
477
+ Ok(Tensor(self.0.contiguous().map_err(wrap_candle_err)?))
478
+ }
479
+
480
+ /// Returns true if the tensor is contiguous in C order.
481
+ /// &RETURNS&: bool
482
+ pub fn is_contiguous(&self) -> bool {
483
+ self.0.is_contiguous()
484
+ }
485
+
486
+ /// Returns true if the tensor is contiguous in Fortran order.
487
+ /// &RETURNS&: bool
488
+ pub fn is_fortran_contiguous(&self) -> bool {
489
+ self.0.is_fortran_contiguous()
490
+ }
491
+
492
+ /// Detach the tensor from the computation graph.
493
+ /// &RETURNS&: Tensor
494
+ pub fn detach(&self) -> RbResult<Self> {
495
+ Ok(Tensor(self.0.detach()))
496
+ }
497
+
498
+ /// Returns a copy of the tensor.
499
+ /// &RETURNS&: Tensor
500
+ pub fn copy(&self) -> RbResult<Self> {
501
+ Ok(Tensor(self.0.copy().map_err(wrap_candle_err)?))
502
+ }
503
+
504
+ /// Convert the tensor to a new dtype.
505
+ /// &RETURNS&: Tensor
506
+ pub fn to_dtype(&self, dtype: magnus::Symbol) -> RbResult<Self> {
507
+ let dtype = DType::from_rbobject(dtype)?;
508
+ Ok(Tensor(self.0.to_dtype(dtype.0).map_err(wrap_candle_err)?))
509
+ }
510
+
511
+ /// Move the tensor to a new device.
512
+ /// &RETURNS&: Tensor
513
+ pub fn to_device(&self, device: Device) -> RbResult<Self> {
514
+ let device = device.as_device()?;
515
+ Ok(Tensor(
516
+ self.0.to_device(&device).map_err(wrap_candle_err)?,
517
+ ))
518
+ }
519
+ }
520
+
521
+ impl Tensor {
522
+ // fn cat(tensors: Vec<RbTensor>, dim: i64) -> RbResult<RbTensor> {
523
+ // if tensors.is_empty() {
524
+ // return Err(Error::new(
525
+ // magnus::exception::arg_error(),
526
+ // "empty input to cat",
527
+ // ));
528
+ // }
529
+ // let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_candle_err)?;
530
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
531
+ // let tensor = Tensor::cat(&tensors, dim).map_err(wrap_candle_err)?;
532
+ // Ok(Tensor(tensor))
533
+ // }
534
+
535
+ // fn stack(tensors: Vec<RbTensor>, dim: usize) -> RbResult<Self> {
536
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
537
+ // let tensor = Tensor::stack(&tensors, dim).map_err(wrap_candle_err)?;
538
+ // Ok(Self(tensor))
539
+ // }
540
+
541
+ /// Creates a new tensor with random values.
542
+ /// &RETURNS&: Tensor
543
+ pub fn rand(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
544
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
545
+ Ok(Self(
546
+ CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
547
+ ))
548
+ }
549
+
550
+ /// Creates a new tensor with random values from a normal distribution.
551
+ /// &RETURNS&: Tensor
552
+ pub fn randn(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
553
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
554
+ Ok(Self(
555
+ CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
556
+ ))
557
+ }
558
+
559
+ /// Creates a new tensor filled with ones.
560
+ /// &RETURNS&: Tensor
561
+ pub fn ones(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
562
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
563
+ Ok(Self(
564
+ CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
565
+ ))
566
+ }
567
+ /// Creates a new tensor filled with zeros.
568
+ /// &RETURNS&: Tensor
569
+ pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
570
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
571
+ Ok(Self(
572
+ CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
573
+ ))
574
+ }
575
+ }
576
+
577
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
578
+ let rb_tensor = rb_candle.define_class("Tensor", class::object())?;
579
+ rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
580
+ // rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
581
+ // rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
582
+ rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
583
+ rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
584
+ rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
585
+ rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
586
+ rb_tensor.define_method("values", method!(Tensor::values, 0))?;
587
+ rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
588
+ rb_tensor.define_method("item", method!(Tensor::item, 0))?;
589
+ rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
590
+ rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
591
+ rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
592
+ rb_tensor.define_method("device", method!(Tensor::device, 0))?;
593
+ rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
594
+ rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
595
+ rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
596
+ rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
597
+ rb_tensor.define_method("log", method!(Tensor::log, 0))?;
598
+ rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
599
+ rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
600
+ rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
601
+ rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
602
+ rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
603
+ rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
604
+ rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
605
+ rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
606
+ rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
607
+ rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
608
+ rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
609
+ rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
610
+ rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
611
+ rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
612
+ rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
613
+ rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
614
+ rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
615
+ rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
616
+ rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
617
+ rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
618
+ rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
619
+ rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
620
+ rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
621
+ rb_tensor.define_method("get", method!(Tensor::get, 1))?;
622
+ rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
623
+ rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
624
+ rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
625
+ rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
626
+ rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
627
+ rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
628
+ rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
629
+ // rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
630
+ // rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
631
+ // rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
632
+ // rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
633
+ // rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
634
+ // rb_tensor.define_method("le", method!(Tensor::le, 1))?;
635
+ rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
636
+ rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
637
+ rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
638
+ rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
639
+ rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
640
+ rb_tensor.define_method("t", method!(Tensor::t, 0))?;
641
+ rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
642
+ rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
643
+ rb_tensor.define_method(
644
+ "is_fortran_contiguous",
645
+ method!(Tensor::is_fortran_contiguous, 0),
646
+ )?;
647
+ rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
648
+ rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
649
+ rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
650
+ rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
651
+ rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
652
+ rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
653
+ Ok(())
654
+ }