tensor_stream 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.rubocop.yml +1 -0
- data/Gemfile +1 -1
- data/LICENSE.txt +1 -1
- data/README.md +34 -34
- data/Rakefile +3 -3
- data/USAGE_GUIDE.md +235 -0
- data/bin/stubgen +20 -0
- data/exe/model_utils +2 -2
- data/lib/tensor_stream.rb +45 -44
- data/lib/tensor_stream/constant.rb +2 -2
- data/lib/tensor_stream/control_flow.rb +1 -1
- data/lib/tensor_stream/debugging/debugging.rb +2 -2
- data/lib/tensor_stream/dynamic_stitch.rb +2 -2
- data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
- data/lib/tensor_stream/evaluator/buffer.rb +1 -1
- data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
- data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
- data/lib/tensor_stream/exceptions.rb +1 -1
- data/lib/tensor_stream/generated_stub/ops.rb +691 -0
- data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
- data/lib/tensor_stream/graph.rb +18 -18
- data/lib/tensor_stream/graph_builder.rb +17 -17
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
- data/lib/tensor_stream/graph_keys.rb +3 -3
- data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
- data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
- data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
- data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
- data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
- data/lib/tensor_stream/helpers/op_helper.rb +8 -9
- data/lib/tensor_stream/helpers/string_helper.rb +15 -15
- data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
- data/lib/tensor_stream/images.rb +1 -1
- data/lib/tensor_stream/initializer.rb +1 -1
- data/lib/tensor_stream/math_gradients.rb +28 -187
- data/lib/tensor_stream/monkey_patches/array.rb +1 -1
- data/lib/tensor_stream/monkey_patches/float.rb +1 -1
- data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
- data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
- data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
- data/lib/tensor_stream/nn/nn_ops.rb +17 -15
- data/lib/tensor_stream/op_maker.rb +180 -0
- data/lib/tensor_stream/operation.rb +17 -17
- data/lib/tensor_stream/ops.rb +95 -384
- data/lib/tensor_stream/ops/add.rb +23 -0
- data/lib/tensor_stream/ops/argmax.rb +14 -0
- data/lib/tensor_stream/ops/argmin.rb +14 -0
- data/lib/tensor_stream/ops/case.rb +17 -0
- data/lib/tensor_stream/ops/cast.rb +15 -0
- data/lib/tensor_stream/ops/ceil.rb +15 -0
- data/lib/tensor_stream/ops/const.rb +0 -0
- data/lib/tensor_stream/ops/cos.rb +10 -0
- data/lib/tensor_stream/ops/div.rb +21 -0
- data/lib/tensor_stream/ops/equal.rb +15 -0
- data/lib/tensor_stream/ops/expand_dims.rb +17 -0
- data/lib/tensor_stream/ops/fill.rb +19 -0
- data/lib/tensor_stream/ops/floor.rb +15 -0
- data/lib/tensor_stream/ops/floor_div.rb +15 -0
- data/lib/tensor_stream/ops/greater.rb +11 -0
- data/lib/tensor_stream/ops/greater_equal.rb +11 -0
- data/lib/tensor_stream/ops/less_equal.rb +15 -0
- data/lib/tensor_stream/ops/log.rb +14 -0
- data/lib/tensor_stream/ops/mat_mul.rb +60 -0
- data/lib/tensor_stream/ops/max.rb +15 -0
- data/lib/tensor_stream/ops/min.rb +15 -0
- data/lib/tensor_stream/ops/mod.rb +23 -0
- data/lib/tensor_stream/ops/mul.rb +21 -0
- data/lib/tensor_stream/ops/negate.rb +14 -0
- data/lib/tensor_stream/ops/ones_like.rb +19 -0
- data/lib/tensor_stream/ops/pow.rb +25 -0
- data/lib/tensor_stream/ops/prod.rb +60 -0
- data/lib/tensor_stream/ops/random_uniform.rb +18 -0
- data/lib/tensor_stream/ops/range.rb +20 -0
- data/lib/tensor_stream/ops/rank.rb +13 -0
- data/lib/tensor_stream/ops/reshape.rb +24 -0
- data/lib/tensor_stream/ops/round.rb +15 -0
- data/lib/tensor_stream/ops/shape.rb +14 -0
- data/lib/tensor_stream/ops/sigmoid.rb +10 -0
- data/lib/tensor_stream/ops/sign.rb +12 -0
- data/lib/tensor_stream/ops/sin.rb +10 -0
- data/lib/tensor_stream/ops/size.rb +16 -0
- data/lib/tensor_stream/ops/sub.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +27 -0
- data/lib/tensor_stream/ops/tan.rb +12 -0
- data/lib/tensor_stream/ops/tanh.rb +10 -0
- data/lib/tensor_stream/ops/tile.rb +19 -0
- data/lib/tensor_stream/ops/zeros.rb +15 -0
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/profile/report_tool.rb +3 -3
- data/lib/tensor_stream/session.rb +36 -38
- data/lib/tensor_stream/tensor.rb +2 -2
- data/lib/tensor_stream/tensor_shape.rb +4 -4
- data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
- data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
- data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
- data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
- data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
- data/lib/tensor_stream/train/optimizer.rb +9 -9
- data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
- data/lib/tensor_stream/train/saver.rb +14 -14
- data/lib/tensor_stream/train/slot_creator.rb +6 -6
- data/lib/tensor_stream/train/utils.rb +12 -12
- data/lib/tensor_stream/trainer.rb +10 -10
- data/lib/tensor_stream/types.rb +1 -1
- data/lib/tensor_stream/utils.rb +33 -32
- data/lib/tensor_stream/utils/freezer.rb +5 -5
- data/lib/tensor_stream/variable.rb +5 -5
- data/lib/tensor_stream/variable_scope.rb +1 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/{iris.data → datasets/iris.data} +0 -0
- data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
- data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
- data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
- data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
- data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
- data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
- data/samples/regression/linear_regression.rb +63 -0
- data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
- data/tensor_stream.gemspec +9 -8
- metadata +89 -19
- data/data_1.json +0 -4764
- data/data_2.json +0 -4764
- data/data_actual.json +0 -28
- data/data_expected.json +0 -28
- data/data_input.json +0 -28
- data/samples/error.graphml +0 -2755
- data/samples/gradient_sample.graphml +0 -1255
- data/samples/linear_regression.rb +0 -69
- data/samples/multigpu.rb +0 -73
- data/samples/raw_neural_net_sample.rb +0 -112
@@ -0,0 +1,24 @@
|
|
1
|
+
# This file has ben automatically generated by stubgen
|
2
|
+
# DO NOT EDIT
|
3
|
+
#
|
4
|
+
module TensorStream
|
5
|
+
module OpStub
|
6
|
+
<% TensorStream::OpMaker.each_op do |op|%>
|
7
|
+
##
|
8
|
+
<% op.description_lines.each do |line|%> # <%= line %>
|
9
|
+
<%end%> #
|
10
|
+
#<% if op.supports_broadcasting? %> This operation supports broadcasting
|
11
|
+
#<% end %>
|
12
|
+
# Params:
|
13
|
+
<% op.parameters.each do |param| %> # +<%= param[:name] %>+:: <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
|
14
|
+
<% end %> #
|
15
|
+
# Options:
|
16
|
+
<% op.options.each do |k, v| %> # +:<%= k %>+:: <%= v[:description]%><% if v[:default_value] != :nil %> default (<%= v[:default_value] %>)<%end%>
|
17
|
+
<%end%> def <%= op.operation.to_s %>(<%= (op.expand_params(true) + op.expand_options(true)).join(', ') %>)
|
18
|
+
<%= op.generate_body %>
|
19
|
+
end
|
20
|
+
<% op.aliases.each do |a|%>
|
21
|
+
alias_method :<%= a %>, :<%= op.operation %><%end%>
|
22
|
+
<% end %>
|
23
|
+
end
|
24
|
+
end
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -12,7 +12,7 @@ module TensorStream
|
|
12
12
|
@node_keys = []
|
13
13
|
@collections = {
|
14
14
|
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
15
|
-
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
15
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
|
16
16
|
}
|
17
17
|
@constants = {}
|
18
18
|
end
|
@@ -27,7 +27,7 @@ module TensorStream
|
|
27
27
|
@node_keys = []
|
28
28
|
@collections = {
|
29
29
|
:"#{GraphKeys::GLOBAL_VARIABLES}" => [],
|
30
|
-
:"#{GraphKeys::TRAINABLE_VARIABLES}" => []
|
30
|
+
:"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
|
31
31
|
}
|
32
32
|
@constants = {}
|
33
33
|
end
|
@@ -85,14 +85,14 @@ module TensorStream
|
|
85
85
|
end
|
86
86
|
|
87
87
|
def add_node(node, name = nil)
|
88
|
-
raise
|
88
|
+
raise "Placeholder cannot be used when eager_execution is enabled" if @eager_execution && node.is_a?(Placeholder)
|
89
89
|
|
90
90
|
if name.nil?
|
91
91
|
node.name = if @nodes[node.name]
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
92
|
+
uniqunify(node.name)
|
93
|
+
else
|
94
|
+
node.name
|
95
|
+
end
|
96
96
|
end
|
97
97
|
|
98
98
|
node.device = get_device_scope
|
@@ -129,10 +129,10 @@ module TensorStream
|
|
129
129
|
|
130
130
|
def add_op(operation, *args)
|
131
131
|
options = if args.last.is_a?(Hash)
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
132
|
+
args.pop
|
133
|
+
else
|
134
|
+
{}
|
135
|
+
end
|
136
136
|
|
137
137
|
inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
|
138
138
|
|
@@ -141,7 +141,7 @@ module TensorStream
|
|
141
141
|
new_op.operation = operation
|
142
142
|
new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
|
143
143
|
new_op.rank = new_op.shape.rank
|
144
|
-
new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join(
|
144
|
+
new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join("/")
|
145
145
|
new_op.internal = options[:internal]
|
146
146
|
|
147
147
|
new_op.data_type = new_op.set_data_type(options[:data_type])
|
@@ -211,7 +211,7 @@ module TensorStream
|
|
211
211
|
def get_operation_counter
|
212
212
|
@op_counter ||= 0
|
213
213
|
|
214
|
-
name = @op_counter.zero? ?
|
214
|
+
name = @op_counter.zero? ? "" : "_#{@op_counter}"
|
215
215
|
|
216
216
|
@op_counter += 1
|
217
217
|
|
@@ -222,7 +222,7 @@ module TensorStream
|
|
222
222
|
@placeholder_counter ||= 0
|
223
223
|
@placeholder_counter += 1
|
224
224
|
|
225
|
-
return
|
225
|
+
return "" if @placeholder_counter == 1
|
226
226
|
|
227
227
|
"_#{@placeholder_counter}"
|
228
228
|
end
|
@@ -231,14 +231,14 @@ module TensorStream
|
|
231
231
|
@var_counter ||= 0
|
232
232
|
@var_counter += 1
|
233
233
|
|
234
|
-
return
|
234
|
+
return "" if @var_counter == 1
|
235
235
|
"_#{@var_counter}"
|
236
236
|
end
|
237
237
|
|
238
238
|
def get_const_counter
|
239
239
|
@const_counter ||= 0
|
240
240
|
|
241
|
-
name = @const_counter.zero? ?
|
241
|
+
name = @const_counter.zero? ? "" : "_#{@const_counter}"
|
242
242
|
|
243
243
|
@const_counter += 1
|
244
244
|
name
|
@@ -248,7 +248,7 @@ module TensorStream
|
|
248
248
|
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
|
249
249
|
return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
|
250
250
|
|
251
|
-
graph_thread_storage[:current_scope].join(
|
251
|
+
graph_thread_storage[:current_scope].join("/")
|
252
252
|
end
|
253
253
|
|
254
254
|
def get_dependency_scope
|
@@ -279,7 +279,7 @@ module TensorStream
|
|
279
279
|
protected
|
280
280
|
|
281
281
|
def _variable_scope
|
282
|
-
return VariableScope.new(name:
|
282
|
+
return VariableScope.new(name: "", reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
|
283
283
|
scope = Thread.current[:tensor_stream_variable_scope].last
|
284
284
|
scope
|
285
285
|
end
|
@@ -13,48 +13,48 @@ module TensorStream
|
|
13
13
|
protobuf = TensorStream::Protobuf.new
|
14
14
|
parsed_tree = protobuf.load_from_string(buffer)
|
15
15
|
parsed_tree.each do |node|
|
16
|
-
next unless node[
|
16
|
+
next unless node["type"] == "node"
|
17
17
|
|
18
18
|
# puts "build #{node['name']}"
|
19
19
|
options = protobuf.options_evaluator(node)
|
20
|
-
options[:name] = node[
|
20
|
+
options[:name] = node["name"]
|
21
21
|
options[:__graph] = @graph
|
22
|
-
value = options.delete(
|
22
|
+
value = options.delete("value")
|
23
23
|
options = symbolize_keys(options)
|
24
|
-
case node[
|
25
|
-
when
|
24
|
+
case node["op"]
|
25
|
+
when "Const"
|
26
26
|
dimension = shape_eval(value)
|
27
27
|
rank = dimension.size
|
28
28
|
options[:value] = value
|
29
29
|
options[:const] = true
|
30
30
|
TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
|
31
|
-
when
|
31
|
+
when "VariableV2"
|
32
32
|
# evaluate options
|
33
33
|
shape = options[:shape]
|
34
34
|
i_var(options[:dtype] || options[:T], nil, shape, nil, options)
|
35
|
-
when
|
35
|
+
when "Placeholder"
|
36
36
|
shape = options[:shape]
|
37
37
|
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
38
38
|
else
|
39
|
-
op = underscore(node[
|
39
|
+
op = underscore(node["op"]).to_sym
|
40
40
|
puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
|
41
41
|
|
42
42
|
# map input tensor
|
43
|
-
inputs = node[
|
44
|
-
input[0] =
|
43
|
+
inputs = node["input"].map { |input|
|
44
|
+
input[0] = "" if input.start_with?("^")
|
45
45
|
|
46
|
-
input_indexed, index = input.split(
|
46
|
+
input_indexed, index = input.split(":")
|
47
47
|
|
48
48
|
tensor = if index && index.to_i > 0
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
49
|
+
@graph.get_tensor_by_name(input_indexed)[index.to_i]
|
50
|
+
else
|
51
|
+
@graph.get_tensor_by_name(input)
|
52
|
+
end
|
53
53
|
|
54
54
|
raise "tensor not found by name #{input}" if tensor.nil?
|
55
55
|
|
56
56
|
tensor
|
57
|
-
|
57
|
+
}
|
58
58
|
|
59
59
|
options[:data_type] = options.delete(:T)
|
60
60
|
Graph.get_default_graph.add_op!(op, *inputs, options)
|
@@ -64,4 +64,4 @@ module TensorStream
|
|
64
64
|
@graph
|
65
65
|
end
|
66
66
|
end
|
67
|
-
end
|
67
|
+
end
|
@@ -1,4 +1,4 @@
|
|
1
|
-
require
|
1
|
+
require "yaml"
|
2
2
|
|
3
3
|
module TensorStream
|
4
4
|
# A .pb graph deserializer
|
@@ -14,7 +14,7 @@ module TensorStream
|
|
14
14
|
# parsers a protobuf file and spits out
|
15
15
|
# a ruby hash
|
16
16
|
def load(pbfile)
|
17
|
-
f = File.new(pbfile,
|
17
|
+
f = File.new(pbfile, "r")
|
18
18
|
lines = []
|
19
19
|
while !f.eof? && (str = f.readline.strip)
|
20
20
|
lines << str
|
@@ -23,38 +23,38 @@ module TensorStream
|
|
23
23
|
end
|
24
24
|
|
25
25
|
def parse_value(value_node)
|
26
|
-
return unless value_node[
|
26
|
+
return unless value_node["tensor"]
|
27
27
|
|
28
|
-
evaluate_tensor_node(value_node[
|
28
|
+
evaluate_tensor_node(value_node["tensor"])
|
29
29
|
end
|
30
30
|
|
31
31
|
def evaluate_tensor_node(node)
|
32
|
-
if !node[
|
33
|
-
content = node[
|
34
|
-
unpacked = eval(%
|
32
|
+
if !node["shape"].empty? && node["tensor_content"]
|
33
|
+
content = node["tensor_content"]
|
34
|
+
unpacked = eval(%("#{content}"))
|
35
35
|
|
36
|
-
if node[
|
37
|
-
TensorShape.reshape(unpacked.unpack(
|
38
|
-
elsif node[
|
39
|
-
TensorShape.reshape(unpacked.unpack(
|
40
|
-
elsif node[
|
41
|
-
node[
|
36
|
+
if node["dtype"] == "DT_FLOAT"
|
37
|
+
TensorShape.reshape(unpacked.unpack("f*"), node["shape"])
|
38
|
+
elsif node["dtype"] == "DT_INT32"
|
39
|
+
TensorShape.reshape(unpacked.unpack("l*"), node["shape"])
|
40
|
+
elsif node["dtype"] == "DT_STRING"
|
41
|
+
node["string_val"]
|
42
42
|
else
|
43
|
-
raise "unknown dtype #{node[
|
43
|
+
raise "unknown dtype #{node["dtype"]}"
|
44
44
|
end
|
45
45
|
else
|
46
46
|
|
47
|
-
val = if node[
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
47
|
+
val = if node["dtype"] == "DT_FLOAT"
|
48
|
+
node["float_val"] ? node["float_val"].to_f : []
|
49
|
+
elsif node["dtype"] == "DT_INT32"
|
50
|
+
node["int_val"] ? node["int_val"].to_i : []
|
51
|
+
elsif node["dtype"] == "DT_STRING"
|
52
|
+
node["string_val"]
|
53
|
+
else
|
54
|
+
raise "unknown dtype #{node["dtype"]}"
|
55
|
+
end
|
56
56
|
|
57
|
-
if node[
|
57
|
+
if node["shape"] == [1]
|
58
58
|
[val]
|
59
59
|
else
|
60
60
|
val
|
@@ -63,16 +63,16 @@ module TensorStream
|
|
63
63
|
end
|
64
64
|
|
65
65
|
def map_type_to_ts(attr_value)
|
66
|
-
case
|
67
|
-
when
|
66
|
+
case attr_value
|
67
|
+
when "DT_FLOAT"
|
68
68
|
:float32
|
69
|
-
when
|
69
|
+
when "DT_INT32"
|
70
70
|
:int32
|
71
|
-
when
|
71
|
+
when "DT_INT64"
|
72
72
|
:int64
|
73
|
-
when
|
73
|
+
when "DT_STRING"
|
74
74
|
:string
|
75
|
-
when
|
75
|
+
when "DT_BOOL"
|
76
76
|
:boolean
|
77
77
|
else
|
78
78
|
raise "unknown type #{attr_value}"
|
@@ -80,21 +80,21 @@ module TensorStream
|
|
80
80
|
end
|
81
81
|
|
82
82
|
def options_evaluator(node)
|
83
|
-
return {} if node[
|
83
|
+
return {} if node["attributes"].nil?
|
84
84
|
|
85
|
-
node[
|
86
|
-
attr_type, attr_value = attribute[
|
85
|
+
node["attributes"].map { |attribute|
|
86
|
+
attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] }
|
87
87
|
|
88
|
-
if attr_type ==
|
88
|
+
if attr_type == "tensor"
|
89
89
|
attr_value = evaluate_tensor_node(attr_value)
|
90
|
-
elsif attr_type ==
|
90
|
+
elsif attr_type == "type"
|
91
91
|
attr_value = map_type_to_ts(attr_value)
|
92
|
-
elsif attr_type ==
|
93
|
-
attr_value = attr_value ==
|
92
|
+
elsif attr_type == "b"
|
93
|
+
attr_value = attr_value == "true"
|
94
94
|
end
|
95
95
|
|
96
|
-
[attribute[
|
97
|
-
|
96
|
+
[attribute["key"], attr_value]
|
97
|
+
}.to_h
|
98
98
|
end
|
99
99
|
|
100
100
|
protected
|
@@ -108,126 +108,126 @@ module TensorStream
|
|
108
108
|
lines.each do |str|
|
109
109
|
case state
|
110
110
|
when :top
|
111
|
-
node[
|
111
|
+
node["type"] = parse_node_name(str)
|
112
112
|
state = :node_context
|
113
113
|
next
|
114
114
|
when :node_context
|
115
|
-
if str ==
|
115
|
+
if str == "attr {"
|
116
116
|
state = :attr_context
|
117
117
|
node_attr = {}
|
118
|
-
node[
|
119
|
-
node[
|
118
|
+
node["attributes"] ||= []
|
119
|
+
node["attributes"] << node_attr
|
120
120
|
next
|
121
|
-
elsif str ==
|
121
|
+
elsif str == "}"
|
122
122
|
state = :top
|
123
123
|
block << node
|
124
124
|
node = {}
|
125
125
|
next
|
126
126
|
else
|
127
|
-
key, value = str.split(
|
128
|
-
if key ==
|
129
|
-
node[
|
130
|
-
node[
|
127
|
+
key, value = str.split(":", 2)
|
128
|
+
if key == "input"
|
129
|
+
node["input"] ||= []
|
130
|
+
node["input"] << process_value(value.strip)
|
131
131
|
else
|
132
132
|
node[key] = process_value(value.strip)
|
133
133
|
end
|
134
134
|
end
|
135
135
|
when :attr_context
|
136
|
-
if str ==
|
136
|
+
if str == "value {"
|
137
137
|
state = :value_context
|
138
|
-
node_attr[
|
138
|
+
node_attr["value"] = {}
|
139
139
|
next
|
140
|
-
elsif str ==
|
140
|
+
elsif str == "}"
|
141
141
|
state = :node_context
|
142
142
|
next
|
143
143
|
else
|
144
|
-
key, value = str.split(
|
144
|
+
key, value = str.split(":", 2)
|
145
145
|
node_attr[key] = process_value(value.strip)
|
146
146
|
end
|
147
147
|
when :value_context
|
148
|
-
if str ==
|
148
|
+
if str == "list {"
|
149
149
|
state = :list_context
|
150
|
-
node_attr[
|
150
|
+
node_attr["value"] = []
|
151
151
|
next
|
152
|
-
elsif str ==
|
152
|
+
elsif str == "shape {"
|
153
153
|
state = :shape_context
|
154
|
-
node_attr[
|
154
|
+
node_attr["value"]["shape"] = []
|
155
155
|
next
|
156
|
-
elsif str ==
|
156
|
+
elsif str == "tensor {"
|
157
157
|
state = :tensor_context
|
158
|
-
node_attr[
|
158
|
+
node_attr["value"]["tensor"] = {}
|
159
159
|
next
|
160
|
-
elsif str ==
|
160
|
+
elsif str == "}"
|
161
161
|
state = :attr_context
|
162
162
|
next
|
163
163
|
else
|
164
|
-
key, value = str.split(
|
165
|
-
if key ==
|
166
|
-
node_attr[
|
167
|
-
elsif key ==
|
168
|
-
node_attr[
|
164
|
+
key, value = str.split(":", 2)
|
165
|
+
if key == "dtype"
|
166
|
+
node_attr["value"]["dtype"] = value.strip
|
167
|
+
elsif key == "type"
|
168
|
+
node_attr["value"]["type"] = value.strip
|
169
169
|
else
|
170
|
-
node_attr[
|
170
|
+
node_attr["value"][key] = process_value(value.strip)
|
171
171
|
end
|
172
172
|
end
|
173
173
|
when :list_context
|
174
|
-
if str ==
|
174
|
+
if str == "}"
|
175
175
|
state = :value_context
|
176
176
|
next
|
177
177
|
else
|
178
|
-
key, value = str.split(
|
179
|
-
node_attr[
|
178
|
+
key, value = str.split(":", 2)
|
179
|
+
node_attr["value"] << {key => value}
|
180
180
|
end
|
181
181
|
when :tensor_context
|
182
|
-
if str ==
|
182
|
+
if str == "tensor_shape {"
|
183
183
|
state = :tensor_shape_context
|
184
|
-
node_attr[
|
184
|
+
node_attr["value"]["tensor"]["shape"] = []
|
185
185
|
next
|
186
|
-
elsif str ==
|
186
|
+
elsif str == "}"
|
187
187
|
state = :value_context
|
188
188
|
next
|
189
189
|
else
|
190
|
-
key, value = str.split(
|
191
|
-
if node_attr[
|
192
|
-
node_attr[
|
193
|
-
node_attr[
|
194
|
-
elsif node_attr[
|
195
|
-
node_attr[
|
190
|
+
key, value = str.split(":", 2)
|
191
|
+
if node_attr["value"]["tensor"][key] && !node_attr["value"]["tensor"][key].is_a?(Array)
|
192
|
+
node_attr["value"]["tensor"][key] = [node_attr["value"]["tensor"][key]]
|
193
|
+
node_attr["value"]["tensor"][key] << process_value(value.strip)
|
194
|
+
elsif node_attr["value"]["tensor"][key]
|
195
|
+
node_attr["value"]["tensor"][key] << process_value(value.strip)
|
196
196
|
else
|
197
|
-
node_attr[
|
197
|
+
node_attr["value"]["tensor"][key] = process_value(value.strip)
|
198
198
|
end
|
199
199
|
end
|
200
200
|
when :tensor_shape_context
|
201
|
-
if str ==
|
201
|
+
if str == "dim {"
|
202
202
|
state = :tensor_shape_dim_context
|
203
203
|
next
|
204
|
-
elsif str ==
|
204
|
+
elsif str == "}"
|
205
205
|
state = :tensor_context
|
206
206
|
next
|
207
207
|
end
|
208
208
|
when :shape_context
|
209
|
-
if str ==
|
209
|
+
if str == "}"
|
210
210
|
state = :value_context
|
211
211
|
next
|
212
|
-
elsif str ==
|
212
|
+
elsif str == "dim {"
|
213
213
|
state = :shape_dim_context
|
214
214
|
next
|
215
215
|
end
|
216
216
|
when :shape_dim_context
|
217
|
-
if str ==
|
217
|
+
if str == "}"
|
218
218
|
state = :shape_context
|
219
219
|
next
|
220
220
|
else
|
221
|
-
_key, value = str.split(
|
222
|
-
node_attr[
|
221
|
+
_key, value = str.split(":", 2)
|
222
|
+
node_attr["value"]["shape"] << value.strip.to_i
|
223
223
|
end
|
224
224
|
when :tensor_shape_dim_context
|
225
|
-
if str ==
|
225
|
+
if str == "}"
|
226
226
|
state = :tensor_shape_context
|
227
227
|
next
|
228
228
|
else
|
229
|
-
_key, value = str.split(
|
230
|
-
node_attr[
|
229
|
+
_key, value = str.split(":", 2)
|
230
|
+
node_attr["value"]["tensor"]["shape"] << value.strip.to_i
|
231
231
|
end
|
232
232
|
end
|
233
233
|
end
|
@@ -236,22 +236,22 @@ module TensorStream
|
|
236
236
|
end
|
237
237
|
|
238
238
|
def parse_node_name(str)
|
239
|
-
str.split(
|
239
|
+
str.split(" ")[0]
|
240
240
|
end
|
241
241
|
|
242
242
|
def process_value(value)
|
243
243
|
if value.start_with?('"')
|
244
|
-
unescape(value.gsub!(/\A"|"\Z/,
|
244
|
+
unescape(value.gsub!(/\A"|"\Z/, ""))
|
245
245
|
else
|
246
246
|
unescape(value)
|
247
247
|
end
|
248
248
|
end
|
249
249
|
|
250
250
|
UNESCAPES = {
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
"\"" => "\x22", "'" => "\x27"
|
251
|
+
"a" => "\x07", "b" => "\x08", "t" => "\x09",
|
252
|
+
"n" => "\x0a", "v" => "\x0b", "f" => "\x0c",
|
253
|
+
"r" => "\x0d", "e" => "\x1b", "\\\\" => "\x5c",
|
254
|
+
"\"" => "\x22", "'" => "\x27",
|
255
255
|
}.freeze
|
256
256
|
|
257
257
|
def unescape(str)
|
@@ -260,11 +260,11 @@ module TensorStream
|
|
260
260
|
if $1
|
261
261
|
$1 == '\\' ? '\\' : UNESCAPES[$1]
|
262
262
|
elsif $2 # escape \u0000 unicode
|
263
|
-
[
|
263
|
+
[$2.to_s.hex].pack("U*")
|
264
264
|
elsif $3 # escape \0xff or \xff
|
265
|
-
[$3].pack(
|
265
|
+
[$3].pack("H2")
|
266
266
|
end
|
267
267
|
end
|
268
268
|
end
|
269
269
|
end
|
270
|
-
end
|
270
|
+
end
|