torch-rb 0.2.2 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performance, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = to_a.flatten
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end
@@ -66,6 +66,7 @@ module Torch
66
66
  end
67
67
 
68
68
  next if t == "Generator?"
69
+ next if t == "MemoryFormat"
69
70
  next if t == "MemoryFormat?"
70
71
  args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
71
72
  end
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
@@ -31,7 +31,7 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
34
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
36
  # native_functions.yaml is missing size argument for normal
37
37
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
112
112
  "OptionalScalarType"
113
113
  when "Tensor[]"
114
114
  "TensorList"
115
+ when "Tensor?[]"
116
+ # TODO make optional
117
+ "TensorList"
115
118
  when "int"
116
119
  "int64_t"
117
120
  when "float"
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -224,13 +238,15 @@ module Torch
224
238
 
225
239
  def inspect
226
240
  name = self.class.name.split("::").last
227
- if children.empty?
241
+ if named_children.empty?
228
242
  "#{name}(#{extra_inspect})"
229
243
  else
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
- children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
246
+ named_children.each do |name, mod|
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
@@ -270,8 +286,11 @@ module Torch
270
286
  str % vars
271
287
  end
272
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
273
292
  def dict
274
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
275
294
  end
276
295
  end
277
296
  end