torch-rb 0.2.6 → 0.2.7

This diff has not been reviewed by any users.
Log in in order to be able to vote.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
- data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
3
+ metadata.gz: 3451d6140ae6a6a9294a73571239df703a9dc753911c5d97a83bcb020b9d878d
4
+ data.tar.gz: 65689090d9fe4d9dee078b2f0f0f56526d76158306390c0988e61b0e2ca98ff1
5
5
  SHA512:
6
- metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
- data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
6
+ metadata.gz: 9f2cc800b8c0e7a3a75bbb9c4705e7e306ed68f52a90530d22659a4d23d8ce0126c1cfd9bc7c33612a842bc20199ba8ec7f488bbba591073d8914f108948e084
7
+ data.tar.gz: dbf34592bef6e869a3814f20e891d2d566339080a46d335a5e42f114477a5769f63ee18ca3ee8b8f1d031faf898dfe4f6861064f0cb0773b6d75622b4a663e0f
@@ -1,3 +1,8 @@
1
+ ## 0.2.7 (2020-06-29)
2
+
3
+ - Made tensors enumerable
4
+ - Improved performance of `inspect` method
5
+
1
6
  ## 0.2.6 (2020-06-29)
2
7
 
3
8
  - Added support for indexing with tensors
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performant, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = _flat_data
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max > 1e8 || max < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end
@@ -286,8 +286,11 @@ module Torch
286
286
  str % vars
287
287
  end
288
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
289
292
  def dict
290
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
291
294
  end
292
295
  end
293
296
  end
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,6 +26,14 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
28
37
  # TODO make more performant
29
38
  def to_a
30
39
  arr = _flat_data
@@ -153,6 +162,18 @@ module Torch
153
162
  neg
154
163
  end
155
164
 
165
+ def &(other)
166
+ logical_and(other)
167
+ end
168
+
169
+ def |(other)
170
+ logical_or(other)
171
+ end
172
+
173
+ def ^(other)
174
+ logical_xor(other)
175
+ end
176
+
156
177
  # TODO better compare?
157
178
  def <=>(other)
158
179
  item <=> other
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.6"
2
+ VERSION = "0.2.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.6
4
+ version: 0.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-29 00:00:00.000000000 Z
11
+ date: 2020-06-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice