torch-rb 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/tensor.rb +21 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3451d6140ae6a6a9294a73571239df703a9dc753911c5d97a83bcb020b9d878d
|
4
|
+
data.tar.gz: 65689090d9fe4d9dee078b2f0f0f56526d76158306390c0988e61b0e2ca98ff1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9f2cc800b8c0e7a3a75bbb9c4705e7e306ed68f52a90530d22659a4d23d8ce0126c1cfd9bc7c33612a842bc20199ba8ec7f488bbba591073d8914f108948e084
|
7
|
+
data.tar.gz: dbf34592bef6e869a3814f20e891d2d566339080a46d335a5e42f114477a5769f63ee18ca3ee8b8f1d031faf898dfe4f6861064f0cb0773b6d75622b4a663e0f
|
data/CHANGELOG.md
CHANGED
data/lib/torch/inspector.rb
CHANGED
@@ -1,89 +1,264 @@
|
|
1
|
+
# mirrors _tensor_str.py
|
1
2
|
module Torch
|
2
3
|
module Inspector
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
4
|
+
PRINT_OPTS = {
|
5
|
+
precision: 4,
|
6
|
+
threshold: 1000,
|
7
|
+
edgeitems: 3,
|
8
|
+
linewidth: 80,
|
9
|
+
sci_mode: nil
|
10
|
+
}
|
11
|
+
|
12
|
+
class Formatter
|
13
|
+
def initialize(tensor)
|
14
|
+
@floating_dtype = tensor.floating_point?
|
15
|
+
@complex_dtype = tensor.complex?
|
16
|
+
@int_mode = true
|
17
|
+
@sci_mode = false
|
18
|
+
@max_width = 1
|
19
|
+
|
20
|
+
tensor_view = Torch.no_grad { tensor.reshape(-1) }
|
21
|
+
|
22
|
+
if !@floating_dtype
|
23
|
+
tensor_view.each do |value|
|
24
|
+
value_str = value.item.to_s
|
25
|
+
@max_width = [@max_width, value_str.length].max
|
26
|
+
end
|
11
27
|
else
|
12
|
-
|
28
|
+
nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
|
29
|
+
|
30
|
+
# no valid number, do nothing
|
31
|
+
return if nonzero_finite_vals.numel == 0
|
32
|
+
|
33
|
+
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
+
nonzero_finite_abs = nonzero_finite_vals.abs.double
|
35
|
+
nonzero_finite_min = nonzero_finite_abs.min.double
|
36
|
+
nonzero_finite_max = nonzero_finite_abs.max.double
|
37
|
+
|
38
|
+
nonzero_finite_vals.each do |value|
|
39
|
+
if value.item != value.item.ceil
|
40
|
+
@int_mode = false
|
41
|
+
break
|
42
|
+
end
|
43
|
+
end
|
13
44
|
|
14
|
-
if
|
15
|
-
|
45
|
+
if @int_mode
|
46
|
+
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
47
|
+
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
48
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
|
49
|
+
@sci_mode = true
|
50
|
+
nonzero_finite_vals.each do |value|
|
51
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
52
|
+
@max_width = [@max_width, value_str.length].max
|
53
|
+
end
|
54
|
+
else
|
55
|
+
nonzero_finite_vals.each do |value|
|
56
|
+
value_str = "%.0f" % value.item
|
57
|
+
@max_width = [@max_width, value_str.length + 1].max
|
58
|
+
end
|
59
|
+
end
|
16
60
|
else
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
61
|
+
# Check if scientific representation should be used.
|
62
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
|
63
|
+
@sci_mode = true
|
64
|
+
nonzero_finite_vals.each do |value|
|
65
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
66
|
+
@max_width = [@max_width, value_str.length].max
|
67
|
+
end
|
68
|
+
else
|
69
|
+
nonzero_finite_vals.each do |value|
|
70
|
+
value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
|
71
|
+
@max_width = [@max_width, value_str.length].max
|
72
|
+
end
|
25
73
|
end
|
74
|
+
end
|
75
|
+
end
|
26
76
|
|
27
|
-
|
28
|
-
|
77
|
+
@sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
|
78
|
+
end
|
29
79
|
|
30
|
-
|
31
|
-
|
80
|
+
def width
|
81
|
+
@max_width
|
82
|
+
end
|
32
83
|
|
33
|
-
|
84
|
+
def format(value)
|
85
|
+
value = value.item
|
34
86
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
fmt = "%#{total}d"
|
87
|
+
if @floating_dtype
|
88
|
+
if @sci_mode
|
89
|
+
ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
|
90
|
+
elsif @int_mode
|
91
|
+
ret = String.new("%.0f" % value)
|
92
|
+
unless value.infinite? || value.nan?
|
93
|
+
ret += "."
|
43
94
|
end
|
95
|
+
else
|
96
|
+
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
44
97
|
end
|
98
|
+
elsif @complex_dtype
|
99
|
+
p = PRINT_OPTS[:precision]
|
100
|
+
raise NotImplementedYet
|
101
|
+
else
|
102
|
+
ret = value.to_s
|
103
|
+
end
|
104
|
+
# Ruby throws error when negative, Python doesn't
|
105
|
+
" " * [@max_width - ret.size, 0].max + ret
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def inspect
|
110
|
+
Torch.no_grad do
|
111
|
+
str_intern(self)
|
112
|
+
end
|
113
|
+
rescue => e
|
114
|
+
# prevent stack error
|
115
|
+
puts e.backtrace.join("\n")
|
116
|
+
"Error inspecting tensor: #{e.inspect}"
|
117
|
+
end
|
118
|
+
|
119
|
+
private
|
120
|
+
|
121
|
+
# TODO update
|
122
|
+
def str_intern(slf)
|
123
|
+
prefix = "tensor("
|
124
|
+
indent = prefix.length
|
125
|
+
suffixes = []
|
126
|
+
|
127
|
+
has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
|
128
|
+
|
129
|
+
if slf.numel == 0 && !slf.sparse?
|
130
|
+
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
131
|
+
if slf.dim != 1
|
132
|
+
suffixes << "size: #{shape.inspect}"
|
133
|
+
end
|
45
134
|
|
46
|
-
|
135
|
+
# In an empty tensor, there are no elements to infer if the dtype
|
136
|
+
# should be int64, so it must be shown explicitly.
|
137
|
+
if slf.dtype != :int64
|
138
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
47
139
|
end
|
140
|
+
tensor_str = "[]"
|
141
|
+
else
|
142
|
+
if !has_default_dtype
|
143
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
144
|
+
end
|
145
|
+
|
146
|
+
if slf.layout != :strided
|
147
|
+
tensor_str = tensor_str(slf.to_dense, indent)
|
148
|
+
else
|
149
|
+
tensor_str = tensor_str(slf, indent)
|
150
|
+
end
|
151
|
+
end
|
48
152
|
|
49
|
-
|
50
|
-
|
51
|
-
attributes << "requires_grad: true"
|
153
|
+
if slf.layout != :strided
|
154
|
+
suffixes << "layout: #{slf.layout.inspect}"
|
52
155
|
end
|
53
|
-
|
54
|
-
|
156
|
+
|
157
|
+
# TODO show grad_fn
|
158
|
+
if slf.requires_grad?
|
159
|
+
suffixes << "requires_grad: true"
|
55
160
|
end
|
56
161
|
|
57
|
-
|
162
|
+
add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
|
58
163
|
end
|
59
164
|
|
60
|
-
|
165
|
+
def add_suffixes(tensor_str, suffixes, indent, force_newline)
|
166
|
+
tensor_strs = [tensor_str]
|
167
|
+
# rfind in Python returns -1 when not found
|
168
|
+
last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
|
169
|
+
suffixes.each do |suffix|
|
170
|
+
suffix_len = suffix.length
|
171
|
+
if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
|
172
|
+
tensor_strs << ",\n" + " " * indent + suffix
|
173
|
+
last_line_len = indent + suffix_len
|
174
|
+
force_newline = false
|
175
|
+
else
|
176
|
+
tensor_strs.append(", " + suffix)
|
177
|
+
last_line_len += suffix_len + 2
|
178
|
+
end
|
179
|
+
end
|
180
|
+
tensor_strs.append(")")
|
181
|
+
tensor_strs.join("")
|
182
|
+
end
|
61
183
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
184
|
+
def tensor_str(slf, indent)
|
185
|
+
return "[]" if slf.numel == 0
|
186
|
+
|
187
|
+
summarize = slf.numel > PRINT_OPTS[:threshold]
|
188
|
+
|
189
|
+
if slf.dtype == :float16 || slf.dtype == :bfloat16
|
190
|
+
slf = slf.float
|
191
|
+
end
|
192
|
+
formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
|
193
|
+
tensor_str_with_formatter(slf, indent, formatter, summarize)
|
194
|
+
end
|
195
|
+
|
196
|
+
def summarized_data(slf)
|
197
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
73
198
|
|
74
|
-
|
199
|
+
dim = slf.dim
|
200
|
+
if dim == 0
|
201
|
+
slf
|
202
|
+
elsif dim == 1
|
203
|
+
if size(0) > 2 * edgeitems
|
204
|
+
Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
|
205
|
+
else
|
206
|
+
slf
|
207
|
+
end
|
208
|
+
elsif slf.size(0) > 2 * edgeitems
|
209
|
+
start = edgeitems.times.map { |i| slf[i] }
|
210
|
+
finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
|
211
|
+
Torch.stack((start + finish).map { |x| summarized_data(x) })
|
75
212
|
else
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
213
|
+
Torch.stack(slf.map { |x| summarized_data(x) })
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def tensor_str_with_formatter(slf, indent, formatter, summarize)
|
218
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
219
|
+
|
220
|
+
dim = slf.dim
|
84
221
|
|
85
|
-
|
222
|
+
return scalar_str(slf, formatter) if dim == 0
|
223
|
+
return vector_str(slf, indent, formatter, summarize) if dim == 1
|
224
|
+
|
225
|
+
if summarize && slf.size(0) > 2 * edgeitems
|
226
|
+
slices = (
|
227
|
+
[edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
|
228
|
+
["..."] +
|
229
|
+
[((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
|
230
|
+
)
|
231
|
+
else
|
232
|
+
slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
|
86
233
|
end
|
234
|
+
|
235
|
+
tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
|
236
|
+
"[" + tensor_str + "]"
|
237
|
+
end
|
238
|
+
|
239
|
+
def scalar_str(slf, formatter)
|
240
|
+
formatter.format(slf)
|
241
|
+
end
|
242
|
+
|
243
|
+
def vector_str(slf, indent, formatter, summarize)
|
244
|
+
# length includes spaces and comma between elements
|
245
|
+
element_length = formatter.width + 2
|
246
|
+
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
247
|
+
char_per_line = element_length * elements_per_line
|
248
|
+
|
249
|
+
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
250
|
+
data = (
|
251
|
+
[slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
|
252
|
+
[" ..."] +
|
253
|
+
[slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
|
254
|
+
)
|
255
|
+
else
|
256
|
+
data = slf.map { |val| formatter.format(val) }
|
257
|
+
end
|
258
|
+
|
259
|
+
data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
|
260
|
+
lines = data_lines.map { |line| line.join(", ") }
|
261
|
+
"[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
|
87
262
|
end
|
88
263
|
end
|
89
264
|
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -286,8 +286,11 @@ module Torch
|
|
286
286
|
str % vars
|
287
287
|
end
|
288
288
|
|
289
|
+
# used for format
|
290
|
+
# remove tensors for performance
|
291
|
+
# so we can skip call to inspect
|
289
292
|
def dict
|
290
|
-
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
293
|
+
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
291
294
|
end
|
292
295
|
end
|
293
296
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,6 +26,14 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
28
37
|
# TODO make more performant
|
29
38
|
def to_a
|
30
39
|
arr = _flat_data
|
@@ -153,6 +162,18 @@ module Torch
|
|
153
162
|
neg
|
154
163
|
end
|
155
164
|
|
165
|
+
def &(other)
|
166
|
+
logical_and(other)
|
167
|
+
end
|
168
|
+
|
169
|
+
def |(other)
|
170
|
+
logical_or(other)
|
171
|
+
end
|
172
|
+
|
173
|
+
def ^(other)
|
174
|
+
logical_xor(other)
|
175
|
+
end
|
176
|
+
|
156
177
|
# TODO better compare?
|
157
178
|
def <=>(other)
|
158
179
|
item <=> other
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-06-
|
11
|
+
date: 2020-06-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|