torch-rb 0.23.0 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +3 -1
- data/codegen/generate_functions.rb +1 -6
- data/codegen/native_functions.yaml +229 -58
- data/ext/torch/ivalue.cpp +4 -4
- data/ext/torch/nn.cpp +2 -1
- data/ext/torch/ruby_arg_parser.cpp +23 -23
- data/ext/torch/ruby_arg_parser.h +16 -16
- data/ext/torch/templates.h +4 -4
- data/ext/torch/tensor.cpp +17 -24
- data/ext/torch/torch.cpp +6 -6
- data/ext/torch/utils.h +5 -5
- data/ext/torch/wrap_outputs.h +29 -22
- data/lib/torch/hub.rb +8 -28
- data/lib/torch/nn/module.rb +1 -1
- data/lib/torch/nn/rnn_base.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +4 -4
- metadata +3 -3
data/lib/torch/hub.rb
CHANGED
|
@@ -6,37 +6,17 @@ module Torch
|
|
|
6
6
|
end
|
|
7
7
|
|
|
8
8
|
def download_url_to_file(url, dst)
|
|
9
|
-
uri
|
|
10
|
-
tmp = nil
|
|
11
|
-
location = nil
|
|
9
|
+
require "open-uri"
|
|
12
10
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
request = Net::HTTP::Get.new(uri)
|
|
16
|
-
|
|
17
|
-
http.request(request) do |response|
|
|
18
|
-
case response
|
|
19
|
-
when Net::HTTPRedirection
|
|
20
|
-
location = response["location"]
|
|
21
|
-
when Net::HTTPSuccess
|
|
22
|
-
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
|
23
|
-
File.open(tmp, "wb") do |f|
|
|
24
|
-
response.read_body do |chunk|
|
|
25
|
-
f.write(chunk)
|
|
26
|
-
end
|
|
27
|
-
end
|
|
28
|
-
else
|
|
29
|
-
raise Error, "Bad response"
|
|
30
|
-
end
|
|
31
|
-
end
|
|
32
|
-
end
|
|
11
|
+
uri = URI.parse(url)
|
|
12
|
+
raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https
|
|
33
13
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
nil
|
|
14
|
+
puts "Downloading #{url}..."
|
|
15
|
+
uri.open(max_redirects: 10) do |download|
|
|
16
|
+
# TODO move file when possible
|
|
17
|
+
IO.copy_stream(download, dst.to_str)
|
|
39
18
|
end
|
|
19
|
+
nil
|
|
40
20
|
end
|
|
41
21
|
|
|
42
22
|
def load_state_dict_from_url(url, model_dir: nil)
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/nn/rnn_base.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
|
@@ -279,7 +279,7 @@ module Torch
|
|
|
279
279
|
def self._make_tensor_class(dtype, cuda = false)
|
|
280
280
|
cls = Class.new
|
|
281
281
|
device = cuda ? "cuda" : "cpu"
|
|
282
|
-
cls.define_singleton_method(
|
|
282
|
+
cls.define_singleton_method(:new) do |*args|
|
|
283
283
|
if args.size == 1 && args.first.is_a?(Tensor)
|
|
284
284
|
args.first.send(dtype).to(device)
|
|
285
285
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
|
@@ -347,7 +347,7 @@ module Torch
|
|
|
347
347
|
# from_blob does not own the data, so we need to keep
|
|
348
348
|
# a reference to it for duration of tensor
|
|
349
349
|
# can remove when passing pointer directly
|
|
350
|
-
tensor.instance_variable_set(
|
|
350
|
+
tensor.instance_variable_set(:@_numo_data, data)
|
|
351
351
|
tensor
|
|
352
352
|
end
|
|
353
353
|
|
|
@@ -426,11 +426,11 @@ module Torch
|
|
|
426
426
|
end
|
|
427
427
|
|
|
428
428
|
if options[:dtype].nil?
|
|
429
|
-
if data.all?
|
|
429
|
+
if data.all?(Integer)
|
|
430
430
|
options[:dtype] = :int64
|
|
431
431
|
elsif data.all? { |v| v == true || v == false }
|
|
432
432
|
options[:dtype] = :bool
|
|
433
|
-
elsif data.any?
|
|
433
|
+
elsif data.any?(Complex)
|
|
434
434
|
options[:dtype] = :complex64
|
|
435
435
|
end
|
|
436
436
|
end
|
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: torch-rb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.24.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -234,14 +234,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
|
234
234
|
requirements:
|
|
235
235
|
- - ">="
|
|
236
236
|
- !ruby/object:Gem::Version
|
|
237
|
-
version: '3.
|
|
237
|
+
version: '3.3'
|
|
238
238
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
239
239
|
requirements:
|
|
240
240
|
- - ">="
|
|
241
241
|
- !ruby/object:Gem::Version
|
|
242
242
|
version: '0'
|
|
243
243
|
requirements: []
|
|
244
|
-
rubygems_version: 4.0.
|
|
244
|
+
rubygems_version: 4.0.6
|
|
245
245
|
specification_version: 4
|
|
246
246
|
summary: Deep learning for Ruby, powered by LibTorch
|
|
247
247
|
test_files: []
|