torch-rb 0.2.4 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -286,8 +286,11 @@ module Torch
286
286
  str % vars
287
287
  end
288
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
289
292
  def dict
290
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
291
294
  end
292
295
  end
293
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,14 +26,36 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
- # TODO support dtype
33
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
34
52
  device = Device.new(device) if device.is_a?(String)
35
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
36
59
  end
37
60
 
38
61
  def cpu
@@ -64,7 +87,7 @@ module Torch
64
87
  if numel != 1
65
88
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
89
  end
67
- _flat_data.first
90
+ to_a.first
68
91
  end
69
92
 
70
93
  def to_i
@@ -80,15 +103,16 @@ module Torch
80
103
  Torch.empty(0, dtype: dtype)
81
104
  end
82
105
 
83
- def backward(gradient = nil)
84
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
85
109
  end
86
110
 
87
111
  # TODO read directly from memory
88
112
  def numo
89
113
  cls = Torch._dtype_to_numo[dtype]
90
114
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
91
- cls.cast(_flat_data).reshape(*shape)
115
+ cls.from_string(_data_str).reshape(*shape)
92
116
  end
93
117
 
94
118
  def new_ones(*size, **options)
@@ -116,15 +140,6 @@ module Torch
116
140
  _view(size)
117
141
  end
118
142
 
119
- # value and other are swapped for some methods
120
- def add!(value = 1, other)
121
- if other.is_a?(Numeric)
122
- _add__scalar(other, value)
123
- else
124
- _add__tensor(other, value)
125
- end
126
- end
127
-
128
143
  def +(other)
129
144
  add(other)
130
145
  end
@@ -153,12 +168,25 @@ module Torch
153
168
  neg
154
169
  end
155
170
 
171
+ def &(other)
172
+ logical_and(other)
173
+ end
174
+
175
+ def |(other)
176
+ logical_or(other)
177
+ end
178
+
179
+ def ^(other)
180
+ logical_xor(other)
181
+ end
182
+
156
183
  # TODO better compare?
157
184
  def <=>(other)
158
185
  item <=> other
159
186
  end
160
187
 
161
- # based on python_variable_indexing.cpp
188
+ # based on python_variable_indexing.cpp and
189
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
190
  def [](*indexes)
163
191
  result = self
164
192
  dim = 0
@@ -170,6 +198,8 @@ module Torch
170
198
  finish += 1 unless index.exclude_end?
171
199
  result = result._slice_tensor(dim, index.begin, finish, 1)
172
200
  dim += 1
201
+ elsif index.is_a?(Tensor)
202
+ result = result.index([index])
173
203
  elsif index.nil?
174
204
  result = result.unsqueeze(dim)
175
205
  dim += 1
@@ -183,24 +213,37 @@ module Torch
183
213
  result
184
214
  end
185
215
 
186
- # TODO
187
- # based on python_variable_indexing.cpp
216
+ # based on python_variable_indexing.cpp and
217
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
188
218
  def []=(index, value)
189
219
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
190
220
 
191
- value = Torch.tensor(value) unless value.is_a?(Tensor)
221
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
192
222
 
193
223
  if index.is_a?(Numeric)
194
- copy_to(_select_int(0, index), value)
224
+ index_put!([Torch.tensor(index)], value)
195
225
  elsif index.is_a?(Range)
196
226
  finish = index.end
197
227
  finish += 1 unless index.exclude_end?
198
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
228
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
+ elsif index.is_a?(Tensor)
230
+ index_put!([index], value)
199
231
  else
200
232
  raise Error, "Unsupported index type: #{index.class.name}"
201
233
  end
202
234
  end
203
235
 
236
+ # native functions that need manually defined
237
+
238
+ # value and other are swapped for some methods
239
+ def add!(value = 1, other)
240
+ if other.is_a?(Numeric)
241
+ _add__scalar(other, value)
242
+ else
243
+ _add__tensor(other, value)
244
+ end
245
+ end
246
+
204
247
  # native functions overlap, so need to handle manually
205
248
  def random!(*args)
206
249
  case args.size
@@ -213,22 +256,9 @@ module Torch
213
256
  end
214
257
  end
215
258
 
216
- private
217
-
218
- def copy_to(dst, src)
219
- dst.copy!(src)
220
- end
221
-
222
- def reshape_arr(arr, dims)
223
- if dims.empty?
224
- arr
225
- else
226
- arr = arr.flatten
227
- dims[1..-1].reverse.each do |dim|
228
- arr = arr.each_slice(dim)
229
- end
230
- arr.to_a
231
- end
259
+ def clamp!(min, max)
260
+ _clamp_min_(min)
261
+ _clamp_max_(max)
232
262
  end
233
263
  end
234
264
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.4"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.4
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-29 00:00:00.000000000 Z
11
+ date: 2020-08-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice