torch-rb 0.2.3 → 0.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/torch/hub.rb +8 -0
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +19 -3
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/utils/data/data_loader.rb +11 -6
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +6 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 67c5a0cf556399dc32d73e8793e3aa794c181150f0f42dfa810c4b98a5acf6f2
|
4
|
+
data.tar.gz: 0a23f6a42595fb9d599962e88438b964180583ead5b9cce934cc447951b4a389
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c0f8e9e3395d196d7ea6fa4b40d128284d768033e02f4ed7d2dc9adc985015fd0a80d601601dd97438b803b6a3bd7b81f5dbda353bb5dee4247503a24cd755d7
|
7
|
+
data.tar.gz: c32a22ebbe1b4dfd77324f62a72d6a128639aac0a99d4c5255b16c606e6f961ae2c8b0dbab5012a9b21faa7409511b79a50676bc8314f181c85f90433433fa8b
|
data/CHANGELOG.md
CHANGED
data/lib/torch/hub.rb
CHANGED
data/lib/torch/inspector.rb
CHANGED
data/lib/torch/nn/batch_norm.rb
CHANGED
@@ -70,6 +70,11 @@ module Torch
|
|
70
70
|
momentum: exponential_average_factor, eps: @eps
|
71
71
|
)
|
72
72
|
end
|
73
|
+
|
74
|
+
def extra_inspect
|
75
|
+
s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
|
76
|
+
format(s, **dict)
|
77
|
+
end
|
73
78
|
end
|
74
79
|
end
|
75
80
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -20,7 +20,14 @@ module Torch
|
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
22
|
def extra_inspect
|
23
|
-
|
23
|
+
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
|
+
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
25
|
+
s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
|
26
|
+
s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
|
27
|
+
s += ", groups: %{groups}" if @groups != 1
|
28
|
+
s += ", bias: false" unless @bias
|
29
|
+
s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
|
30
|
+
format(s, **dict)
|
24
31
|
end
|
25
32
|
end
|
26
33
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
data/lib/torch/nn/max_poolnd.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -186,8 +186,22 @@ module Torch
|
|
186
186
|
named_modules.values
|
187
187
|
end
|
188
188
|
|
189
|
-
|
190
|
-
|
189
|
+
# TODO return enumerator?
|
190
|
+
def named_modules(memo: nil, prefix: "")
|
191
|
+
ret = {}
|
192
|
+
memo ||= Set.new
|
193
|
+
unless memo.include?(self)
|
194
|
+
memo << self
|
195
|
+
ret[prefix] = self
|
196
|
+
named_children.each do |name, mod|
|
197
|
+
next unless mod.is_a?(Module)
|
198
|
+
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
|
199
|
+
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
|
200
|
+
ret[m[0]] = m[1]
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
ret
|
191
205
|
end
|
192
206
|
|
193
207
|
def train(mode = true)
|
@@ -230,7 +244,9 @@ module Torch
|
|
230
244
|
str = String.new
|
231
245
|
str << "#{name}(\n"
|
232
246
|
named_children.each do |name, mod|
|
233
|
-
|
247
|
+
mod_str = mod.inspect
|
248
|
+
mod_str = mod_str.lines.join(" ")
|
249
|
+
str << " (#{name}): #{mod_str}\n"
|
234
250
|
end
|
235
251
|
str << ")"
|
236
252
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -6,9 +6,10 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
@@ -16,11 +17,15 @@ module Torch
|
|
16
17
|
# this makes it easy to compare results
|
17
18
|
base_seed = Torch.empty([], dtype: :int64).random!.item
|
18
19
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
24
29
|
yield collate(batch)
|
25
30
|
end
|
26
31
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -260,6 +260,7 @@ files:
|
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
263
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
264
265
|
- lib/torch/version.rb
|
265
266
|
homepage: https://github.com/ankane/torch.rb
|