torch-rb 0.23.0 → 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/torch/hub.rb CHANGED
@@ -6,37 +6,17 @@ module Torch
6
6
  end
7
7
 
8
8
  def download_url_to_file(url, dst)
9
- uri = URI(url)
10
- tmp = nil
11
- location = nil
9
+ require "open-uri"
12
10
 
13
- puts "Downloading #{url}..."
14
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
15
- request = Net::HTTP::Get.new(uri)
16
-
17
- http.request(request) do |response|
18
- case response
19
- when Net::HTTPRedirection
20
- location = response["location"]
21
- when Net::HTTPSuccess
22
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
- File.open(tmp, "wb") do |f|
24
- response.read_body do |chunk|
25
- f.write(chunk)
26
- end
27
- end
28
- else
29
- raise Error, "Bad response"
30
- end
31
- end
32
- end
11
+ uri = URI.parse(url)
12
+ raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https
33
13
 
34
- if location
35
- download_url_to_file(location, dst)
36
- else
37
- FileUtils.mv(tmp, dst)
38
- nil
14
+ puts "Downloading #{url}..."
15
+ uri.open(max_redirects: 10) do |download|
16
+ # TODO move file when possible
17
+ IO.copy_stream(download, dst.to_str)
39
18
  end
19
+ nil
40
20
  end
41
21
 
42
22
  def load_state_dict_from_url(url, model_dir: nil)
@@ -197,7 +197,7 @@ module Torch
197
197
  named_buffers.values
198
198
  end
199
199
 
200
- # TODO set recurse: true in 0.18.0
200
+ # TODO set recurse: true in future
201
201
  def named_buffers(prefix: "", recurse: false)
202
202
  buffers = {}
203
203
  if recurse
@@ -161,7 +161,7 @@ module Torch
161
161
  private
162
162
 
163
163
  def _flat_weights
164
- @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
164
+ @all_weights.flatten.filter_map { |v| instance_variable_get("@#{v}") }
165
165
  end
166
166
 
167
167
  def _get_flat_weights
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.23.0"
2
+ VERSION = "0.24.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -279,7 +279,7 @@ module Torch
279
279
  def self._make_tensor_class(dtype, cuda = false)
280
280
  cls = Class.new
281
281
  device = cuda ? "cuda" : "cpu"
282
- cls.define_singleton_method("new") do |*args|
282
+ cls.define_singleton_method(:new) do |*args|
283
283
  if args.size == 1 && args.first.is_a?(Tensor)
284
284
  args.first.send(dtype).to(device)
285
285
  elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
@@ -347,7 +347,7 @@ module Torch
347
347
  # from_blob does not own the data, so we need to keep
348
348
  # a reference to it for duration of tensor
349
349
  # can remove when passing pointer directly
350
- tensor.instance_variable_set("@_numo_data", data)
350
+ tensor.instance_variable_set(:@_numo_data, data)
351
351
  tensor
352
352
  end
353
353
 
@@ -426,11 +426,11 @@ module Torch
426
426
  end
427
427
 
428
428
  if options[:dtype].nil?
429
- if data.all? { |v| v.is_a?(Integer) }
429
+ if data.all?(Integer)
430
430
  options[:dtype] = :int64
431
431
  elsif data.all? { |v| v == true || v == false }
432
432
  options[:dtype] = :bool
433
- elsif data.any? { |v| v.is_a?(Complex) }
433
+ elsif data.any?(Complex)
434
434
  options[:dtype] = :complex64
435
435
  end
436
436
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.23.0
4
+ version: 0.24.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -234,14 +234,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
234
234
  requirements:
235
235
  - - ">="
236
236
  - !ruby/object:Gem::Version
237
- version: '3.2'
237
+ version: '3.3'
238
238
  required_rubygems_version: !ruby/object:Gem::Requirement
239
239
  requirements:
240
240
  - - ">="
241
241
  - !ruby/object:Gem::Version
242
242
  version: '0'
243
243
  requirements: []
244
- rubygems_version: 4.0.3
244
+ rubygems_version: 4.0.6
245
245
  specification_version: 4
246
246
  summary: Deep learning for Ruby, powered by LibTorch
247
247
  test_files: []