torch-rb 0.2.5 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
4
- data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
3
+ metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
+ data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
5
5
  SHA512:
6
- metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
7
- data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
6
+ metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
+ data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
@@ -1,3 +1,9 @@
1
+ ## 0.2.6 (2020-06-29)
2
+
3
+ - Added support for indexing with tensors
4
+ - Added `contiguous` methods
5
+ - Fixed named parameters for nested parameters
6
+
1
7
  ## 0.2.5 (2020-06-07)
2
8
 
3
9
  - Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
data/README.md CHANGED
@@ -409,7 +409,7 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.2.0 | 1.5.0
412
+ 0.2.0+ | 1.5.0+
413
413
  0.1.8 | 1.4.0
414
414
  0.1.0-0.1.7 | 1.3.1
415
415
 
@@ -329,6 +329,11 @@ void Init_ext()
329
329
  .define_method("numel", &torch::Tensor::numel)
330
330
  .define_method("element_size", &torch::Tensor::element_size)
331
331
  .define_method("requires_grad", &torch::Tensor::requires_grad)
332
+ .define_method(
333
+ "contiguous?",
334
+ *[](Tensor& self) {
335
+ return self.is_contiguous();
336
+ })
332
337
  .define_method(
333
338
  "addcmul!",
334
339
  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
@@ -66,6 +66,7 @@ module Torch
66
66
  end
67
67
 
68
68
  next if t == "Generator?"
69
+ next if t == "MemoryFormat"
69
70
  next if t == "MemoryFormat?"
70
71
  args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
71
72
  end
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
@@ -31,7 +31,7 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
34
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
36
  # native_functions.yaml is missing size argument for normal
37
37
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
112
112
  "OptionalScalarType"
113
113
  when "Tensor[]"
114
114
  "TensorList"
115
+ when "Tensor?[]"
116
+ # TODO make optional
117
+ "TensorList"
115
118
  when "int"
116
119
  "int64_t"
117
120
  when "float"
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -158,7 +158,8 @@ module Torch
158
158
  item <=> other
159
159
  end
160
160
 
161
- # based on python_variable_indexing.cpp
161
+ # based on python_variable_indexing.cpp and
162
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
163
  def [](*indexes)
163
164
  result = self
164
165
  dim = 0
@@ -170,6 +171,8 @@ module Torch
170
171
  finish += 1 unless index.exclude_end?
171
172
  result = result._slice_tensor(dim, index.begin, finish, 1)
172
173
  dim += 1
174
+ elsif index.is_a?(Tensor)
175
+ result = result.index([index])
173
176
  elsif index.nil?
174
177
  result = result.unsqueeze(dim)
175
178
  dim += 1
@@ -183,12 +186,12 @@ module Torch
183
186
  result
184
187
  end
185
188
 
186
- # TODO
187
- # based on python_variable_indexing.cpp
189
+ # based on python_variable_indexing.cpp and
190
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
188
191
  def []=(index, value)
189
192
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
190
193
 
191
- value = Torch.tensor(value) unless value.is_a?(Tensor)
194
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
192
195
 
193
196
  if index.is_a?(Numeric)
194
197
  copy_to(_select_int(0, index), value)
@@ -196,6 +199,8 @@ module Torch
196
199
  finish = index.end
197
200
  finish += 1 unless index.exclude_end?
198
201
  copy_to(_slice_tensor(0, index.begin, finish, 1), value)
202
+ elsif index.is_a?(Tensor)
203
+ index_put!([index], value)
199
204
  else
200
205
  raise Error, "Unsupported index type: #{index.class.name}"
201
206
  end
@@ -224,6 +229,11 @@ module Torch
224
229
  end
225
230
  end
226
231
 
232
+ def clamp!(min, max)
233
+ _clamp_min_(min)
234
+ _clamp_max_(max)
235
+ end
236
+
227
237
  private
228
238
 
229
239
  def copy_to(dst, src)
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.5"
2
+ VERSION = "0.2.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.5
4
+ version: 0.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-07 00:00:00.000000000 Z
11
+ date: 2020-06-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice