torch-rb 0.8.2 → 0.9.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +5 -4
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +30 -33
- data/ext/torch/tensor.cpp +33 -33
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +9 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +22 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +14 -48
- metadata +3 -3
data/lib/torch.rb
CHANGED
@@ -267,7 +267,7 @@ module Torch
|
|
267
267
|
args.first.send(dtype).to(device)
|
268
268
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
269
269
|
bytes = args.first.bytes
|
270
|
-
Torch.
|
270
|
+
Torch._from_blob_ref(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
271
271
|
elsif args.size == 1 && args.first.is_a?(Array)
|
272
272
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
273
273
|
elsif args.size == 0
|
@@ -320,12 +320,17 @@ module Torch
|
|
320
320
|
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
321
321
|
options = tensor_options(device: "cpu", dtype: dtype[0])
|
322
322
|
# TODO pass pointer to array instead of creating string
|
323
|
-
|
324
|
-
|
323
|
+
_from_blob_ref(ndarray.to_string, ndarray.shape, options)
|
324
|
+
end
|
325
|
+
|
326
|
+
# private
|
327
|
+
# TODO use keepAlive in Rice (currently segfaults)
|
328
|
+
def _from_blob_ref(data, size, options)
|
329
|
+
tensor = _from_blob(data, size, options)
|
325
330
|
# from_blob does not own the data, so we need to keep
|
326
331
|
# a reference to it for duration of tensor
|
327
332
|
# can remove when passing pointer directly
|
328
|
-
tensor.instance_variable_set("@
|
333
|
+
tensor.instance_variable_set("@_numo_data", data)
|
329
334
|
tensor
|
330
335
|
end
|
331
336
|
|
@@ -377,8 +382,6 @@ module Torch
|
|
377
382
|
to_ruby(_load(File.binread(f)))
|
378
383
|
end
|
379
384
|
|
380
|
-
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
381
|
-
|
382
385
|
def tensor(data, **options)
|
383
386
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
384
387
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
@@ -408,42 +411,13 @@ module Torch
|
|
408
411
|
end
|
409
412
|
end
|
410
413
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
# --- begin like ---
|
415
|
-
|
416
|
-
def ones_like(input, **options)
|
417
|
-
ones(input.size, **like_options(input, options))
|
418
|
-
end
|
419
|
-
|
420
|
-
def empty_like(input, **options)
|
421
|
-
empty(input.size, **like_options(input, options))
|
422
|
-
end
|
414
|
+
# TODO check each dimensions for consistency in future
|
415
|
+
raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
|
423
416
|
|
424
|
-
|
425
|
-
|
426
|
-
end
|
417
|
+
# TOOD move to C++
|
418
|
+
data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
|
427
419
|
|
428
|
-
|
429
|
-
rand(input.size, **like_options(input, options))
|
430
|
-
end
|
431
|
-
|
432
|
-
def randint_like(input, low, high = nil, **options)
|
433
|
-
# ruby doesn't support input, low = 0, high, ...
|
434
|
-
if high.nil?
|
435
|
-
high = low
|
436
|
-
low = 0
|
437
|
-
end
|
438
|
-
randint(low, high, input.size, **like_options(input, options))
|
439
|
-
end
|
440
|
-
|
441
|
-
def randn_like(input, **options)
|
442
|
-
randn(input.size, **like_options(input, options))
|
443
|
-
end
|
444
|
-
|
445
|
-
def zeros_like(input, **options)
|
446
|
-
zeros(input.size, **like_options(input, options))
|
420
|
+
_tensor(data, size, tensor_options(**options))
|
447
421
|
end
|
448
422
|
|
449
423
|
# center option
|
@@ -572,13 +546,5 @@ module Torch
|
|
572
546
|
end
|
573
547
|
options
|
574
548
|
end
|
575
|
-
|
576
|
-
def like_options(input, options)
|
577
|
-
options = options.dup
|
578
|
-
options[:dtype] ||= input.dtype
|
579
|
-
options[:layout] ||= input.layout
|
580
|
-
options[:device] ||= input.device
|
581
|
-
options
|
582
|
-
end
|
583
549
|
end
|
584
550
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-02-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -227,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
227
|
- !ruby/object:Gem::Version
|
228
228
|
version: '0'
|
229
229
|
requirements: []
|
230
|
-
rubygems_version: 3.
|
230
|
+
rubygems_version: 3.3.3
|
231
231
|
signing_key:
|
232
232
|
specification_version: 4
|
233
233
|
summary: Deep learning for Ruby, powered by LibTorch
|