torch-rb 0.2.5 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -286,8 +286,11 @@ module Torch
286
286
  str % vars
287
287
  end
288
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
289
292
  def dict
290
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
291
294
  end
292
295
  end
293
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,6 +26,14 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
28
37
  # TODO make more performant
29
38
  def to_a
30
39
  arr = _flat_data
@@ -38,10 +47,15 @@ module Torch
38
47
  end
39
48
  end
40
49
 
41
- # TODO support dtype
42
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
43
52
  device = Device.new(device) if device.is_a?(String)
44
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
45
59
  end
46
60
 
47
61
  def cpu
@@ -89,8 +103,9 @@ module Torch
89
103
  Torch.empty(0, dtype: dtype)
90
104
  end
91
105
 
92
- def backward(gradient = nil)
93
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
94
109
  end
95
110
 
96
111
  # TODO read directly from memory
@@ -153,12 +168,25 @@ module Torch
153
168
  neg
154
169
  end
155
170
 
171
+ def &(other)
172
+ logical_and(other)
173
+ end
174
+
175
+ def |(other)
176
+ logical_or(other)
177
+ end
178
+
179
+ def ^(other)
180
+ logical_xor(other)
181
+ end
182
+
156
183
  # TODO better compare?
157
184
  def <=>(other)
158
185
  item <=> other
159
186
  end
160
187
 
161
- # based on python_variable_indexing.cpp
188
+ # based on python_variable_indexing.cpp and
189
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
190
  def [](*indexes)
163
191
  result = self
164
192
  dim = 0
@@ -170,6 +198,8 @@ module Torch
170
198
  finish += 1 unless index.exclude_end?
171
199
  result = result._slice_tensor(dim, index.begin, finish, 1)
172
200
  dim += 1
201
+ elsif index.is_a?(Tensor)
202
+ result = result.index([index])
173
203
  elsif index.nil?
174
204
  result = result.unsqueeze(dim)
175
205
  dim += 1
@@ -183,19 +213,21 @@ module Torch
183
213
  result
184
214
  end
185
215
 
186
- # TODO
187
- # based on python_variable_indexing.cpp
216
+ # based on python_variable_indexing.cpp and
217
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
188
218
  def []=(index, value)
189
219
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
190
220
 
191
- value = Torch.tensor(value) unless value.is_a?(Tensor)
221
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
192
222
 
193
223
  if index.is_a?(Numeric)
194
- copy_to(_select_int(0, index), value)
224
+ index_put!([Torch.tensor(index)], value)
195
225
  elsif index.is_a?(Range)
196
226
  finish = index.end
197
227
  finish += 1 unless index.exclude_end?
198
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
228
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
+ elsif index.is_a?(Tensor)
230
+ index_put!([index], value)
199
231
  else
200
232
  raise Error, "Unsupported index type: #{index.class.name}"
201
233
  end
@@ -224,10 +256,9 @@ module Torch
224
256
  end
225
257
  end
226
258
 
227
- private
228
-
229
- def copy_to(dst, src)
230
- dst.copy!(src)
259
+ def clamp!(min, max)
260
+ _clamp_min_(min)
261
+ _clamp_max_(max)
231
262
  end
232
263
  end
233
264
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.5"
2
+ VERSION = "0.3.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.5
4
+ version: 0.3.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-07 00:00:00.000000000 Z
11
+ date: 2020-08-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb