torch-rb 0.2.3 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -18,9 +18,15 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
22
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
23
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
24
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
25
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
26
+ s += ", groups: %{groups}" if @groups != 1
27
+ s += ", bias: false" unless @bias
28
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
29
+ format(s, **dict)
24
30
  end
25
31
  end
26
32
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -230,7 +244,9 @@ module Torch
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
246
  named_children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
@@ -270,8 +286,11 @@ module Torch
270
286
  str % vars
271
287
  end
272
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
273
292
  def dict
274
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
275
294
  end
276
295
  end
277
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,14 +26,36 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
- # TODO support dtype
33
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
34
52
  device = Device.new(device) if device.is_a?(String)
35
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
36
59
  end
37
60
 
38
61
  def cpu
@@ -64,7 +87,15 @@ module Torch
64
87
  if numel != 1
65
88
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
89
  end
67
- _flat_data.first
90
+ to_a.first
91
+ end
92
+
93
+ def to_i
94
+ item.to_i
95
+ end
96
+
97
+ def to_f
98
+ item.to_f
68
99
  end
69
100
 
70
101
  # unsure if this is correct
@@ -80,7 +111,7 @@ module Torch
80
111
  def numo
81
112
  cls = Torch._dtype_to_numo[dtype]
82
113
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
83
- cls.cast(_flat_data).reshape(*shape)
114
+ cls.from_string(_data_str).reshape(*shape)
84
115
  end
85
116
 
86
117
  def new_ones(*size, **options)
@@ -108,15 +139,6 @@ module Torch
108
139
  _view(size)
109
140
  end
110
141
 
111
- # value and other are swapped for some methods
112
- def add!(value = 1, other)
113
- if other.is_a?(Numeric)
114
- _add__scalar(other, value)
115
- else
116
- _add__tensor(other, value)
117
- end
118
- end
119
-
120
142
  def +(other)
121
143
  add(other)
122
144
  end
@@ -145,12 +167,25 @@ module Torch
145
167
  neg
146
168
  end
147
169
 
170
+ def &(other)
171
+ logical_and(other)
172
+ end
173
+
174
+ def |(other)
175
+ logical_or(other)
176
+ end
177
+
178
+ def ^(other)
179
+ logical_xor(other)
180
+ end
181
+
148
182
  # TODO better compare?
149
183
  def <=>(other)
150
184
  item <=> other
151
185
  end
152
186
 
153
- # based on python_variable_indexing.cpp
187
+ # based on python_variable_indexing.cpp and
188
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
154
189
  def [](*indexes)
155
190
  result = self
156
191
  dim = 0
@@ -162,6 +197,8 @@ module Torch
162
197
  finish += 1 unless index.exclude_end?
163
198
  result = result._slice_tensor(dim, index.begin, finish, 1)
164
199
  dim += 1
200
+ elsif index.is_a?(Tensor)
201
+ result = result.index([index])
165
202
  elsif index.nil?
166
203
  result = result.unsqueeze(dim)
167
204
  dim += 1
@@ -175,24 +212,37 @@ module Torch
175
212
  result
176
213
  end
177
214
 
178
- # TODO
179
- # based on python_variable_indexing.cpp
215
+ # based on python_variable_indexing.cpp and
216
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
180
217
  def []=(index, value)
181
218
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
219
 
183
- value = Torch.tensor(value) unless value.is_a?(Tensor)
220
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
184
221
 
185
222
  if index.is_a?(Numeric)
186
- copy_to(_select_int(0, index), value)
223
+ index_put!([Torch.tensor(index)], value)
187
224
  elsif index.is_a?(Range)
188
225
  finish = index.end
189
226
  finish += 1 unless index.exclude_end?
190
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
227
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
228
+ elsif index.is_a?(Tensor)
229
+ index_put!([index], value)
191
230
  else
192
231
  raise Error, "Unsupported index type: #{index.class.name}"
193
232
  end
194
233
  end
195
234
 
235
+ # native functions that need manually defined
236
+
237
+ # value and other are swapped for some methods
238
+ def add!(value = 1, other)
239
+ if other.is_a?(Numeric)
240
+ _add__scalar(other, value)
241
+ else
242
+ _add__tensor(other, value)
243
+ end
244
+ end
245
+
196
246
  # native functions overlap, so need to handle manually
197
247
  def random!(*args)
198
248
  case args.size
@@ -205,22 +255,9 @@ module Torch
205
255
  end
206
256
  end
207
257
 
208
- private
209
-
210
- def copy_to(dst, src)
211
- dst.copy!(src)
212
- end
213
-
214
- def reshape_arr(arr, dims)
215
- if dims.empty?
216
- arr
217
- else
218
- arr = arr.flatten
219
- dims[1..-1].reverse.each do |dim|
220
- arr = arr.each_slice(dim)
221
- end
222
- arr.to_a
223
- end
258
+ def clamp!(min, max)
259
+ _clamp_min_(min)
260
+ _clamp_max_(max)
224
261
  end
225
262
  end
226
263
  end
@@ -6,9 +6,10 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
@@ -16,11 +17,15 @@ module Torch
16
17
  # this makes it easy to compare results
17
18
  base_seed = Torch.empty([], dtype: :int64).random!.item
18
19
 
19
- max_size = @dataset.size
20
- size.times do |i|
21
- start_index = i * @batch_size
22
- end_index = [start_index + @batch_size, max_size].min
23
- batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
24
29
  yield collate(batch)
25
30
  end
26
31
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.3"
2
+ VERSION = "0.3.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-07-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -260,6 +260,7 @@ files:
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
263
264
  - lib/torch/utils/data/tensor_dataset.rb
264
265
  - lib/torch/version.rb
265
266
  homepage: https://github.com/ankane/torch.rb