tensor_stream 0.6.1 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.rubocop.yml +10 -0
- data/CHANGELOG.md +8 -0
- data/README.md +40 -1
- data/benchmark/benchmark.rb +4 -1
- data/lib/tensor_stream.rb +5 -0
- data/lib/tensor_stream/debugging/debugging.rb +4 -2
- data/lib/tensor_stream/device.rb +2 -1
- data/lib/tensor_stream/evaluator/base_evaluator.rb +43 -32
- data/lib/tensor_stream/evaluator/evaluator.rb +0 -1
- data/lib/tensor_stream/evaluator/opencl/kernels/acos.cl +8 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_gradient.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/asin.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/floor_mod.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/log_softmax.cl +26 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/max.cl +5 -5
- data/lib/tensor_stream/evaluator/opencl/kernels/min.cl +46 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/real_div.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +27 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross_grad.cl +28 -0
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +5 -6
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +200 -265
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -8
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +193 -122
- data/lib/tensor_stream/exceptions.rb +6 -0
- data/lib/tensor_stream/graph.rb +21 -6
- data/lib/tensor_stream/graph_builder.rb +67 -0
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +271 -0
- data/lib/tensor_stream/graph_keys.rb +1 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +11 -10
- data/lib/tensor_stream/helpers/op_helper.rb +7 -33
- data/lib/tensor_stream/helpers/string_helper.rb +16 -0
- data/lib/tensor_stream/math_gradients.rb +67 -44
- data/lib/tensor_stream/nn/nn_ops.rb +7 -1
- data/lib/tensor_stream/operation.rb +14 -27
- data/lib/tensor_stream/ops.rb +82 -29
- data/lib/tensor_stream/session.rb +4 -0
- data/lib/tensor_stream/tensor.rb +30 -12
- data/lib/tensor_stream/tensor_shape.rb +1 -1
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +37 -4
- data/lib/tensor_stream/train/saver.rb +46 -0
- data/lib/tensor_stream/train/utils.rb +37 -0
- data/lib/tensor_stream/trainer.rb +2 -0
- data/lib/tensor_stream/utils.rb +24 -14
- data/lib/tensor_stream/variable.rb +5 -11
- data/lib/tensor_stream/variable_scope.rb +15 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +8 -4
- data/samples/linear_regression.rb +1 -1
- data/samples/multigpu.rb +73 -0
- data/samples/nearest_neighbor.rb +3 -3
- data/tensor_stream.gemspec +1 -1
- data/test_samples/raw_neural_net_sample.rb +4 -1
- metadata +21 -6
data/lib/tensor_stream/graph.rb
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# A class that defines a TensorStream graph
|
3
3
|
class Graph
|
4
|
-
attr_accessor :nodes, :collections, :eager_execution, :random_seed, :constants
|
4
|
+
attr_accessor :nodes, :node_keys, :collections, :eager_execution, :random_seed, :constants
|
5
5
|
|
6
6
|
def initialize
|
7
7
|
@eager_execution = false
|
8
8
|
@nodes = {}
|
9
|
+
@node_keys = []
|
9
10
|
@collections = {
|
10
|
-
:"#{GraphKeys::GLOBAL_VARIABLES}" => []
|
11
|
+
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
12
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
11
13
|
}
|
12
14
|
@constants = {}
|
13
15
|
end
|
@@ -19,8 +21,10 @@ module TensorStream
|
|
19
21
|
@op_counter = 0
|
20
22
|
@random_seed = nil
|
21
23
|
@nodes = {}
|
24
|
+
@node_keys = []
|
22
25
|
@collections = {
|
23
|
-
:"#{GraphKeys::GLOBAL_VARIABLES}" => []
|
26
|
+
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
27
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
24
28
|
}
|
25
29
|
@constants = {}
|
26
30
|
end
|
@@ -83,6 +87,7 @@ module TensorStream
|
|
83
87
|
end
|
84
88
|
|
85
89
|
node.device = get_device_scope
|
90
|
+
@node_keys << node.name
|
86
91
|
@nodes[node.name] = node
|
87
92
|
@constants[node.name] = node if node.is_const
|
88
93
|
node.send(:propagate_outputs)
|
@@ -98,6 +103,11 @@ module TensorStream
|
|
98
103
|
@nodes[name]
|
99
104
|
end
|
100
105
|
|
106
|
+
def get_tensor_by_name(name)
|
107
|
+
raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
|
108
|
+
get_node(name)
|
109
|
+
end
|
110
|
+
|
101
111
|
def add_node!(name, node)
|
102
112
|
@nodes[name] = node
|
103
113
|
node
|
@@ -120,12 +130,12 @@ module TensorStream
|
|
120
130
|
add_node(node)
|
121
131
|
end
|
122
132
|
|
123
|
-
def control_dependencies(control_inputs = []
|
133
|
+
def control_dependencies(control_inputs = [])
|
124
134
|
Thread.current["ts_graph_#{object_id}"] ||= {}
|
125
135
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
|
126
136
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs)
|
127
137
|
begin
|
128
|
-
|
138
|
+
yield
|
129
139
|
ensure
|
130
140
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop
|
131
141
|
end
|
@@ -201,6 +211,11 @@ module TensorStream
|
|
201
211
|
TensorStream::Pbtext.new.get_string(self)
|
202
212
|
end
|
203
213
|
|
214
|
+
def self.parse_from_string(buffer)
|
215
|
+
builder = TensorStream::GraphBuilder.new(Graph.new)
|
216
|
+
builder.build(buffer)
|
217
|
+
end
|
218
|
+
|
204
219
|
def graph_def_versions
|
205
220
|
"producer: 26"
|
206
221
|
end
|
@@ -208,7 +223,7 @@ module TensorStream
|
|
208
223
|
protected
|
209
224
|
|
210
225
|
def _variable_scope
|
211
|
-
return
|
226
|
+
return VariableScope.new(name: '', reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
|
212
227
|
scope = Thread.current[:tensor_stream_variable_scope].last
|
213
228
|
scope
|
214
229
|
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class GraphBuilder
|
3
|
+
include TensorStream::OpHelper
|
4
|
+
include TensorStream::StringHelper
|
5
|
+
|
6
|
+
attr_accessor :graph
|
7
|
+
|
8
|
+
def initialize(graph)
|
9
|
+
@graph = graph
|
10
|
+
end
|
11
|
+
|
12
|
+
def build(buffer)
|
13
|
+
protobuf = TensorStream::Protobuf.new
|
14
|
+
parsed_tree = protobuf.load_from_string(buffer)
|
15
|
+
parsed_tree.each do |node|
|
16
|
+
next unless node['type'] == 'node'
|
17
|
+
# puts "build #{node['name']}"
|
18
|
+
options = protobuf.options_evaluator(node)
|
19
|
+
options[:name] = node['name']
|
20
|
+
options[:__graph] = @graph
|
21
|
+
value = options.delete('value')
|
22
|
+
options = symbolize_keys(options)
|
23
|
+
case node['op']
|
24
|
+
when 'Const'
|
25
|
+
dimension = shape_eval(value)
|
26
|
+
rank = dimension.size
|
27
|
+
options[:value] = value
|
28
|
+
options[:const] = true
|
29
|
+
TensorStream::Tensor.new(options[:dtype] || options[:T], rank, dimension, options)
|
30
|
+
when 'VariableV2'
|
31
|
+
# evaluate options
|
32
|
+
shape = options[:shape]
|
33
|
+
TensorStream::Variable.new(options[:dtype] || options[:T], nil, shape, nil, options)
|
34
|
+
when 'Placeholder'
|
35
|
+
shape = options[:shape]
|
36
|
+
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
37
|
+
else
|
38
|
+
op = underscore(node['op']).to_sym
|
39
|
+
unless TensorStream::Evaluator::RubyEvaluator.ops.keys.include?(op)
|
40
|
+
puts "warning unsupported op #{op}"
|
41
|
+
end
|
42
|
+
# map input tensor
|
43
|
+
inputs = node['input'].map do |input|
|
44
|
+
input[0] = '' if input.start_with?('^')
|
45
|
+
|
46
|
+
input_indexed, index = input.split(':')
|
47
|
+
|
48
|
+
tensor = if index && index.to_i > 0
|
49
|
+
@graph.get_tensor_by_name(input_indexed)[index.to_i]
|
50
|
+
else
|
51
|
+
@graph.get_tensor_by_name(input)
|
52
|
+
end
|
53
|
+
|
54
|
+
raise "tensor not found by name #{input}" if tensor.nil?
|
55
|
+
|
56
|
+
tensor
|
57
|
+
end
|
58
|
+
|
59
|
+
options[:data_type] = options.delete(:T)
|
60
|
+
TensorStream::Operation.new(op, *inputs, options)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
@graph
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,271 @@
|
|
1
|
+
require 'yaml'
|
2
|
+
|
3
|
+
module TensorStream
|
4
|
+
# A .pb graph deserializer
|
5
|
+
class Protobuf
|
6
|
+
def initialize
|
7
|
+
end
|
8
|
+
|
9
|
+
def load_from_string(buffer)
|
10
|
+
evaluate_lines(buffer.split("\n").map(&:strip))
|
11
|
+
end
|
12
|
+
|
13
|
+
##
|
14
|
+
# parsers a protobuf file and spits out
|
15
|
+
# a ruby hash
|
16
|
+
def load(pbfile)
|
17
|
+
f = File.new(pbfile, 'r')
|
18
|
+
lines = []
|
19
|
+
while !f.eof? && (str = f.readline.strip)
|
20
|
+
lines << str
|
21
|
+
end
|
22
|
+
evaluate_lines(lines)
|
23
|
+
end
|
24
|
+
|
25
|
+
def parse_value(value_node)
|
26
|
+
if value_node['tensor']
|
27
|
+
evaluate_tensor_node(value_node['tensor'])
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def evaluate_tensor_node(node)
|
32
|
+
if !node['shape'].empty? && node['tensor_content']
|
33
|
+
content = node['tensor_content']
|
34
|
+
unpacked = eval(%Q{"#{content}"})
|
35
|
+
|
36
|
+
if node['dtype'] == 'DT_FLOAT'
|
37
|
+
TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
|
38
|
+
elsif node['dtype'] == 'DT_INT32'
|
39
|
+
TensorShape.reshape(unpacked.unpack('l*'), node['shape'])
|
40
|
+
elsif node['dtype'] == 'DT_STRING'
|
41
|
+
node['string_val']
|
42
|
+
else
|
43
|
+
raise "unknown dtype #{node['dtype']}"
|
44
|
+
end
|
45
|
+
else
|
46
|
+
|
47
|
+
val = if node['dtype'] == 'DT_FLOAT'
|
48
|
+
node['float_val'] ? node['float_val'].to_f : []
|
49
|
+
elsif node['dtype'] == 'DT_INT32'
|
50
|
+
node['int_val'] ? node['int_val'].to_i : []
|
51
|
+
elsif node['dtype'] == 'DT_STRING'
|
52
|
+
node['string_val']
|
53
|
+
else
|
54
|
+
raise "unknown dtype #{node['dtype']}"
|
55
|
+
end
|
56
|
+
|
57
|
+
if node['shape'] == [1]
|
58
|
+
[val]
|
59
|
+
else
|
60
|
+
val
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
def map_type_to_ts(attr_value)
|
66
|
+
case(attr_value)
|
67
|
+
when 'DT_FLOAT'
|
68
|
+
:float32
|
69
|
+
when 'DT_INT32'
|
70
|
+
:int32
|
71
|
+
when 'DT_INT64'
|
72
|
+
:int64
|
73
|
+
when 'DT_STRING'
|
74
|
+
:string
|
75
|
+
when 'DT_BOOL'
|
76
|
+
:boolean
|
77
|
+
else
|
78
|
+
raise "unknown type #{attr_value}"
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
def options_evaluator(node)
|
83
|
+
return {} if node['attributes'].nil?
|
84
|
+
|
85
|
+
node['attributes'].map do |attribute|
|
86
|
+
attr_type, attr_value = attribute['value'].collect { |k, v| [k, v] }.flatten(1)
|
87
|
+
|
88
|
+
if attr_type == 'tensor'
|
89
|
+
attr_value = evaluate_tensor_node(attr_value)
|
90
|
+
elsif attr_type == 'type'
|
91
|
+
attr_value = map_type_to_ts(attr_value)
|
92
|
+
elsif attr_type == 'b'
|
93
|
+
attr_value = attr_value == 'true'
|
94
|
+
end
|
95
|
+
|
96
|
+
[attribute['key'], attr_value]
|
97
|
+
end.to_h
|
98
|
+
end
|
99
|
+
|
100
|
+
protected
|
101
|
+
|
102
|
+
def evaluate_lines(lines = [])
|
103
|
+
block = []
|
104
|
+
node = {}
|
105
|
+
node_attr = {}
|
106
|
+
dim = []
|
107
|
+
state = :top
|
108
|
+
|
109
|
+
lines.each do |str|
|
110
|
+
case(state)
|
111
|
+
when :top
|
112
|
+
node['type'] = parse_node_name(str)
|
113
|
+
state = :node_context
|
114
|
+
next
|
115
|
+
when :node_context
|
116
|
+
if str == 'attr {'
|
117
|
+
state = :attr_context
|
118
|
+
node_attr = {}
|
119
|
+
node['attributes'] ||= []
|
120
|
+
node['attributes'] << node_attr
|
121
|
+
next
|
122
|
+
elsif str == '}'
|
123
|
+
state = :top
|
124
|
+
block << node
|
125
|
+
node = {}
|
126
|
+
next
|
127
|
+
else
|
128
|
+
key, value = str.split(':', 2)
|
129
|
+
if key == 'input'
|
130
|
+
node['input'] ||= []
|
131
|
+
node['input'] << process_value(value.strip)
|
132
|
+
else
|
133
|
+
node[key] = process_value(value.strip)
|
134
|
+
end
|
135
|
+
end
|
136
|
+
when :attr_context
|
137
|
+
if str == 'value {'
|
138
|
+
state = :value_context
|
139
|
+
node_attr['value'] = {}
|
140
|
+
next
|
141
|
+
elsif str == '}'
|
142
|
+
state = :node_context
|
143
|
+
next
|
144
|
+
else
|
145
|
+
key, value = str.split(':', 2)
|
146
|
+
node_attr[key] = process_value(value.strip)
|
147
|
+
end
|
148
|
+
when :value_context
|
149
|
+
if str == 'list {'
|
150
|
+
state = :list_context
|
151
|
+
node_attr['value'] = []
|
152
|
+
next
|
153
|
+
elsif str == 'shape {'
|
154
|
+
state = :shape_context
|
155
|
+
node_attr['value']['shape'] = []
|
156
|
+
next
|
157
|
+
elsif str == 'tensor {'
|
158
|
+
state = :tensor_context
|
159
|
+
node_attr['value']['tensor'] = {}
|
160
|
+
next
|
161
|
+
elsif str == '}'
|
162
|
+
state = :attr_context
|
163
|
+
next
|
164
|
+
else
|
165
|
+
key, value = str.split(':', 2)
|
166
|
+
if key == 'dtype'
|
167
|
+
node_attr['value']['dtype'] = value.strip
|
168
|
+
elsif key == 'type'
|
169
|
+
node_attr['value']['type'] = value.strip
|
170
|
+
else
|
171
|
+
node_attr['value'][key] = process_value(value.strip)
|
172
|
+
end
|
173
|
+
end
|
174
|
+
when :list_context
|
175
|
+
if str == '}'
|
176
|
+
state = :value_context
|
177
|
+
next
|
178
|
+
else
|
179
|
+
key, value = str.split(':', 2)
|
180
|
+
node_attr['value'] << { key => value}
|
181
|
+
end
|
182
|
+
when :tensor_context
|
183
|
+
if str == 'tensor_shape {'
|
184
|
+
state = :tensor_shape_context
|
185
|
+
node_attr['value']['tensor']['shape'] = []
|
186
|
+
next
|
187
|
+
elsif str == '}'
|
188
|
+
state = :value_context
|
189
|
+
next
|
190
|
+
else
|
191
|
+
key, value = str.split(':', 2)
|
192
|
+
if node_attr['value']['tensor'][key] && !node_attr['value']['tensor'][key].is_a?(Array)
|
193
|
+
node_attr['value']['tensor'][key] = [node_attr['value']['tensor'][key]]
|
194
|
+
node_attr['value']['tensor'][key] << process_value(value.strip)
|
195
|
+
elsif node_attr['value']['tensor'][key]
|
196
|
+
node_attr['value']['tensor'][key] << process_value(value.strip)
|
197
|
+
else
|
198
|
+
node_attr['value']['tensor'][key] = process_value(value.strip)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
when :tensor_shape_context
|
202
|
+
if str == 'dim {'
|
203
|
+
state = :tensor_shape_dim_context
|
204
|
+
next
|
205
|
+
elsif str == '}'
|
206
|
+
state = :tensor_context
|
207
|
+
next
|
208
|
+
end
|
209
|
+
when :shape_context
|
210
|
+
if str == '}'
|
211
|
+
state = :value_context
|
212
|
+
next
|
213
|
+
elsif str == 'dim {'
|
214
|
+
state = :shape_dim_context
|
215
|
+
next
|
216
|
+
end
|
217
|
+
when :shape_dim_context
|
218
|
+
if str == '}'
|
219
|
+
state = :shape_context
|
220
|
+
next
|
221
|
+
else
|
222
|
+
key, value = str.split(':', 2)
|
223
|
+
node_attr['value']['shape'] << value.strip.to_i
|
224
|
+
end
|
225
|
+
when :tensor_shape_dim_context
|
226
|
+
if str == '}'
|
227
|
+
state = :tensor_shape_context
|
228
|
+
next
|
229
|
+
else
|
230
|
+
key, value = str.split(':', 2)
|
231
|
+
node_attr['value']['tensor']['shape'] << value.strip.to_i
|
232
|
+
end
|
233
|
+
end
|
234
|
+
end
|
235
|
+
|
236
|
+
block
|
237
|
+
end
|
238
|
+
|
239
|
+
def parse_node_name(str)
|
240
|
+
name = str.split(' ')[0]
|
241
|
+
end
|
242
|
+
|
243
|
+
def process_value(value)
|
244
|
+
if value.start_with?('"')
|
245
|
+
unescape(value.gsub!(/\A"|"\Z/, ''))
|
246
|
+
else
|
247
|
+
unescape(value)
|
248
|
+
end
|
249
|
+
end
|
250
|
+
|
251
|
+
UNESCAPES = {
|
252
|
+
'a' => "\x07", 'b' => "\x08", 't' => "\x09",
|
253
|
+
'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
|
254
|
+
'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
|
255
|
+
"\"" => "\x22", "'" => "\x27"
|
256
|
+
}
|
257
|
+
|
258
|
+
def unescape(str)
|
259
|
+
# Escape all the things
|
260
|
+
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) {
|
261
|
+
if $1
|
262
|
+
$1 == '\\' ? '\\' : UNESCAPES[$1]
|
263
|
+
elsif $2 # escape \u0000 unicode
|
264
|
+
["#$2".hex].pack('U*')
|
265
|
+
elsif $3 # escape \0xff or \xff
|
266
|
+
[$3].pack('H2')
|
267
|
+
end
|
268
|
+
}
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
@@ -6,7 +6,8 @@ module TensorStream
|
|
6
6
|
def get_string(tensor_or_graph, session = nil)
|
7
7
|
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
8
8
|
@lines = []
|
9
|
-
graph.
|
9
|
+
graph.node_keys.each do |k|
|
10
|
+
node = graph.get_tensor_by_name(k)
|
10
11
|
@lines << "node {"
|
11
12
|
@lines << " name: #{node.name.to_json}"
|
12
13
|
if node.is_a?(TensorStream::Operation)
|
@@ -16,16 +17,16 @@ module TensorStream
|
|
16
17
|
@lines << " input: #{input.name.to_json}"
|
17
18
|
end
|
18
19
|
# type
|
19
|
-
pb_attr('T', "
|
20
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
20
21
|
process_options(node)
|
21
22
|
elsif node.is_a?(TensorStream::Tensor) && node.is_const
|
22
23
|
@lines << " op: \"Const\""
|
23
24
|
# type
|
24
|
-
pb_attr('T', "
|
25
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
25
26
|
pb_attr('value', tensor_value(node))
|
26
27
|
elsif node.is_a?(TensorStream::Variable)
|
27
28
|
@lines << " op: \"VariableV2\""
|
28
|
-
pb_attr('T', "
|
29
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
29
30
|
pb_attr('shape', shape_buf(node, 'shape'))
|
30
31
|
process_options(node)
|
31
32
|
end
|
@@ -42,16 +43,16 @@ module TensorStream
|
|
42
43
|
def process_options(node)
|
43
44
|
return if node.options.nil?
|
44
45
|
node.options.each do |k, v|
|
45
|
-
next if %w[name].include?(k.to_s)
|
46
|
+
next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
|
46
47
|
@lines << " attr {"
|
47
48
|
@lines << " key: \"#{k}\""
|
48
49
|
@lines << " value {"
|
49
50
|
if (v.is_a?(TrueClass) || v.is_a?(FalseClass))
|
50
|
-
@lines << "
|
51
|
+
@lines << " b: #{v.to_s}"
|
51
52
|
elsif (v.is_a?(Integer))
|
52
|
-
@lines << "
|
53
|
+
@lines << " int_val: #{v}"
|
53
54
|
elsif (v.is_a?(Float))
|
54
|
-
@lines << "
|
55
|
+
@lines << " float_val: #{v}"
|
55
56
|
end
|
56
57
|
@lines << " }"
|
57
58
|
@lines << " }"
|
@@ -65,7 +66,7 @@ module TensorStream
|
|
65
66
|
def pack_arr_int(int_arr)
|
66
67
|
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
67
68
|
end
|
68
|
-
|
69
|
+
|
69
70
|
def shape_buf(tensor, shape_type = 'tensor_shape')
|
70
71
|
arr = []
|
71
72
|
arr << " #{shape_type} {"
|
@@ -77,6 +78,7 @@ module TensorStream
|
|
77
78
|
arr << " }"
|
78
79
|
arr
|
79
80
|
end
|
81
|
+
|
80
82
|
def tensor_value(tensor)
|
81
83
|
arr = []
|
82
84
|
arr << "tensor {"
|
@@ -146,5 +148,4 @@ module TensorStream
|
|
146
148
|
@lines << " }"
|
147
149
|
end
|
148
150
|
end
|
149
|
-
|
150
151
|
end
|