tensor_stream 0.6.1 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +10 -0
- data/CHANGELOG.md +8 -0
- data/README.md +40 -1
- data/benchmark/benchmark.rb +4 -1
- data/lib/tensor_stream.rb +5 -0
- data/lib/tensor_stream/debugging/debugging.rb +4 -2
- data/lib/tensor_stream/device.rb +2 -1
- data/lib/tensor_stream/evaluator/base_evaluator.rb +43 -32
- data/lib/tensor_stream/evaluator/evaluator.rb +0 -1
- data/lib/tensor_stream/evaluator/opencl/kernels/acos.cl +8 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_gradient.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/asin.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/floor_mod.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/log_softmax.cl +26 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/max.cl +5 -5
- data/lib/tensor_stream/evaluator/opencl/kernels/min.cl +46 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/real_div.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +27 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross_grad.cl +28 -0
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +5 -6
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +200 -265
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -8
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +193 -122
- data/lib/tensor_stream/exceptions.rb +6 -0
- data/lib/tensor_stream/graph.rb +21 -6
- data/lib/tensor_stream/graph_builder.rb +67 -0
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +271 -0
- data/lib/tensor_stream/graph_keys.rb +1 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +11 -10
- data/lib/tensor_stream/helpers/op_helper.rb +7 -33
- data/lib/tensor_stream/helpers/string_helper.rb +16 -0
- data/lib/tensor_stream/math_gradients.rb +67 -44
- data/lib/tensor_stream/nn/nn_ops.rb +7 -1
- data/lib/tensor_stream/operation.rb +14 -27
- data/lib/tensor_stream/ops.rb +82 -29
- data/lib/tensor_stream/session.rb +4 -0
- data/lib/tensor_stream/tensor.rb +30 -12
- data/lib/tensor_stream/tensor_shape.rb +1 -1
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +37 -4
- data/lib/tensor_stream/train/saver.rb +46 -0
- data/lib/tensor_stream/train/utils.rb +37 -0
- data/lib/tensor_stream/trainer.rb +2 -0
- data/lib/tensor_stream/utils.rb +24 -14
- data/lib/tensor_stream/variable.rb +5 -11
- data/lib/tensor_stream/variable_scope.rb +15 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +8 -4
- data/samples/linear_regression.rb +1 -1
- data/samples/multigpu.rb +73 -0
- data/samples/nearest_neighbor.rb +3 -3
- data/tensor_stream.gemspec +1 -1
- data/test_samples/raw_neural_net_sample.rb +4 -1
- metadata +21 -6
data/lib/tensor_stream/graph.rb
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# A class that defines a TensorStream graph
|
3
3
|
class Graph
|
4
|
-
attr_accessor :nodes, :collections, :eager_execution, :random_seed, :constants
|
4
|
+
attr_accessor :nodes, :node_keys, :collections, :eager_execution, :random_seed, :constants
|
5
5
|
|
6
6
|
def initialize
|
7
7
|
@eager_execution = false
|
8
8
|
@nodes = {}
|
9
|
+
@node_keys = []
|
9
10
|
@collections = {
|
10
|
-
:"#{GraphKeys::GLOBAL_VARIABLES}" => []
|
11
|
+
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
12
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
11
13
|
}
|
12
14
|
@constants = {}
|
13
15
|
end
|
@@ -19,8 +21,10 @@ module TensorStream
|
|
19
21
|
@op_counter = 0
|
20
22
|
@random_seed = nil
|
21
23
|
@nodes = {}
|
24
|
+
@node_keys = []
|
22
25
|
@collections = {
|
23
|
-
:"#{GraphKeys::GLOBAL_VARIABLES}" => []
|
26
|
+
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
27
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
24
28
|
}
|
25
29
|
@constants = {}
|
26
30
|
end
|
@@ -83,6 +87,7 @@ module TensorStream
|
|
83
87
|
end
|
84
88
|
|
85
89
|
node.device = get_device_scope
|
90
|
+
@node_keys << node.name
|
86
91
|
@nodes[node.name] = node
|
87
92
|
@constants[node.name] = node if node.is_const
|
88
93
|
node.send(:propagate_outputs)
|
@@ -98,6 +103,11 @@ module TensorStream
|
|
98
103
|
@nodes[name]
|
99
104
|
end
|
100
105
|
|
106
|
+
def get_tensor_by_name(name)
|
107
|
+
raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
|
108
|
+
get_node(name)
|
109
|
+
end
|
110
|
+
|
101
111
|
def add_node!(name, node)
|
102
112
|
@nodes[name] = node
|
103
113
|
node
|
@@ -120,12 +130,12 @@ module TensorStream
|
|
120
130
|
add_node(node)
|
121
131
|
end
|
122
132
|
|
123
|
-
def control_dependencies(control_inputs = []
|
133
|
+
def control_dependencies(control_inputs = [])
|
124
134
|
Thread.current["ts_graph_#{object_id}"] ||= {}
|
125
135
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
|
126
136
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs)
|
127
137
|
begin
|
128
|
-
|
138
|
+
yield
|
129
139
|
ensure
|
130
140
|
Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop
|
131
141
|
end
|
@@ -201,6 +211,11 @@ module TensorStream
|
|
201
211
|
TensorStream::Pbtext.new.get_string(self)
|
202
212
|
end
|
203
213
|
|
214
|
+
def self.parse_from_string(buffer)
|
215
|
+
builder = TensorStream::GraphBuilder.new(Graph.new)
|
216
|
+
builder.build(buffer)
|
217
|
+
end
|
218
|
+
|
204
219
|
def graph_def_versions
|
205
220
|
"producer: 26"
|
206
221
|
end
|
@@ -208,7 +223,7 @@ module TensorStream
|
|
208
223
|
protected
|
209
224
|
|
210
225
|
def _variable_scope
|
211
|
-
return
|
226
|
+
return VariableScope.new(name: '', reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
|
212
227
|
scope = Thread.current[:tensor_stream_variable_scope].last
|
213
228
|
scope
|
214
229
|
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class GraphBuilder
|
3
|
+
include TensorStream::OpHelper
|
4
|
+
include TensorStream::StringHelper
|
5
|
+
|
6
|
+
attr_accessor :graph
|
7
|
+
|
8
|
+
def initialize(graph)
|
9
|
+
@graph = graph
|
10
|
+
end
|
11
|
+
|
12
|
+
def build(buffer)
|
13
|
+
protobuf = TensorStream::Protobuf.new
|
14
|
+
parsed_tree = protobuf.load_from_string(buffer)
|
15
|
+
parsed_tree.each do |node|
|
16
|
+
next unless node['type'] == 'node'
|
17
|
+
# puts "build #{node['name']}"
|
18
|
+
options = protobuf.options_evaluator(node)
|
19
|
+
options[:name] = node['name']
|
20
|
+
options[:__graph] = @graph
|
21
|
+
value = options.delete('value')
|
22
|
+
options = symbolize_keys(options)
|
23
|
+
case node['op']
|
24
|
+
when 'Const'
|
25
|
+
dimension = shape_eval(value)
|
26
|
+
rank = dimension.size
|
27
|
+
options[:value] = value
|
28
|
+
options[:const] = true
|
29
|
+
TensorStream::Tensor.new(options[:dtype] || options[:T], rank, dimension, options)
|
30
|
+
when 'VariableV2'
|
31
|
+
# evaluate options
|
32
|
+
shape = options[:shape]
|
33
|
+
TensorStream::Variable.new(options[:dtype] || options[:T], nil, shape, nil, options)
|
34
|
+
when 'Placeholder'
|
35
|
+
shape = options[:shape]
|
36
|
+
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
37
|
+
else
|
38
|
+
op = underscore(node['op']).to_sym
|
39
|
+
unless TensorStream::Evaluator::RubyEvaluator.ops.keys.include?(op)
|
40
|
+
puts "warning unsupported op #{op}"
|
41
|
+
end
|
42
|
+
# map input tensor
|
43
|
+
inputs = node['input'].map do |input|
|
44
|
+
input[0] = '' if input.start_with?('^')
|
45
|
+
|
46
|
+
input_indexed, index = input.split(':')
|
47
|
+
|
48
|
+
tensor = if index && index.to_i > 0
|
49
|
+
@graph.get_tensor_by_name(input_indexed)[index.to_i]
|
50
|
+
else
|
51
|
+
@graph.get_tensor_by_name(input)
|
52
|
+
end
|
53
|
+
|
54
|
+
raise "tensor not found by name #{input}" if tensor.nil?
|
55
|
+
|
56
|
+
tensor
|
57
|
+
end
|
58
|
+
|
59
|
+
options[:data_type] = options.delete(:T)
|
60
|
+
TensorStream::Operation.new(op, *inputs, options)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
@graph
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,271 @@
|
|
1
|
+
require 'yaml'
|
2
|
+
|
3
|
+
module TensorStream
|
4
|
+
# A .pb graph deserializer
|
5
|
+
class Protobuf
|
6
|
+
def initialize
|
7
|
+
end
|
8
|
+
|
9
|
+
def load_from_string(buffer)
|
10
|
+
evaluate_lines(buffer.split("\n").map(&:strip))
|
11
|
+
end
|
12
|
+
|
13
|
+
##
|
14
|
+
# parsers a protobuf file and spits out
|
15
|
+
# a ruby hash
|
16
|
+
def load(pbfile)
|
17
|
+
f = File.new(pbfile, 'r')
|
18
|
+
lines = []
|
19
|
+
while !f.eof? && (str = f.readline.strip)
|
20
|
+
lines << str
|
21
|
+
end
|
22
|
+
evaluate_lines(lines)
|
23
|
+
end
|
24
|
+
|
25
|
+
def parse_value(value_node)
|
26
|
+
if value_node['tensor']
|
27
|
+
evaluate_tensor_node(value_node['tensor'])
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def evaluate_tensor_node(node)
|
32
|
+
if !node['shape'].empty? && node['tensor_content']
|
33
|
+
content = node['tensor_content']
|
34
|
+
unpacked = eval(%Q{"#{content}"})
|
35
|
+
|
36
|
+
if node['dtype'] == 'DT_FLOAT'
|
37
|
+
TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
|
38
|
+
elsif node['dtype'] == 'DT_INT32'
|
39
|
+
TensorShape.reshape(unpacked.unpack('l*'), node['shape'])
|
40
|
+
elsif node['dtype'] == 'DT_STRING'
|
41
|
+
node['string_val']
|
42
|
+
else
|
43
|
+
raise "unknown dtype #{node['dtype']}"
|
44
|
+
end
|
45
|
+
else
|
46
|
+
|
47
|
+
val = if node['dtype'] == 'DT_FLOAT'
|
48
|
+
node['float_val'] ? node['float_val'].to_f : []
|
49
|
+
elsif node['dtype'] == 'DT_INT32'
|
50
|
+
node['int_val'] ? node['int_val'].to_i : []
|
51
|
+
elsif node['dtype'] == 'DT_STRING'
|
52
|
+
node['string_val']
|
53
|
+
else
|
54
|
+
raise "unknown dtype #{node['dtype']}"
|
55
|
+
end
|
56
|
+
|
57
|
+
if node['shape'] == [1]
|
58
|
+
[val]
|
59
|
+
else
|
60
|
+
val
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
def map_type_to_ts(attr_value)
|
66
|
+
case(attr_value)
|
67
|
+
when 'DT_FLOAT'
|
68
|
+
:float32
|
69
|
+
when 'DT_INT32'
|
70
|
+
:int32
|
71
|
+
when 'DT_INT64'
|
72
|
+
:int64
|
73
|
+
when 'DT_STRING'
|
74
|
+
:string
|
75
|
+
when 'DT_BOOL'
|
76
|
+
:boolean
|
77
|
+
else
|
78
|
+
raise "unknown type #{attr_value}"
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
def options_evaluator(node)
|
83
|
+
return {} if node['attributes'].nil?
|
84
|
+
|
85
|
+
node['attributes'].map do |attribute|
|
86
|
+
attr_type, attr_value = attribute['value'].collect { |k, v| [k, v] }.flatten(1)
|
87
|
+
|
88
|
+
if attr_type == 'tensor'
|
89
|
+
attr_value = evaluate_tensor_node(attr_value)
|
90
|
+
elsif attr_type == 'type'
|
91
|
+
attr_value = map_type_to_ts(attr_value)
|
92
|
+
elsif attr_type == 'b'
|
93
|
+
attr_value = attr_value == 'true'
|
94
|
+
end
|
95
|
+
|
96
|
+
[attribute['key'], attr_value]
|
97
|
+
end.to_h
|
98
|
+
end
|
99
|
+
|
100
|
+
protected
|
101
|
+
|
102
|
+
def evaluate_lines(lines = [])
|
103
|
+
block = []
|
104
|
+
node = {}
|
105
|
+
node_attr = {}
|
106
|
+
dim = []
|
107
|
+
state = :top
|
108
|
+
|
109
|
+
lines.each do |str|
|
110
|
+
case(state)
|
111
|
+
when :top
|
112
|
+
node['type'] = parse_node_name(str)
|
113
|
+
state = :node_context
|
114
|
+
next
|
115
|
+
when :node_context
|
116
|
+
if str == 'attr {'
|
117
|
+
state = :attr_context
|
118
|
+
node_attr = {}
|
119
|
+
node['attributes'] ||= []
|
120
|
+
node['attributes'] << node_attr
|
121
|
+
next
|
122
|
+
elsif str == '}'
|
123
|
+
state = :top
|
124
|
+
block << node
|
125
|
+
node = {}
|
126
|
+
next
|
127
|
+
else
|
128
|
+
key, value = str.split(':', 2)
|
129
|
+
if key == 'input'
|
130
|
+
node['input'] ||= []
|
131
|
+
node['input'] << process_value(value.strip)
|
132
|
+
else
|
133
|
+
node[key] = process_value(value.strip)
|
134
|
+
end
|
135
|
+
end
|
136
|
+
when :attr_context
|
137
|
+
if str == 'value {'
|
138
|
+
state = :value_context
|
139
|
+
node_attr['value'] = {}
|
140
|
+
next
|
141
|
+
elsif str == '}'
|
142
|
+
state = :node_context
|
143
|
+
next
|
144
|
+
else
|
145
|
+
key, value = str.split(':', 2)
|
146
|
+
node_attr[key] = process_value(value.strip)
|
147
|
+
end
|
148
|
+
when :value_context
|
149
|
+
if str == 'list {'
|
150
|
+
state = :list_context
|
151
|
+
node_attr['value'] = []
|
152
|
+
next
|
153
|
+
elsif str == 'shape {'
|
154
|
+
state = :shape_context
|
155
|
+
node_attr['value']['shape'] = []
|
156
|
+
next
|
157
|
+
elsif str == 'tensor {'
|
158
|
+
state = :tensor_context
|
159
|
+
node_attr['value']['tensor'] = {}
|
160
|
+
next
|
161
|
+
elsif str == '}'
|
162
|
+
state = :attr_context
|
163
|
+
next
|
164
|
+
else
|
165
|
+
key, value = str.split(':', 2)
|
166
|
+
if key == 'dtype'
|
167
|
+
node_attr['value']['dtype'] = value.strip
|
168
|
+
elsif key == 'type'
|
169
|
+
node_attr['value']['type'] = value.strip
|
170
|
+
else
|
171
|
+
node_attr['value'][key] = process_value(value.strip)
|
172
|
+
end
|
173
|
+
end
|
174
|
+
when :list_context
|
175
|
+
if str == '}'
|
176
|
+
state = :value_context
|
177
|
+
next
|
178
|
+
else
|
179
|
+
key, value = str.split(':', 2)
|
180
|
+
node_attr['value'] << { key => value}
|
181
|
+
end
|
182
|
+
when :tensor_context
|
183
|
+
if str == 'tensor_shape {'
|
184
|
+
state = :tensor_shape_context
|
185
|
+
node_attr['value']['tensor']['shape'] = []
|
186
|
+
next
|
187
|
+
elsif str == '}'
|
188
|
+
state = :value_context
|
189
|
+
next
|
190
|
+
else
|
191
|
+
key, value = str.split(':', 2)
|
192
|
+
if node_attr['value']['tensor'][key] && !node_attr['value']['tensor'][key].is_a?(Array)
|
193
|
+
node_attr['value']['tensor'][key] = [node_attr['value']['tensor'][key]]
|
194
|
+
node_attr['value']['tensor'][key] << process_value(value.strip)
|
195
|
+
elsif node_attr['value']['tensor'][key]
|
196
|
+
node_attr['value']['tensor'][key] << process_value(value.strip)
|
197
|
+
else
|
198
|
+
node_attr['value']['tensor'][key] = process_value(value.strip)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
when :tensor_shape_context
|
202
|
+
if str == 'dim {'
|
203
|
+
state = :tensor_shape_dim_context
|
204
|
+
next
|
205
|
+
elsif str == '}'
|
206
|
+
state = :tensor_context
|
207
|
+
next
|
208
|
+
end
|
209
|
+
when :shape_context
|
210
|
+
if str == '}'
|
211
|
+
state = :value_context
|
212
|
+
next
|
213
|
+
elsif str == 'dim {'
|
214
|
+
state = :shape_dim_context
|
215
|
+
next
|
216
|
+
end
|
217
|
+
when :shape_dim_context
|
218
|
+
if str == '}'
|
219
|
+
state = :shape_context
|
220
|
+
next
|
221
|
+
else
|
222
|
+
key, value = str.split(':', 2)
|
223
|
+
node_attr['value']['shape'] << value.strip.to_i
|
224
|
+
end
|
225
|
+
when :tensor_shape_dim_context
|
226
|
+
if str == '}'
|
227
|
+
state = :tensor_shape_context
|
228
|
+
next
|
229
|
+
else
|
230
|
+
key, value = str.split(':', 2)
|
231
|
+
node_attr['value']['tensor']['shape'] << value.strip.to_i
|
232
|
+
end
|
233
|
+
end
|
234
|
+
end
|
235
|
+
|
236
|
+
block
|
237
|
+
end
|
238
|
+
|
239
|
+
def parse_node_name(str)
|
240
|
+
name = str.split(' ')[0]
|
241
|
+
end
|
242
|
+
|
243
|
+
def process_value(value)
|
244
|
+
if value.start_with?('"')
|
245
|
+
unescape(value.gsub!(/\A"|"\Z/, ''))
|
246
|
+
else
|
247
|
+
unescape(value)
|
248
|
+
end
|
249
|
+
end
|
250
|
+
|
251
|
+
UNESCAPES = {
|
252
|
+
'a' => "\x07", 'b' => "\x08", 't' => "\x09",
|
253
|
+
'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
|
254
|
+
'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
|
255
|
+
"\"" => "\x22", "'" => "\x27"
|
256
|
+
}
|
257
|
+
|
258
|
+
def unescape(str)
|
259
|
+
# Escape all the things
|
260
|
+
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) {
|
261
|
+
if $1
|
262
|
+
$1 == '\\' ? '\\' : UNESCAPES[$1]
|
263
|
+
elsif $2 # escape \u0000 unicode
|
264
|
+
["#$2".hex].pack('U*')
|
265
|
+
elsif $3 # escape \0xff or \xff
|
266
|
+
[$3].pack('H2')
|
267
|
+
end
|
268
|
+
}
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
@@ -6,7 +6,8 @@ module TensorStream
|
|
6
6
|
def get_string(tensor_or_graph, session = nil)
|
7
7
|
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
8
8
|
@lines = []
|
9
|
-
graph.
|
9
|
+
graph.node_keys.each do |k|
|
10
|
+
node = graph.get_tensor_by_name(k)
|
10
11
|
@lines << "node {"
|
11
12
|
@lines << " name: #{node.name.to_json}"
|
12
13
|
if node.is_a?(TensorStream::Operation)
|
@@ -16,16 +17,16 @@ module TensorStream
|
|
16
17
|
@lines << " input: #{input.name.to_json}"
|
17
18
|
end
|
18
19
|
# type
|
19
|
-
pb_attr('T', "
|
20
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
20
21
|
process_options(node)
|
21
22
|
elsif node.is_a?(TensorStream::Tensor) && node.is_const
|
22
23
|
@lines << " op: \"Const\""
|
23
24
|
# type
|
24
|
-
pb_attr('T', "
|
25
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
25
26
|
pb_attr('value', tensor_value(node))
|
26
27
|
elsif node.is_a?(TensorStream::Variable)
|
27
28
|
@lines << " op: \"VariableV2\""
|
28
|
-
pb_attr('T', "
|
29
|
+
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
29
30
|
pb_attr('shape', shape_buf(node, 'shape'))
|
30
31
|
process_options(node)
|
31
32
|
end
|
@@ -42,16 +43,16 @@ module TensorStream
|
|
42
43
|
def process_options(node)
|
43
44
|
return if node.options.nil?
|
44
45
|
node.options.each do |k, v|
|
45
|
-
next if %w[name].include?(k.to_s)
|
46
|
+
next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
|
46
47
|
@lines << " attr {"
|
47
48
|
@lines << " key: \"#{k}\""
|
48
49
|
@lines << " value {"
|
49
50
|
if (v.is_a?(TrueClass) || v.is_a?(FalseClass))
|
50
|
-
@lines << "
|
51
|
+
@lines << " b: #{v.to_s}"
|
51
52
|
elsif (v.is_a?(Integer))
|
52
|
-
@lines << "
|
53
|
+
@lines << " int_val: #{v}"
|
53
54
|
elsif (v.is_a?(Float))
|
54
|
-
@lines << "
|
55
|
+
@lines << " float_val: #{v}"
|
55
56
|
end
|
56
57
|
@lines << " }"
|
57
58
|
@lines << " }"
|
@@ -65,7 +66,7 @@ module TensorStream
|
|
65
66
|
def pack_arr_int(int_arr)
|
66
67
|
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
67
68
|
end
|
68
|
-
|
69
|
+
|
69
70
|
def shape_buf(tensor, shape_type = 'tensor_shape')
|
70
71
|
arr = []
|
71
72
|
arr << " #{shape_type} {"
|
@@ -77,6 +78,7 @@ module TensorStream
|
|
77
78
|
arr << " }"
|
78
79
|
arr
|
79
80
|
end
|
81
|
+
|
80
82
|
def tensor_value(tensor)
|
81
83
|
arr = []
|
82
84
|
arr << "tensor {"
|
@@ -146,5 +148,4 @@ module TensorStream
|
|
146
148
|
@lines << " }"
|
147
149
|
end
|
148
150
|
end
|
149
|
-
|
150
151
|
end
|