red-candle 1.0.0.pre.1 → 1.0.0.pre.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/candle/build.rs +116 -0
- data/ext/candle/src/lib.rs +6 -96
- data/ext/candle/src/llm/generation_config.rs +49 -0
- data/ext/candle/src/llm/mistral.rs +325 -0
- data/ext/candle/src/llm/mod.rs +68 -0
- data/ext/candle/src/llm/text_generation.rs +141 -0
- data/ext/candle/src/reranker.rs +267 -0
- data/ext/candle/src/ruby/device.rs +197 -0
- data/ext/candle/src/ruby/dtype.rs +37 -0
- data/ext/candle/src/ruby/embedding_model.rs +410 -0
- data/ext/candle/src/ruby/errors.rs +13 -0
- data/ext/candle/src/ruby/llm.rs +295 -0
- data/ext/candle/src/ruby/mod.rs +21 -0
- data/ext/candle/src/ruby/qtensor.rs +69 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/tensor.rs +654 -0
- data/ext/candle/src/ruby/utils.rs +88 -0
- data/lib/candle/version.rb +1 -1
- metadata +17 -1
@@ -0,0 +1,654 @@
|
|
1
|
+
use magnus::prelude::*;
|
2
|
+
use magnus::{function, method, class, RModule, Error, Module, Object};
|
3
|
+
|
4
|
+
use crate::ruby::{
|
5
|
+
errors::wrap_candle_err,
|
6
|
+
utils::{actual_dim, actual_index},
|
7
|
+
};
|
8
|
+
use crate::ruby::{DType, Device, Result as RbResult};
|
9
|
+
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
|
10
|
+
|
11
|
+
#[derive(Clone, Debug)]
|
12
|
+
#[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
|
13
|
+
/// A `candle` tensor.
|
14
|
+
pub struct Tensor(pub CoreTensor);
|
15
|
+
|
16
|
+
impl std::ops::Deref for Tensor {
|
17
|
+
type Target = CoreTensor;
|
18
|
+
|
19
|
+
fn deref(&self) -> &Self::Target {
|
20
|
+
&self.0
|
21
|
+
}
|
22
|
+
}
|
23
|
+
|
24
|
+
impl Tensor {
|
25
|
+
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> RbResult<Self> {
|
26
|
+
let dtype = dtype
|
27
|
+
.map(|dtype| DType::from_rbobject(dtype))
|
28
|
+
.unwrap_or(Ok(DType(CoreDType::F32)))?;
|
29
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
30
|
+
// FIXME: Do not use `to_f64` here.
|
31
|
+
let array = array
|
32
|
+
.into_iter()
|
33
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
34
|
+
.collect::<RbResult<Vec<_>>>()?;
|
35
|
+
Ok(Self(
|
36
|
+
CoreTensor::new(array.as_slice(), &device)
|
37
|
+
.map_err(wrap_candle_err)?
|
38
|
+
.to_dtype(dtype.0)
|
39
|
+
.map_err(wrap_candle_err)?,
|
40
|
+
))
|
41
|
+
}
|
42
|
+
|
43
|
+
pub fn values(&self) -> RbResult<Vec<f64>> {
|
44
|
+
let values = self
|
45
|
+
.0
|
46
|
+
.to_dtype(CoreDType::F64)
|
47
|
+
.map_err(wrap_candle_err)?
|
48
|
+
.flatten_all()
|
49
|
+
.map_err(wrap_candle_err)?
|
50
|
+
.to_vec1()
|
51
|
+
.map_err(wrap_candle_err)?;
|
52
|
+
Ok(values)
|
53
|
+
}
|
54
|
+
|
55
|
+
/// Get values as f32 without dtype conversion
|
56
|
+
pub fn values_f32(&self) -> RbResult<Vec<f32>> {
|
57
|
+
match self.0.dtype() {
|
58
|
+
CoreDType::F32 => {
|
59
|
+
let values = self
|
60
|
+
.0
|
61
|
+
.flatten_all()
|
62
|
+
.map_err(wrap_candle_err)?
|
63
|
+
.to_vec1()
|
64
|
+
.map_err(wrap_candle_err)?;
|
65
|
+
Ok(values)
|
66
|
+
}
|
67
|
+
_ => Err(magnus::Error::new(
|
68
|
+
magnus::exception::runtime_error(),
|
69
|
+
"Tensor must be F32 dtype for values_f32",
|
70
|
+
)),
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
/// Get a single scalar value from a rank-0 tensor
|
75
|
+
pub fn item(&self) -> RbResult<f64> {
|
76
|
+
if self.0.rank() != 0 {
|
77
|
+
return Err(magnus::Error::new(
|
78
|
+
magnus::exception::runtime_error(),
|
79
|
+
format!("item() can only be called on scalar tensors (rank 0), but tensor has rank {}", self.0.rank()),
|
80
|
+
));
|
81
|
+
}
|
82
|
+
|
83
|
+
// Try to get the value based on dtype
|
84
|
+
match self.0.dtype() {
|
85
|
+
CoreDType::F32 => {
|
86
|
+
let val: f32 = self.0.to_vec0().map_err(wrap_candle_err)?;
|
87
|
+
Ok(val as f64)
|
88
|
+
}
|
89
|
+
CoreDType::F64 => {
|
90
|
+
let val: f64 = self.0.to_vec0().map_err(wrap_candle_err)?;
|
91
|
+
Ok(val)
|
92
|
+
}
|
93
|
+
_ => {
|
94
|
+
// For other dtypes, convert to F64 first
|
95
|
+
let val: f64 = self.0
|
96
|
+
.to_dtype(CoreDType::F64)
|
97
|
+
.map_err(wrap_candle_err)?
|
98
|
+
.to_vec0()
|
99
|
+
.map_err(wrap_candle_err)?;
|
100
|
+
Ok(val)
|
101
|
+
}
|
102
|
+
}
|
103
|
+
}
|
104
|
+
|
105
|
+
/// Gets the tensor's shape.
|
106
|
+
/// &RETURNS&: Tuple[int]
|
107
|
+
pub fn shape(&self) -> Vec<usize> {
|
108
|
+
self.0.dims().to_vec()
|
109
|
+
}
|
110
|
+
|
111
|
+
/// Gets the tensor's strides.
|
112
|
+
/// &RETURNS&: Tuple[int]
|
113
|
+
pub fn stride(&self) -> Vec<usize> {
|
114
|
+
self.0.stride().to_vec()
|
115
|
+
}
|
116
|
+
|
117
|
+
/// Gets the tensor's dtype.
|
118
|
+
/// &RETURNS&: DType
|
119
|
+
pub fn dtype(&self) -> DType {
|
120
|
+
DType(self.0.dtype())
|
121
|
+
}
|
122
|
+
|
123
|
+
/// Gets the tensor's device.
|
124
|
+
/// &RETURNS&: Device
|
125
|
+
pub fn device(&self) -> Device {
|
126
|
+
Device::from_device(self.0.device())
|
127
|
+
}
|
128
|
+
|
129
|
+
/// Gets the tensor's rank.
|
130
|
+
/// &RETURNS&: int
|
131
|
+
pub fn rank(&self) -> usize {
|
132
|
+
self.0.rank()
|
133
|
+
}
|
134
|
+
|
135
|
+
/// The number of elements stored in this tensor.
|
136
|
+
/// &RETURNS&: int
|
137
|
+
pub fn elem_count(&self) -> usize {
|
138
|
+
self.0.elem_count()
|
139
|
+
}
|
140
|
+
|
141
|
+
pub fn __repr__(&self) -> String {
|
142
|
+
format!("{}", self.0)
|
143
|
+
}
|
144
|
+
|
145
|
+
pub fn __str__(&self) -> String {
|
146
|
+
self.__repr__()
|
147
|
+
}
|
148
|
+
|
149
|
+
/// Performs the `sin` operation on the tensor.
|
150
|
+
/// &RETURNS&: Tensor
|
151
|
+
pub fn sin(&self) -> RbResult<Self> {
|
152
|
+
Ok(Tensor(self.0.sin().map_err(wrap_candle_err)?))
|
153
|
+
}
|
154
|
+
|
155
|
+
/// Performs the `cos` operation on the tensor.
|
156
|
+
/// &RETURNS&: Tensor
|
157
|
+
pub fn cos(&self) -> RbResult<Self> {
|
158
|
+
Ok(Tensor(self.0.cos().map_err(wrap_candle_err)?))
|
159
|
+
}
|
160
|
+
|
161
|
+
/// Performs the `log` operation on the tensor.
|
162
|
+
/// &RETURNS&: Tensor
|
163
|
+
pub fn log(&self) -> RbResult<Self> {
|
164
|
+
Ok(Tensor(self.0.log().map_err(wrap_candle_err)?))
|
165
|
+
}
|
166
|
+
|
167
|
+
/// Squares the tensor.
|
168
|
+
/// &RETURNS&: Tensor
|
169
|
+
pub fn sqr(&self) -> RbResult<Self> {
|
170
|
+
Ok(Tensor(self.0.sqr().map_err(wrap_candle_err)?))
|
171
|
+
}
|
172
|
+
|
173
|
+
/// Returns the mean along the specified axis.
|
174
|
+
/// @param axis [Integer, optional] The axis to reduce over (default: 0)
|
175
|
+
/// @return [Candle::Tensor]
|
176
|
+
pub fn mean(&self, axis: Option<i64>) -> RbResult<Self> {
|
177
|
+
let axis = axis.unwrap_or(0) as usize;
|
178
|
+
Ok(Tensor(self.0.mean(axis).map_err(wrap_candle_err)?))
|
179
|
+
}
|
180
|
+
|
181
|
+
/// Returns the sum along the specified axis.
|
182
|
+
/// @param axis [Integer, optional] The axis to reduce over (default: 0)
|
183
|
+
/// @return [Candle::Tensor]
|
184
|
+
pub fn sum(&self, axis: Option<i64>) -> RbResult<Self> {
|
185
|
+
let axis = axis.unwrap_or(0) as usize;
|
186
|
+
Ok(Tensor(self.0.sum(axis).map_err(wrap_candle_err)?))
|
187
|
+
}
|
188
|
+
|
189
|
+
/// Calculates the square root of the tensor.
|
190
|
+
/// &RETURNS&: Tensor
|
191
|
+
pub fn sqrt(&self) -> RbResult<Self> {
|
192
|
+
Ok(Tensor(self.0.sqrt().map_err(wrap_candle_err)?))
|
193
|
+
}
|
194
|
+
|
195
|
+
/// Get the `recip` of the tensor.
|
196
|
+
/// &RETURNS&: Tensor
|
197
|
+
pub fn recip(&self) -> RbResult<Self> {
|
198
|
+
Ok(Tensor(self.0.recip().map_err(wrap_candle_err)?))
|
199
|
+
}
|
200
|
+
|
201
|
+
/// Performs the `exp` operation on the tensor.
|
202
|
+
/// &RETURNS&: Tensor
|
203
|
+
pub fn exp(&self) -> RbResult<Self> {
|
204
|
+
Ok(Tensor(self.0.exp().map_err(wrap_candle_err)?))
|
205
|
+
}
|
206
|
+
|
207
|
+
/// Performs the `pow` operation on the tensor with the given exponent.
|
208
|
+
/// &RETURNS&: Tensor
|
209
|
+
pub fn powf(&self, p: f64) -> RbResult<Self> {
|
210
|
+
Ok(Tensor(self.0.powf(p).map_err(wrap_candle_err)?))
|
211
|
+
}
|
212
|
+
|
213
|
+
/// Select values for the input tensor at the target indexes across the specified dimension.
|
214
|
+
///
|
215
|
+
/// The `indexes` is argument is an int tensor with a single dimension.
|
216
|
+
/// The output has the same number of dimension as the `self` input. The target dimension of
|
217
|
+
/// the output has length the length of `indexes` and the values are taken from `self` using
|
218
|
+
/// the index from `indexes`. Other dimensions have the same number of elements as the input
|
219
|
+
/// tensor.
|
220
|
+
/// &RETURNS&: Tensor
|
221
|
+
pub fn index_select(&self, rhs: &Self, dim: i64) -> RbResult<Self> {
|
222
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
223
|
+
Ok(Tensor(
|
224
|
+
self.0.index_select(rhs, dim).map_err(wrap_candle_err)?,
|
225
|
+
))
|
226
|
+
}
|
227
|
+
|
228
|
+
/// Performs a matrix multiplication between the two tensors.
|
229
|
+
/// &RETURNS&: Tensor
|
230
|
+
pub fn matmul(&self, rhs: &Self) -> RbResult<Self> {
|
231
|
+
Ok(Tensor(self.0.matmul(rhs).map_err(wrap_candle_err)?))
|
232
|
+
}
|
233
|
+
|
234
|
+
/// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
235
|
+
/// &RETURNS&: Tensor
|
236
|
+
pub fn broadcast_add(&self, rhs: &Self) -> RbResult<Self> {
|
237
|
+
Ok(Tensor(
|
238
|
+
self.0.broadcast_add(rhs).map_err(wrap_candle_err)?,
|
239
|
+
))
|
240
|
+
}
|
241
|
+
|
242
|
+
/// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
243
|
+
/// &RETURNS&: Tensor
|
244
|
+
pub fn broadcast_sub(&self, rhs: &Self) -> RbResult<Self> {
|
245
|
+
Ok(Tensor(
|
246
|
+
self.0.broadcast_sub(rhs).map_err(wrap_candle_err)?,
|
247
|
+
))
|
248
|
+
}
|
249
|
+
|
250
|
+
/// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
251
|
+
/// &RETURNS&: Tensor
|
252
|
+
pub fn broadcast_mul(&self, rhs: &Self) -> RbResult<Self> {
|
253
|
+
Ok(Tensor(
|
254
|
+
self.0.broadcast_mul(rhs).map_err(wrap_candle_err)?,
|
255
|
+
))
|
256
|
+
}
|
257
|
+
|
258
|
+
/// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
259
|
+
/// &RETURNS&: Tensor
|
260
|
+
pub fn broadcast_div(&self, rhs: &Self) -> RbResult<Self> {
|
261
|
+
Ok(Tensor(
|
262
|
+
self.0.broadcast_div(rhs).map_err(wrap_candle_err)?,
|
263
|
+
))
|
264
|
+
}
|
265
|
+
|
266
|
+
/// Returns a tensor with the same shape as the input tensor, the values are taken from
|
267
|
+
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
268
|
+
/// input tensor is equal to zero.
|
269
|
+
/// &RETURNS&: Tensor
|
270
|
+
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> RbResult<Self> {
|
271
|
+
Ok(Tensor(
|
272
|
+
self.0
|
273
|
+
.where_cond(on_true, on_false)
|
274
|
+
.map_err(wrap_candle_err)?,
|
275
|
+
))
|
276
|
+
}
|
277
|
+
|
278
|
+
/// Add two tensors.
|
279
|
+
/// &RETURNS&: Tensor
|
280
|
+
pub fn __add__(&self, rhs: &Tensor) -> RbResult<Self> {
|
281
|
+
Ok(Self(self.0.add(&rhs.0).map_err(wrap_candle_err)?))
|
282
|
+
}
|
283
|
+
|
284
|
+
/// Multiply two tensors.
|
285
|
+
/// &RETURNS&: Tensor
|
286
|
+
pub fn __mul__(&self, rhs: &Tensor) -> RbResult<Self> {
|
287
|
+
Ok(Self(self.0.mul(&rhs.0).map_err(wrap_candle_err)?))
|
288
|
+
}
|
289
|
+
|
290
|
+
/// Subtract two tensors.
|
291
|
+
/// &RETURNS&: Tensor
|
292
|
+
pub fn __sub__(&self, rhs: &Tensor) -> RbResult<Self> {
|
293
|
+
Ok(Self(self.0.sub(&rhs.0).map_err(wrap_candle_err)?))
|
294
|
+
}
|
295
|
+
|
296
|
+
/// Divide two tensors.
|
297
|
+
/// &RETURNS&: Tensor
|
298
|
+
/// Divides this tensor by another tensor or a scalar (Float/Integer).
|
299
|
+
/// @param rhs [Candle::Tensor, Float, or Integer]
|
300
|
+
/// @return [Candle::Tensor]
|
301
|
+
pub fn __truediv__(&self, rhs: magnus::Value) -> RbResult<Self> {
|
302
|
+
use magnus::TryConvert;
|
303
|
+
if let Ok(tensor) = <&Tensor>::try_convert(rhs) {
|
304
|
+
Ok(Self(self.0.broadcast_div(&tensor.0).map_err(wrap_candle_err)?))
|
305
|
+
} else if let Ok(f) = <f64>::try_convert(rhs) {
|
306
|
+
let scalar = CoreTensor::from_vec(vec![f as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
|
307
|
+
Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
|
308
|
+
} else if let Ok(i) = <i64>::try_convert(rhs) {
|
309
|
+
let scalar = CoreTensor::from_vec(vec![i as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
|
310
|
+
Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
|
311
|
+
} else {
|
312
|
+
Err(magnus::Error::new(magnus::exception::type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
|
313
|
+
}
|
314
|
+
}
|
315
|
+
|
316
|
+
/// Reshapes the tensor to the given shape.
|
317
|
+
/// &RETURNS&: Tensor
|
318
|
+
pub fn reshape(&self, shape: Vec<usize>) -> RbResult<Self> {
|
319
|
+
Ok(Tensor(self.0.reshape(shape).map_err(wrap_candle_err)?))
|
320
|
+
}
|
321
|
+
|
322
|
+
/// Broadcasts the tensor to the given shape.
|
323
|
+
/// &RETURNS&: Tensor
|
324
|
+
pub fn broadcast_as(&self, shape: Vec<usize>) -> RbResult<Self> {
|
325
|
+
Ok(Tensor(
|
326
|
+
self.0.broadcast_as(shape).map_err(wrap_candle_err)?,
|
327
|
+
))
|
328
|
+
}
|
329
|
+
|
330
|
+
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
331
|
+
/// &RETURNS&: Tensor
|
332
|
+
pub fn broadcast_left(&self, shape: Vec<usize>) -> RbResult<Self> {
|
333
|
+
Ok(Tensor(
|
334
|
+
self.0.broadcast_left(shape).map_err(wrap_candle_err)?,
|
335
|
+
))
|
336
|
+
}
|
337
|
+
|
338
|
+
/// Creates a new tensor with the specified dimension removed if its size was one.
|
339
|
+
/// &RETURNS&: Tensor
|
340
|
+
pub fn squeeze(&self, dim: i64) -> RbResult<Self> {
|
341
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
342
|
+
Ok(Tensor(self.0.squeeze(dim).map_err(wrap_candle_err)?))
|
343
|
+
}
|
344
|
+
|
345
|
+
/// Creates a new tensor with a dimension of size one inserted at the specified position.
|
346
|
+
/// &RETURNS&: Tensor
|
347
|
+
pub fn unsqueeze(&self, dim: usize) -> RbResult<Self> {
|
348
|
+
Ok(Tensor(self.0.unsqueeze(dim).map_err(wrap_candle_err)?))
|
349
|
+
}
|
350
|
+
|
351
|
+
/// Gets the value at the specified index.
|
352
|
+
/// &RETURNS&: Tensor
|
353
|
+
pub fn get(&self, index: i64) -> RbResult<Self> {
|
354
|
+
let index = actual_index(self, 0, index).map_err(wrap_candle_err)?;
|
355
|
+
Ok(Tensor(self.0.get(index).map_err(wrap_candle_err)?))
|
356
|
+
}
|
357
|
+
|
358
|
+
/// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
359
|
+
/// &RETURNS&: Tensor
|
360
|
+
pub fn transpose(&self, dim1: usize, dim2: usize) -> RbResult<Self> {
|
361
|
+
Ok(Tensor(
|
362
|
+
self.0.transpose(dim1, dim2).map_err(wrap_candle_err)?,
|
363
|
+
))
|
364
|
+
}
|
365
|
+
|
366
|
+
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
367
|
+
/// ranges from `start` to `start + len`.
|
368
|
+
/// &RETURNS&: Tensor
|
369
|
+
pub fn narrow(&self, dim: i64, start: i64, len: usize) -> RbResult<Self> {
|
370
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
371
|
+
let start = actual_index(self, dim, start).map_err(wrap_candle_err)?;
|
372
|
+
Ok(Tensor(
|
373
|
+
self.0.narrow(dim, start, len).map_err(wrap_candle_err)?,
|
374
|
+
))
|
375
|
+
}
|
376
|
+
|
377
|
+
/// Returns the indices of the maximum value(s) across the selected dimension.
|
378
|
+
/// &RETURNS&: Tensor
|
379
|
+
pub fn argmax_keepdim(&self, dim: i64) -> RbResult<Self> {
|
380
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
381
|
+
Ok(Tensor(
|
382
|
+
self.0.argmax_keepdim(dim).map_err(wrap_candle_err)?,
|
383
|
+
))
|
384
|
+
}
|
385
|
+
|
386
|
+
/// Returns the indices of the minimum value(s) across the selected dimension.
|
387
|
+
/// &RETURNS&: Tensor
|
388
|
+
pub fn argmin_keepdim(&self, dim: i64) -> RbResult<Self> {
|
389
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
390
|
+
Ok(Tensor(
|
391
|
+
self.0.argmin_keepdim(dim).map_err(wrap_candle_err)?,
|
392
|
+
))
|
393
|
+
}
|
394
|
+
|
395
|
+
/// Gathers the maximum value across the selected dimension.
|
396
|
+
/// &RETURNS&: Tensor
|
397
|
+
pub fn max_keepdim(&self, dim: i64) -> RbResult<Self> {
|
398
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
399
|
+
Ok(Tensor(self.0.max_keepdim(dim).map_err(wrap_candle_err)?))
|
400
|
+
}
|
401
|
+
|
402
|
+
/// Gathers the minimum value across the selected dimension.
|
403
|
+
/// &RETURNS&: Tensor
|
404
|
+
pub fn min_keepdim(&self, dim: i64) -> RbResult<Self> {
|
405
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
406
|
+
Ok(Tensor(self.0.min_keepdim(dim).map_err(wrap_candle_err)?))
|
407
|
+
}
|
408
|
+
|
409
|
+
// fn eq(&self, rhs: &Self) -> RbResult<Self> {
|
410
|
+
// Ok(Tensor(self.0.eq(rhs).map_err(wrap_candle_err)?))
|
411
|
+
// }
|
412
|
+
|
413
|
+
// fn ne(&self, rhs: &Self) -> RbResult<Self> {
|
414
|
+
// Ok(Tensor(self.0.ne(rhs).map_err(wrap_candle_err)?))
|
415
|
+
// }
|
416
|
+
|
417
|
+
// fn lt(&self, rhs: &Self) -> RbResult<Self> {
|
418
|
+
// Ok(Tensor(self.0.lt(rhs).map_err(wrap_candle_err)?))
|
419
|
+
// }
|
420
|
+
|
421
|
+
// fn gt(&self, rhs: &Self) -> RbResult<Self> {
|
422
|
+
// Ok(Tensor(self.0.gt(rhs).map_err(wrap_candle_err)?))
|
423
|
+
// }
|
424
|
+
|
425
|
+
// fn ge(&self, rhs: &Self) -> RbResult<Self> {
|
426
|
+
// Ok(Tensor(self.0.ge(rhs).map_err(wrap_candle_err)?))
|
427
|
+
// }
|
428
|
+
|
429
|
+
// fn le(&self, rhs: &Self) -> RbResult<Self> {
|
430
|
+
// Ok(Tensor(self.0.le(rhs).map_err(wrap_candle_err)?))
|
431
|
+
// }
|
432
|
+
|
433
|
+
/// Returns the sum of the tensor.
|
434
|
+
/// &RETURNS&: Tensor
|
435
|
+
pub fn sum_all(&self) -> RbResult<Self> {
|
436
|
+
Ok(Tensor(self.0.sum_all().map_err(wrap_candle_err)?))
|
437
|
+
}
|
438
|
+
|
439
|
+
/// Returns the mean of the tensor.
|
440
|
+
/// &RETURNS&: Tensor
|
441
|
+
pub fn mean_all(&self) -> RbResult<Self> {
|
442
|
+
let elements = self.0.elem_count();
|
443
|
+
let sum = self.0.sum_all().map_err(wrap_candle_err)?;
|
444
|
+
let mean = (sum / elements as f64).map_err(wrap_candle_err)?;
|
445
|
+
Ok(Tensor(mean))
|
446
|
+
}
|
447
|
+
|
448
|
+
/// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
449
|
+
/// &RETURNS&: Tensor
|
450
|
+
pub fn flatten_from(&self, dim: i64) -> RbResult<Self> {
|
451
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
452
|
+
Ok(Tensor(self.0.flatten_from(dim).map_err(wrap_candle_err)?))
|
453
|
+
}
|
454
|
+
|
455
|
+
///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
456
|
+
/// &RETURNS&: Tensor
|
457
|
+
pub fn flatten_to(&self, dim: i64) -> RbResult<Self> {
|
458
|
+
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
459
|
+
Ok(Tensor(self.0.flatten_to(dim).map_err(wrap_candle_err)?))
|
460
|
+
}
|
461
|
+
|
462
|
+
/// Flattens the tensor into a 1D tensor.
|
463
|
+
/// &RETURNS&: Tensor
|
464
|
+
pub fn flatten_all(&self) -> RbResult<Self> {
|
465
|
+
Ok(Tensor(self.0.flatten_all().map_err(wrap_candle_err)?))
|
466
|
+
}
|
467
|
+
|
468
|
+
/// Transposes the tensor.
|
469
|
+
/// &RETURNS&: Tensor
|
470
|
+
pub fn t(&self) -> RbResult<Self> {
|
471
|
+
Ok(Tensor(self.0.t().map_err(wrap_candle_err)?))
|
472
|
+
}
|
473
|
+
|
474
|
+
/// Makes the tensor contiguous in memory.
|
475
|
+
/// &RETURNS&: Tensor
|
476
|
+
pub fn contiguous(&self) -> RbResult<Self> {
|
477
|
+
Ok(Tensor(self.0.contiguous().map_err(wrap_candle_err)?))
|
478
|
+
}
|
479
|
+
|
480
|
+
/// Returns true if the tensor is contiguous in C order.
|
481
|
+
/// &RETURNS&: bool
|
482
|
+
pub fn is_contiguous(&self) -> bool {
|
483
|
+
self.0.is_contiguous()
|
484
|
+
}
|
485
|
+
|
486
|
+
/// Returns true if the tensor is contiguous in Fortran order.
|
487
|
+
/// &RETURNS&: bool
|
488
|
+
pub fn is_fortran_contiguous(&self) -> bool {
|
489
|
+
self.0.is_fortran_contiguous()
|
490
|
+
}
|
491
|
+
|
492
|
+
/// Detach the tensor from the computation graph.
|
493
|
+
/// &RETURNS&: Tensor
|
494
|
+
pub fn detach(&self) -> RbResult<Self> {
|
495
|
+
Ok(Tensor(self.0.detach()))
|
496
|
+
}
|
497
|
+
|
498
|
+
/// Returns a copy of the tensor.
|
499
|
+
/// &RETURNS&: Tensor
|
500
|
+
pub fn copy(&self) -> RbResult<Self> {
|
501
|
+
Ok(Tensor(self.0.copy().map_err(wrap_candle_err)?))
|
502
|
+
}
|
503
|
+
|
504
|
+
/// Convert the tensor to a new dtype.
|
505
|
+
/// &RETURNS&: Tensor
|
506
|
+
pub fn to_dtype(&self, dtype: magnus::Symbol) -> RbResult<Self> {
|
507
|
+
let dtype = DType::from_rbobject(dtype)?;
|
508
|
+
Ok(Tensor(self.0.to_dtype(dtype.0).map_err(wrap_candle_err)?))
|
509
|
+
}
|
510
|
+
|
511
|
+
/// Move the tensor to a new device.
|
512
|
+
/// &RETURNS&: Tensor
|
513
|
+
pub fn to_device(&self, device: Device) -> RbResult<Self> {
|
514
|
+
let device = device.as_device()?;
|
515
|
+
Ok(Tensor(
|
516
|
+
self.0.to_device(&device).map_err(wrap_candle_err)?,
|
517
|
+
))
|
518
|
+
}
|
519
|
+
}
|
520
|
+
|
521
|
+
impl Tensor {
|
522
|
+
// fn cat(tensors: Vec<RbTensor>, dim: i64) -> RbResult<RbTensor> {
|
523
|
+
// if tensors.is_empty() {
|
524
|
+
// return Err(Error::new(
|
525
|
+
// magnus::exception::arg_error(),
|
526
|
+
// "empty input to cat",
|
527
|
+
// ));
|
528
|
+
// }
|
529
|
+
// let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_candle_err)?;
|
530
|
+
// let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
531
|
+
// let tensor = Tensor::cat(&tensors, dim).map_err(wrap_candle_err)?;
|
532
|
+
// Ok(Tensor(tensor))
|
533
|
+
// }
|
534
|
+
|
535
|
+
// fn stack(tensors: Vec<RbTensor>, dim: usize) -> RbResult<Self> {
|
536
|
+
// let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
537
|
+
// let tensor = Tensor::stack(&tensors, dim).map_err(wrap_candle_err)?;
|
538
|
+
// Ok(Self(tensor))
|
539
|
+
// }
|
540
|
+
|
541
|
+
/// Creates a new tensor with random values.
|
542
|
+
/// &RETURNS&: Tensor
|
543
|
+
pub fn rand(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
|
544
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
545
|
+
Ok(Self(
|
546
|
+
CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
547
|
+
))
|
548
|
+
}
|
549
|
+
|
550
|
+
/// Creates a new tensor with random values from a normal distribution.
|
551
|
+
/// &RETURNS&: Tensor
|
552
|
+
pub fn randn(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
|
553
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
554
|
+
Ok(Self(
|
555
|
+
CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
556
|
+
))
|
557
|
+
}
|
558
|
+
|
559
|
+
/// Creates a new tensor filled with ones.
|
560
|
+
/// &RETURNS&: Tensor
|
561
|
+
pub fn ones(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
|
562
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
563
|
+
Ok(Self(
|
564
|
+
CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
565
|
+
))
|
566
|
+
}
|
567
|
+
/// Creates a new tensor filled with zeros.
|
568
|
+
/// &RETURNS&: Tensor
|
569
|
+
pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> RbResult<Self> {
|
570
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
571
|
+
Ok(Self(
|
572
|
+
CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
573
|
+
))
|
574
|
+
}
|
575
|
+
}
|
576
|
+
|
577
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
578
|
+
let rb_tensor = rb_candle.define_class("Tensor", class::object())?;
|
579
|
+
rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
|
580
|
+
// rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
|
581
|
+
// rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
|
582
|
+
rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
|
583
|
+
rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
|
584
|
+
rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
|
585
|
+
rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
|
586
|
+
rb_tensor.define_method("values", method!(Tensor::values, 0))?;
|
587
|
+
rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
|
588
|
+
rb_tensor.define_method("item", method!(Tensor::item, 0))?;
|
589
|
+
rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
|
590
|
+
rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
|
591
|
+
rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
|
592
|
+
rb_tensor.define_method("device", method!(Tensor::device, 0))?;
|
593
|
+
rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
|
594
|
+
rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
|
595
|
+
rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
|
596
|
+
rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
|
597
|
+
rb_tensor.define_method("log", method!(Tensor::log, 0))?;
|
598
|
+
rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
|
599
|
+
rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
|
600
|
+
rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
|
601
|
+
rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
|
602
|
+
rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
|
603
|
+
rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
|
604
|
+
rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
|
605
|
+
rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
|
606
|
+
rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
|
607
|
+
rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
|
608
|
+
rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
|
609
|
+
rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
|
610
|
+
rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
|
611
|
+
rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
|
612
|
+
rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
|
613
|
+
rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
|
614
|
+
rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
|
615
|
+
rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
|
616
|
+
rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
|
617
|
+
rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
|
618
|
+
rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
|
619
|
+
rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
|
620
|
+
rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
|
621
|
+
rb_tensor.define_method("get", method!(Tensor::get, 1))?;
|
622
|
+
rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
|
623
|
+
rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
|
624
|
+
rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
|
625
|
+
rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
|
626
|
+
rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
|
627
|
+
rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
|
628
|
+
rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
|
629
|
+
// rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
|
630
|
+
// rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
|
631
|
+
// rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
|
632
|
+
// rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
|
633
|
+
// rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
|
634
|
+
// rb_tensor.define_method("le", method!(Tensor::le, 1))?;
|
635
|
+
rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
|
636
|
+
rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
|
637
|
+
rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
|
638
|
+
rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
|
639
|
+
rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
|
640
|
+
rb_tensor.define_method("t", method!(Tensor::t, 0))?;
|
641
|
+
rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
|
642
|
+
rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
|
643
|
+
rb_tensor.define_method(
|
644
|
+
"is_fortran_contiguous",
|
645
|
+
method!(Tensor::is_fortran_contiguous, 0),
|
646
|
+
)?;
|
647
|
+
rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
|
648
|
+
rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
|
649
|
+
rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
|
650
|
+
rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
|
651
|
+
rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
|
652
|
+
rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
|
653
|
+
Ok(())
|
654
|
+
}
|