red-candle 1.0.0.pre.7 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +322 -4
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +5 -0
- data/ext/candle/src/llm/llama.rs +5 -0
- data/ext/candle/src/llm/mistral.rs +5 -0
- data/ext/candle/src/llm/mod.rs +1 -89
- data/ext/candle/src/llm/quantized_gguf.rs +5 -0
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -33
- data/ext/candle/src/ruby/llm.rs +31 -13
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +126 -3
- data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -1,11 +1,11 @@
|
|
1
1
|
use magnus::prelude::*;
|
2
|
-
use magnus::{function, method, class, RModule,
|
2
|
+
use magnus::{function, method, class, RModule, Module, Object};
|
3
3
|
|
4
4
|
use crate::ruby::{
|
5
5
|
errors::wrap_candle_err,
|
6
6
|
utils::{actual_dim, actual_index},
|
7
7
|
};
|
8
|
-
use crate::ruby::{DType, Device, Result
|
8
|
+
use crate::ruby::{DType, Device, Result};
|
9
9
|
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
|
10
10
|
|
11
11
|
#[derive(Clone, Debug)]
|
@@ -22,7 +22,7 @@ impl std::ops::Deref for Tensor {
|
|
22
22
|
}
|
23
23
|
|
24
24
|
impl Tensor {
|
25
|
-
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) ->
|
25
|
+
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
|
26
26
|
let dtype = dtype
|
27
27
|
.map(|dtype| DType::from_rbobject(dtype))
|
28
28
|
.unwrap_or(Ok(DType(CoreDType::F32)))?;
|
@@ -31,7 +31,7 @@ impl Tensor {
|
|
31
31
|
let array = array
|
32
32
|
.into_iter()
|
33
33
|
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
34
|
-
.collect::<
|
34
|
+
.collect::<Result<Vec<_>>>()?;
|
35
35
|
Ok(Self(
|
36
36
|
CoreTensor::new(array.as_slice(), &device)
|
37
37
|
.map_err(wrap_candle_err)?
|
@@ -40,7 +40,7 @@ impl Tensor {
|
|
40
40
|
))
|
41
41
|
}
|
42
42
|
|
43
|
-
pub fn values(&self) ->
|
43
|
+
pub fn values(&self) -> Result<Vec<f64>> {
|
44
44
|
let values = self
|
45
45
|
.0
|
46
46
|
.to_dtype(CoreDType::F64)
|
@@ -53,7 +53,7 @@ impl Tensor {
|
|
53
53
|
}
|
54
54
|
|
55
55
|
/// Get values as f32 without dtype conversion
|
56
|
-
pub fn values_f32(&self) ->
|
56
|
+
pub fn values_f32(&self) -> Result<Vec<f32>> {
|
57
57
|
match self.0.dtype() {
|
58
58
|
CoreDType::F32 => {
|
59
59
|
let values = self
|
@@ -72,7 +72,7 @@ impl Tensor {
|
|
72
72
|
}
|
73
73
|
|
74
74
|
/// Get a single scalar value from a rank-0 tensor
|
75
|
-
pub fn item(&self) ->
|
75
|
+
pub fn item(&self) -> Result<f64> {
|
76
76
|
if self.0.rank() != 0 {
|
77
77
|
return Err(magnus::Error::new(
|
78
78
|
magnus::exception::runtime_error(),
|
@@ -148,32 +148,32 @@ impl Tensor {
|
|
148
148
|
|
149
149
|
/// Performs the `sin` operation on the tensor.
|
150
150
|
/// &RETURNS&: Tensor
|
151
|
-
pub fn sin(&self) ->
|
151
|
+
pub fn sin(&self) -> Result<Self> {
|
152
152
|
Ok(Tensor(self.0.sin().map_err(wrap_candle_err)?))
|
153
153
|
}
|
154
154
|
|
155
155
|
/// Performs the `cos` operation on the tensor.
|
156
156
|
/// &RETURNS&: Tensor
|
157
|
-
pub fn cos(&self) ->
|
157
|
+
pub fn cos(&self) -> Result<Self> {
|
158
158
|
Ok(Tensor(self.0.cos().map_err(wrap_candle_err)?))
|
159
159
|
}
|
160
160
|
|
161
161
|
/// Performs the `log` operation on the tensor.
|
162
162
|
/// &RETURNS&: Tensor
|
163
|
-
pub fn log(&self) ->
|
163
|
+
pub fn log(&self) -> Result<Self> {
|
164
164
|
Ok(Tensor(self.0.log().map_err(wrap_candle_err)?))
|
165
165
|
}
|
166
166
|
|
167
167
|
/// Squares the tensor.
|
168
168
|
/// &RETURNS&: Tensor
|
169
|
-
pub fn sqr(&self) ->
|
169
|
+
pub fn sqr(&self) -> Result<Self> {
|
170
170
|
Ok(Tensor(self.0.sqr().map_err(wrap_candle_err)?))
|
171
171
|
}
|
172
172
|
|
173
173
|
/// Returns the mean along the specified axis.
|
174
174
|
/// @param axis [Integer, optional] The axis to reduce over (default: 0)
|
175
175
|
/// @return [Candle::Tensor]
|
176
|
-
pub fn mean(&self, axis: Option<i64>) ->
|
176
|
+
pub fn mean(&self, axis: Option<i64>) -> Result<Self> {
|
177
177
|
let axis = axis.unwrap_or(0) as usize;
|
178
178
|
Ok(Tensor(self.0.mean(axis).map_err(wrap_candle_err)?))
|
179
179
|
}
|
@@ -181,32 +181,32 @@ impl Tensor {
|
|
181
181
|
/// Returns the sum along the specified axis.
|
182
182
|
/// @param axis [Integer, optional] The axis to reduce over (default: 0)
|
183
183
|
/// @return [Candle::Tensor]
|
184
|
-
pub fn sum(&self, axis: Option<i64>) ->
|
184
|
+
pub fn sum(&self, axis: Option<i64>) -> Result<Self> {
|
185
185
|
let axis = axis.unwrap_or(0) as usize;
|
186
186
|
Ok(Tensor(self.0.sum(axis).map_err(wrap_candle_err)?))
|
187
187
|
}
|
188
188
|
|
189
189
|
/// Calculates the square root of the tensor.
|
190
190
|
/// &RETURNS&: Tensor
|
191
|
-
pub fn sqrt(&self) ->
|
191
|
+
pub fn sqrt(&self) -> Result<Self> {
|
192
192
|
Ok(Tensor(self.0.sqrt().map_err(wrap_candle_err)?))
|
193
193
|
}
|
194
194
|
|
195
195
|
/// Get the `recip` of the tensor.
|
196
196
|
/// &RETURNS&: Tensor
|
197
|
-
pub fn recip(&self) ->
|
197
|
+
pub fn recip(&self) -> Result<Self> {
|
198
198
|
Ok(Tensor(self.0.recip().map_err(wrap_candle_err)?))
|
199
199
|
}
|
200
200
|
|
201
201
|
/// Performs the `exp` operation on the tensor.
|
202
202
|
/// &RETURNS&: Tensor
|
203
|
-
pub fn exp(&self) ->
|
203
|
+
pub fn exp(&self) -> Result<Self> {
|
204
204
|
Ok(Tensor(self.0.exp().map_err(wrap_candle_err)?))
|
205
205
|
}
|
206
206
|
|
207
207
|
/// Performs the `pow` operation on the tensor with the given exponent.
|
208
208
|
/// &RETURNS&: Tensor
|
209
|
-
pub fn powf(&self, p: f64) ->
|
209
|
+
pub fn powf(&self, p: f64) -> Result<Self> {
|
210
210
|
Ok(Tensor(self.0.powf(p).map_err(wrap_candle_err)?))
|
211
211
|
}
|
212
212
|
|
@@ -218,7 +218,7 @@ impl Tensor {
|
|
218
218
|
/// the index from `indexes`. Other dimensions have the same number of elements as the input
|
219
219
|
/// tensor.
|
220
220
|
/// &RETURNS&: Tensor
|
221
|
-
pub fn index_select(&self, rhs: &Self, dim: i64) ->
|
221
|
+
pub fn index_select(&self, rhs: &Self, dim: i64) -> Result<Self> {
|
222
222
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
223
223
|
Ok(Tensor(
|
224
224
|
self.0.index_select(rhs, dim).map_err(wrap_candle_err)?,
|
@@ -227,13 +227,13 @@ impl Tensor {
|
|
227
227
|
|
228
228
|
/// Performs a matrix multiplication between the two tensors.
|
229
229
|
/// &RETURNS&: Tensor
|
230
|
-
pub fn matmul(&self, rhs: &Self) ->
|
230
|
+
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
231
231
|
Ok(Tensor(self.0.matmul(rhs).map_err(wrap_candle_err)?))
|
232
232
|
}
|
233
233
|
|
234
234
|
/// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
235
235
|
/// &RETURNS&: Tensor
|
236
|
-
pub fn broadcast_add(&self, rhs: &Self) ->
|
236
|
+
pub fn broadcast_add(&self, rhs: &Self) -> Result<Self> {
|
237
237
|
Ok(Tensor(
|
238
238
|
self.0.broadcast_add(rhs).map_err(wrap_candle_err)?,
|
239
239
|
))
|
@@ -241,7 +241,7 @@ impl Tensor {
|
|
241
241
|
|
242
242
|
/// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
243
243
|
/// &RETURNS&: Tensor
|
244
|
-
pub fn broadcast_sub(&self, rhs: &Self) ->
|
244
|
+
pub fn broadcast_sub(&self, rhs: &Self) -> Result<Self> {
|
245
245
|
Ok(Tensor(
|
246
246
|
self.0.broadcast_sub(rhs).map_err(wrap_candle_err)?,
|
247
247
|
))
|
@@ -249,7 +249,7 @@ impl Tensor {
|
|
249
249
|
|
250
250
|
/// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
251
251
|
/// &RETURNS&: Tensor
|
252
|
-
pub fn broadcast_mul(&self, rhs: &Self) ->
|
252
|
+
pub fn broadcast_mul(&self, rhs: &Self) -> Result<Self> {
|
253
253
|
Ok(Tensor(
|
254
254
|
self.0.broadcast_mul(rhs).map_err(wrap_candle_err)?,
|
255
255
|
))
|
@@ -257,7 +257,7 @@ impl Tensor {
|
|
257
257
|
|
258
258
|
/// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
259
259
|
/// &RETURNS&: Tensor
|
260
|
-
pub fn broadcast_div(&self, rhs: &Self) ->
|
260
|
+
pub fn broadcast_div(&self, rhs: &Self) -> Result<Self> {
|
261
261
|
Ok(Tensor(
|
262
262
|
self.0.broadcast_div(rhs).map_err(wrap_candle_err)?,
|
263
263
|
))
|
@@ -267,7 +267,7 @@ impl Tensor {
|
|
267
267
|
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
268
268
|
/// input tensor is equal to zero.
|
269
269
|
/// &RETURNS&: Tensor
|
270
|
-
pub fn where_cond(&self, on_true: &Self, on_false: &Self) ->
|
270
|
+
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
271
271
|
Ok(Tensor(
|
272
272
|
self.0
|
273
273
|
.where_cond(on_true, on_false)
|
@@ -277,19 +277,19 @@ impl Tensor {
|
|
277
277
|
|
278
278
|
/// Add two tensors.
|
279
279
|
/// &RETURNS&: Tensor
|
280
|
-
pub fn __add__(&self, rhs: &Tensor) ->
|
280
|
+
pub fn __add__(&self, rhs: &Tensor) -> Result<Self> {
|
281
281
|
Ok(Self(self.0.add(&rhs.0).map_err(wrap_candle_err)?))
|
282
282
|
}
|
283
283
|
|
284
284
|
/// Multiply two tensors.
|
285
285
|
/// &RETURNS&: Tensor
|
286
|
-
pub fn __mul__(&self, rhs: &Tensor) ->
|
286
|
+
pub fn __mul__(&self, rhs: &Tensor) -> Result<Self> {
|
287
287
|
Ok(Self(self.0.mul(&rhs.0).map_err(wrap_candle_err)?))
|
288
288
|
}
|
289
289
|
|
290
290
|
/// Subtract two tensors.
|
291
291
|
/// &RETURNS&: Tensor
|
292
|
-
pub fn __sub__(&self, rhs: &Tensor) ->
|
292
|
+
pub fn __sub__(&self, rhs: &Tensor) -> Result<Self> {
|
293
293
|
Ok(Self(self.0.sub(&rhs.0).map_err(wrap_candle_err)?))
|
294
294
|
}
|
295
295
|
|
@@ -298,7 +298,7 @@ impl Tensor {
|
|
298
298
|
/// Divides this tensor by another tensor or a scalar (Float/Integer).
|
299
299
|
/// @param rhs [Candle::Tensor, Float, or Integer]
|
300
300
|
/// @return [Candle::Tensor]
|
301
|
-
pub fn __truediv__(&self, rhs: magnus::Value) ->
|
301
|
+
pub fn __truediv__(&self, rhs: magnus::Value) -> Result<Self> {
|
302
302
|
use magnus::TryConvert;
|
303
303
|
if let Ok(tensor) = <&Tensor>::try_convert(rhs) {
|
304
304
|
Ok(Self(self.0.broadcast_div(&tensor.0).map_err(wrap_candle_err)?))
|
@@ -315,13 +315,13 @@ impl Tensor {
|
|
315
315
|
|
316
316
|
/// Reshapes the tensor to the given shape.
|
317
317
|
/// &RETURNS&: Tensor
|
318
|
-
pub fn reshape(&self, shape: Vec<usize>) ->
|
318
|
+
pub fn reshape(&self, shape: Vec<usize>) -> Result<Self> {
|
319
319
|
Ok(Tensor(self.0.reshape(shape).map_err(wrap_candle_err)?))
|
320
320
|
}
|
321
321
|
|
322
322
|
/// Broadcasts the tensor to the given shape.
|
323
323
|
/// &RETURNS&: Tensor
|
324
|
-
pub fn broadcast_as(&self, shape: Vec<usize>) ->
|
324
|
+
pub fn broadcast_as(&self, shape: Vec<usize>) -> Result<Self> {
|
325
325
|
Ok(Tensor(
|
326
326
|
self.0.broadcast_as(shape).map_err(wrap_candle_err)?,
|
327
327
|
))
|
@@ -329,7 +329,7 @@ impl Tensor {
|
|
329
329
|
|
330
330
|
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
331
331
|
/// &RETURNS&: Tensor
|
332
|
-
pub fn broadcast_left(&self, shape: Vec<usize>) ->
|
332
|
+
pub fn broadcast_left(&self, shape: Vec<usize>) -> Result<Self> {
|
333
333
|
Ok(Tensor(
|
334
334
|
self.0.broadcast_left(shape).map_err(wrap_candle_err)?,
|
335
335
|
))
|
@@ -337,27 +337,27 @@ impl Tensor {
|
|
337
337
|
|
338
338
|
/// Creates a new tensor with the specified dimension removed if its size was one.
|
339
339
|
/// &RETURNS&: Tensor
|
340
|
-
pub fn squeeze(&self, dim: i64) ->
|
340
|
+
pub fn squeeze(&self, dim: i64) -> Result<Self> {
|
341
341
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
342
342
|
Ok(Tensor(self.0.squeeze(dim).map_err(wrap_candle_err)?))
|
343
343
|
}
|
344
344
|
|
345
345
|
/// Creates a new tensor with a dimension of size one inserted at the specified position.
|
346
346
|
/// &RETURNS&: Tensor
|
347
|
-
pub fn unsqueeze(&self, dim: usize) ->
|
347
|
+
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
|
348
348
|
Ok(Tensor(self.0.unsqueeze(dim).map_err(wrap_candle_err)?))
|
349
349
|
}
|
350
350
|
|
351
351
|
/// Gets the value at the specified index.
|
352
352
|
/// &RETURNS&: Tensor
|
353
|
-
pub fn get(&self, index: i64) ->
|
353
|
+
pub fn get(&self, index: i64) -> Result<Self> {
|
354
354
|
let index = actual_index(self, 0, index).map_err(wrap_candle_err)?;
|
355
355
|
Ok(Tensor(self.0.get(index).map_err(wrap_candle_err)?))
|
356
356
|
}
|
357
357
|
|
358
358
|
/// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
359
359
|
/// &RETURNS&: Tensor
|
360
|
-
pub fn transpose(&self, dim1: usize, dim2: usize) ->
|
360
|
+
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
361
361
|
Ok(Tensor(
|
362
362
|
self.0.transpose(dim1, dim2).map_err(wrap_candle_err)?,
|
363
363
|
))
|
@@ -366,7 +366,7 @@ impl Tensor {
|
|
366
366
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
367
367
|
/// ranges from `start` to `start + len`.
|
368
368
|
/// &RETURNS&: Tensor
|
369
|
-
pub fn narrow(&self, dim: i64, start: i64, len: usize) ->
|
369
|
+
pub fn narrow(&self, dim: i64, start: i64, len: usize) -> Result<Self> {
|
370
370
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
371
371
|
let start = actual_index(self, dim, start).map_err(wrap_candle_err)?;
|
372
372
|
Ok(Tensor(
|
@@ -376,7 +376,7 @@ impl Tensor {
|
|
376
376
|
|
377
377
|
/// Returns the indices of the maximum value(s) across the selected dimension.
|
378
378
|
/// &RETURNS&: Tensor
|
379
|
-
pub fn argmax_keepdim(&self, dim: i64) ->
|
379
|
+
pub fn argmax_keepdim(&self, dim: i64) -> Result<Self> {
|
380
380
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
381
381
|
Ok(Tensor(
|
382
382
|
self.0.argmax_keepdim(dim).map_err(wrap_candle_err)?,
|
@@ -385,7 +385,7 @@ impl Tensor {
|
|
385
385
|
|
386
386
|
/// Returns the indices of the minimum value(s) across the selected dimension.
|
387
387
|
/// &RETURNS&: Tensor
|
388
|
-
pub fn argmin_keepdim(&self, dim: i64) ->
|
388
|
+
pub fn argmin_keepdim(&self, dim: i64) -> Result<Self> {
|
389
389
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
390
390
|
Ok(Tensor(
|
391
391
|
self.0.argmin_keepdim(dim).map_err(wrap_candle_err)?,
|
@@ -394,51 +394,51 @@ impl Tensor {
|
|
394
394
|
|
395
395
|
/// Gathers the maximum value across the selected dimension.
|
396
396
|
/// &RETURNS&: Tensor
|
397
|
-
pub fn max_keepdim(&self, dim: i64) ->
|
397
|
+
pub fn max_keepdim(&self, dim: i64) -> Result<Self> {
|
398
398
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
399
399
|
Ok(Tensor(self.0.max_keepdim(dim).map_err(wrap_candle_err)?))
|
400
400
|
}
|
401
401
|
|
402
402
|
/// Gathers the minimum value across the selected dimension.
|
403
403
|
/// &RETURNS&: Tensor
|
404
|
-
pub fn min_keepdim(&self, dim: i64) ->
|
404
|
+
pub fn min_keepdim(&self, dim: i64) -> Result<Self> {
|
405
405
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
406
406
|
Ok(Tensor(self.0.min_keepdim(dim).map_err(wrap_candle_err)?))
|
407
407
|
}
|
408
408
|
|
409
|
-
// fn eq(&self, rhs: &Self) ->
|
409
|
+
// fn eq(&self, rhs: &Self) -> Result<Self> {
|
410
410
|
// Ok(Tensor(self.0.eq(rhs).map_err(wrap_candle_err)?))
|
411
411
|
// }
|
412
412
|
|
413
|
-
// fn ne(&self, rhs: &Self) ->
|
413
|
+
// fn ne(&self, rhs: &Self) -> Result<Self> {
|
414
414
|
// Ok(Tensor(self.0.ne(rhs).map_err(wrap_candle_err)?))
|
415
415
|
// }
|
416
416
|
|
417
|
-
// fn lt(&self, rhs: &Self) ->
|
417
|
+
// fn lt(&self, rhs: &Self) -> Result<Self> {
|
418
418
|
// Ok(Tensor(self.0.lt(rhs).map_err(wrap_candle_err)?))
|
419
419
|
// }
|
420
420
|
|
421
|
-
// fn gt(&self, rhs: &Self) ->
|
421
|
+
// fn gt(&self, rhs: &Self) -> Result<Self> {
|
422
422
|
// Ok(Tensor(self.0.gt(rhs).map_err(wrap_candle_err)?))
|
423
423
|
// }
|
424
424
|
|
425
|
-
// fn ge(&self, rhs: &Self) ->
|
425
|
+
// fn ge(&self, rhs: &Self) -> Result<Self> {
|
426
426
|
// Ok(Tensor(self.0.ge(rhs).map_err(wrap_candle_err)?))
|
427
427
|
// }
|
428
428
|
|
429
|
-
// fn le(&self, rhs: &Self) ->
|
429
|
+
// fn le(&self, rhs: &Self) -> Result<Self> {
|
430
430
|
// Ok(Tensor(self.0.le(rhs).map_err(wrap_candle_err)?))
|
431
431
|
// }
|
432
432
|
|
433
433
|
/// Returns the sum of the tensor.
|
434
434
|
/// &RETURNS&: Tensor
|
435
|
-
pub fn sum_all(&self) ->
|
435
|
+
pub fn sum_all(&self) -> Result<Self> {
|
436
436
|
Ok(Tensor(self.0.sum_all().map_err(wrap_candle_err)?))
|
437
437
|
}
|
438
438
|
|
439
439
|
/// Returns the mean of the tensor.
|
440
440
|
/// &RETURNS&: Tensor
|
441
|
-
pub fn mean_all(&self) ->
|
441
|
+
pub fn mean_all(&self) -> Result<Self> {
|
442
442
|
let elements = self.0.elem_count();
|
443
443
|
let sum = self.0.sum_all().map_err(wrap_candle_err)?;
|
444
444
|
let mean = (sum / elements as f64).map_err(wrap_candle_err)?;
|
@@ -447,33 +447,33 @@ impl Tensor {
|
|
447
447
|
|
448
448
|
/// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
449
449
|
/// &RETURNS&: Tensor
|
450
|
-
pub fn flatten_from(&self, dim: i64) ->
|
450
|
+
pub fn flatten_from(&self, dim: i64) -> Result<Self> {
|
451
451
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
452
452
|
Ok(Tensor(self.0.flatten_from(dim).map_err(wrap_candle_err)?))
|
453
453
|
}
|
454
454
|
|
455
455
|
///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
456
456
|
/// &RETURNS&: Tensor
|
457
|
-
pub fn flatten_to(&self, dim: i64) ->
|
457
|
+
pub fn flatten_to(&self, dim: i64) -> Result<Self> {
|
458
458
|
let dim = actual_dim(self, dim).map_err(wrap_candle_err)?;
|
459
459
|
Ok(Tensor(self.0.flatten_to(dim).map_err(wrap_candle_err)?))
|
460
460
|
}
|
461
461
|
|
462
462
|
/// Flattens the tensor into a 1D tensor.
|
463
463
|
/// &RETURNS&: Tensor
|
464
|
-
pub fn flatten_all(&self) ->
|
464
|
+
pub fn flatten_all(&self) -> Result<Self> {
|
465
465
|
Ok(Tensor(self.0.flatten_all().map_err(wrap_candle_err)?))
|
466
466
|
}
|
467
467
|
|
468
468
|
/// Transposes the tensor.
|
469
469
|
/// &RETURNS&: Tensor
|
470
|
-
pub fn t(&self) ->
|
470
|
+
pub fn t(&self) -> Result<Self> {
|
471
471
|
Ok(Tensor(self.0.t().map_err(wrap_candle_err)?))
|
472
472
|
}
|
473
473
|
|
474
474
|
/// Makes the tensor contiguous in memory.
|
475
475
|
/// &RETURNS&: Tensor
|
476
|
-
pub fn contiguous(&self) ->
|
476
|
+
pub fn contiguous(&self) -> Result<Self> {
|
477
477
|
Ok(Tensor(self.0.contiguous().map_err(wrap_candle_err)?))
|
478
478
|
}
|
479
479
|
|
@@ -491,26 +491,26 @@ impl Tensor {
|
|
491
491
|
|
492
492
|
/// Detach the tensor from the computation graph.
|
493
493
|
/// &RETURNS&: Tensor
|
494
|
-
pub fn detach(&self) ->
|
494
|
+
pub fn detach(&self) -> Result<Self> {
|
495
495
|
Ok(Tensor(self.0.detach()))
|
496
496
|
}
|
497
497
|
|
498
498
|
/// Returns a copy of the tensor.
|
499
499
|
/// &RETURNS&: Tensor
|
500
|
-
pub fn copy(&self) ->
|
500
|
+
pub fn copy(&self) -> Result<Self> {
|
501
501
|
Ok(Tensor(self.0.copy().map_err(wrap_candle_err)?))
|
502
502
|
}
|
503
503
|
|
504
504
|
/// Convert the tensor to a new dtype.
|
505
505
|
/// &RETURNS&: Tensor
|
506
|
-
pub fn to_dtype(&self, dtype: magnus::Symbol) ->
|
506
|
+
pub fn to_dtype(&self, dtype: magnus::Symbol) -> Result<Self> {
|
507
507
|
let dtype = DType::from_rbobject(dtype)?;
|
508
508
|
Ok(Tensor(self.0.to_dtype(dtype.0).map_err(wrap_candle_err)?))
|
509
509
|
}
|
510
510
|
|
511
511
|
/// Move the tensor to a new device.
|
512
512
|
/// &RETURNS&: Tensor
|
513
|
-
pub fn to_device(&self, device: Device) ->
|
513
|
+
pub fn to_device(&self, device: Device) -> Result<Self> {
|
514
514
|
let device = device.as_device()?;
|
515
515
|
Ok(Tensor(
|
516
516
|
self.0.to_device(&device).map_err(wrap_candle_err)?,
|
@@ -519,7 +519,7 @@ impl Tensor {
|
|
519
519
|
}
|
520
520
|
|
521
521
|
impl Tensor {
|
522
|
-
// fn cat(tensors: Vec<
|
522
|
+
// fn cat(tensors: Vec<Tensor>, dim: i64) -> Result<Tensor> {
|
523
523
|
// if tensors.is_empty() {
|
524
524
|
// return Err(Error::new(
|
525
525
|
// magnus::exception::arg_error(),
|
@@ -528,19 +528,19 @@ impl Tensor {
|
|
528
528
|
// }
|
529
529
|
// let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_candle_err)?;
|
530
530
|
// let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
531
|
-
// let tensor =
|
531
|
+
// let tensor = CoreTensor::cat(&tensors, dim).map_err(wrap_candle_err)?;
|
532
532
|
// Ok(Tensor(tensor))
|
533
533
|
// }
|
534
534
|
|
535
|
-
// fn stack(tensors: Vec<
|
535
|
+
// fn stack(tensors: Vec<Tensor>, dim: usize) -> Result<Self> {
|
536
536
|
// let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
537
|
-
// let tensor =
|
537
|
+
// let tensor = CoreTensor::stack(&tensors, dim).map_err(wrap_candle_err)?;
|
538
538
|
// Ok(Self(tensor))
|
539
539
|
// }
|
540
540
|
|
541
541
|
/// Creates a new tensor with random values.
|
542
542
|
/// &RETURNS&: Tensor
|
543
|
-
pub fn rand(shape: Vec<usize>, device: Option<Device>) ->
|
543
|
+
pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
544
544
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
545
545
|
Ok(Self(
|
546
546
|
CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
@@ -549,7 +549,7 @@ impl Tensor {
|
|
549
549
|
|
550
550
|
/// Creates a new tensor with random values from a normal distribution.
|
551
551
|
/// &RETURNS&: Tensor
|
552
|
-
pub fn randn(shape: Vec<usize>, device: Option<Device>) ->
|
552
|
+
pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
553
553
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
554
554
|
Ok(Self(
|
555
555
|
CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
@@ -558,7 +558,7 @@ impl Tensor {
|
|
558
558
|
|
559
559
|
/// Creates a new tensor filled with ones.
|
560
560
|
/// &RETURNS&: Tensor
|
561
|
-
pub fn ones(shape: Vec<usize>, device: Option<Device>) ->
|
561
|
+
pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
562
562
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
563
563
|
Ok(Self(
|
564
564
|
CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
@@ -566,7 +566,7 @@ impl Tensor {
|
|
566
566
|
}
|
567
567
|
/// Creates a new tensor filled with zeros.
|
568
568
|
/// &RETURNS&: Tensor
|
569
|
-
pub fn zeros(shape: Vec<usize>, device: Option<Device>) ->
|
569
|
+
pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
570
570
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
571
571
|
Ok(Self(
|
572
572
|
CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
@@ -574,7 +574,7 @@ impl Tensor {
|
|
574
574
|
}
|
575
575
|
}
|
576
576
|
|
577
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
577
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
578
578
|
let rb_tensor = rb_candle.define_class("Tensor", class::object())?;
|
579
579
|
rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
|
580
580
|
// rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
|