tensor_stream 1.0.4 → 1.0.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +12 -2
- data/Dockerfile +1 -1
- data/USAGE_GUIDE.md +68 -0
- data/lib/tensor_stream.rb +1 -0
- data/lib/tensor_stream/evaluator/base_evaluator.rb +21 -1
- data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
- data/lib/tensor_stream/evaluator/evaluator_utils.rb +20 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +26 -0
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +60 -5
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +25 -29
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +7 -11
- data/lib/tensor_stream/evaluator/ruby/storage_manager.rb +40 -0
- data/lib/tensor_stream/evaluator/ruby/variable_ops.rb +74 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +31 -77
- data/lib/tensor_stream/generated_stub/ops.rb +256 -166
- data/lib/tensor_stream/generated_stub/stub_file.erb +4 -4
- data/lib/tensor_stream/graph.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +4 -6
- data/lib/tensor_stream/helpers/infer_shape.rb +1 -7
- data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
- data/lib/tensor_stream/images.rb +4 -0
- data/lib/tensor_stream/math/math_ops.rb +22 -0
- data/lib/tensor_stream/math_gradients.rb +15 -1
- data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
- data/lib/tensor_stream/nn/nn_ops.rb +16 -0
- data/lib/tensor_stream/op_maker.rb +36 -3
- data/lib/tensor_stream/operation.rb +8 -20
- data/lib/tensor_stream/ops.rb +14 -11
- data/lib/tensor_stream/ops/bias_add.rb +16 -0
- data/lib/tensor_stream/ops/equal.rb +4 -0
- data/lib/tensor_stream/ops/greater.rb +4 -0
- data/lib/tensor_stream/ops/greater_equal.rb +4 -0
- data/lib/tensor_stream/ops/less.rb +19 -0
- data/lib/tensor_stream/ops/less_equal.rb +4 -0
- data/lib/tensor_stream/ops/not_equal.rb +19 -0
- data/lib/tensor_stream/ops/rsqrt.rb +11 -0
- data/lib/tensor_stream/ops/strided_slice.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +4 -2
- data/lib/tensor_stream/ops/top_k.rb +23 -0
- data/lib/tensor_stream/session.rb +6 -12
- data/lib/tensor_stream/tensor.rb +1 -0
- data/lib/tensor_stream/tensor_shape.rb +32 -1
- data/lib/tensor_stream/train/saver.rb +2 -3
- data/lib/tensor_stream/utils.rb +18 -13
- data/lib/tensor_stream/utils/freezer.rb +5 -1
- data/lib/tensor_stream/utils/py_ports.rb +11 -0
- data/lib/tensor_stream/variable.rb +9 -6
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/word_embeddings/word_embedding_1.rb +192 -0
- data/samples/word_embeddings/word_embedding_2.rb +203 -0
- data/tensor_stream.gemspec +7 -2
- metadata +67 -10
@@ -9,12 +9,12 @@ module TensorStream
|
|
9
9
|
<%end%> #
|
10
10
|
#<% if op.supports_broadcasting? %> This operation supports broadcasting
|
11
11
|
#<% end %>
|
12
|
-
#
|
13
|
-
<% op.parameters.each do |param| %> # +<%= param[:name] %>+:: <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
|
12
|
+
<% op.parameters.each do |param| %> # @param <%= param[:name] %> <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
|
14
13
|
<% end %> #
|
15
14
|
# Options:
|
16
|
-
<% op.options.each do |k, v| %> #
|
17
|
-
<%end%>
|
15
|
+
<% op.options.each do |k, v| %> # @option <%= k %> <%= v[:description]%><% if v[:default_value] != :nil %> default (<%= v[:default_value] %>)<%end%>
|
16
|
+
<%end%> # @return Tensor
|
17
|
+
def <%= op.operation.to_s %>(<%= (op.expand_params(true) + op.expand_options(true)).join(', ') %>)
|
18
18
|
<%= op.generate_body %>
|
19
19
|
end
|
20
20
|
<% op.aliases.each do |a|%>
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -30,6 +30,7 @@ module TensorStream
|
|
30
30
|
:"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
|
31
31
|
}
|
32
32
|
@constants = {}
|
33
|
+
TensorStream::Evaluator.clear_storages(self)
|
33
34
|
end
|
34
35
|
|
35
36
|
def as_default
|
@@ -129,7 +130,7 @@ module TensorStream
|
|
129
130
|
|
130
131
|
def add_op(operation, *args)
|
131
132
|
options = if args.last.is_a?(Hash)
|
132
|
-
args.pop
|
133
|
+
args.pop || {}
|
133
134
|
else
|
134
135
|
{}
|
135
136
|
end
|
@@ -180,8 +181,7 @@ module TensorStream
|
|
180
181
|
|
181
182
|
def add_variable!(node, options = {})
|
182
183
|
node = add_variable(node, options)
|
183
|
-
op = Graph.get_default_graph.add_op!(:variable_v2,
|
184
|
-
node.name = op.name
|
184
|
+
op = Graph.get_default_graph.add_op!(:variable_v2, var_name: node.name, shape: options[:shape], data_type: options[:data_type])
|
185
185
|
op
|
186
186
|
end
|
187
187
|
|
@@ -31,15 +31,13 @@ module TensorStream
|
|
31
31
|
options = {}
|
32
32
|
|
33
33
|
new_var = nil
|
34
|
-
if op_def.
|
34
|
+
if op_def[:op].to_sym == :variable_v2
|
35
35
|
new_var = Variable.new(op_def.dig(:attrs, :data_type))
|
36
|
-
var_shape = op_def.dig(:attrs, :container, :shape)
|
37
|
-
var_options = op_def.dig(:attrs, :container, :options)
|
38
|
-
var_options[:name] = op_def[:name]
|
39
36
|
|
40
|
-
|
41
|
-
|
37
|
+
var_options = {}
|
38
|
+
var_options[:name] = op_def.dig(:attrs, :var_name)
|
42
39
|
|
40
|
+
new_var.prepare(nil, nil, TensorStream.get_variable_scope, var_options)
|
43
41
|
@graph.add_variable(new_var, var_options)
|
44
42
|
end
|
45
43
|
|
@@ -10,13 +10,7 @@ module TensorStream
|
|
10
10
|
def self.infer_shape(tensor)
|
11
11
|
case tensor.operation
|
12
12
|
when :assign
|
13
|
-
|
14
|
-
tensor.inputs[0].shape.shape
|
15
|
-
else
|
16
|
-
tensor.inputs[1].shape.shape
|
17
|
-
end
|
18
|
-
|
19
|
-
possible_shape
|
13
|
+
tensor.inputs[0]&.shape&.shape
|
20
14
|
when :const
|
21
15
|
shape_eval(tensor.options[:value])
|
22
16
|
when :variable_v2
|
@@ -5,7 +5,16 @@ module TensorStream
|
|
5
5
|
end
|
6
6
|
|
7
7
|
def [](index)
|
8
|
-
|
8
|
+
if index.is_a?(Range)
|
9
|
+
last = if index.end.nil?
|
10
|
+
[TensorStream.shape(self)[0]]
|
11
|
+
else
|
12
|
+
[index.max + 1]
|
13
|
+
end
|
14
|
+
_op(:strided_slice, self, [index.min], last, [1])
|
15
|
+
else
|
16
|
+
_op(:index, self, index)
|
17
|
+
end
|
9
18
|
end
|
10
19
|
|
11
20
|
def *(other)
|
data/lib/tensor_stream/images.rb
CHANGED
@@ -7,6 +7,10 @@ module TensorStream
|
|
7
7
|
_op(:decode_png, contents, channels: channels, data_type: dtype, name: name, new_shape: new_shape)
|
8
8
|
end
|
9
9
|
|
10
|
+
def self.decode_jpeg(contents, channels: 0, dtype: :uint8, name: nil, new_shape: nil)
|
11
|
+
_op(:decode_jpg, contents, channels: channels, data_type: dtype, name: name, new_shape: new_shape)
|
12
|
+
end
|
13
|
+
|
10
14
|
def self.encode_png(contents, compression: -1, name: nil, new_shape: nil, resample_method: nil)
|
11
15
|
check_allowed_types(contents, %i[uint8 uint16])
|
12
16
|
contents = convert_to_tensor(contents, dtype: :uint16)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# High level math functions
|
3
|
+
class Maths
|
4
|
+
extend TensorStream::OpHelper
|
5
|
+
|
6
|
+
module MathFunctions
|
7
|
+
|
8
|
+
##
|
9
|
+
# Normalizes along dimension axis using an L2 norm.
|
10
|
+
def l2_normalize(x, axis: nil, epsilon: 1e-12, name: nil)
|
11
|
+
TensorStream.name_scope(name, "l2_normalize", values: [x]) do |name|
|
12
|
+
x = TensorStream.convert_to_tensor(x, name: "x")
|
13
|
+
square_sum = TensorStream.reduce_sum(TensorStream.square(x), axis, keepdims: true)
|
14
|
+
x_inv_norm = TensorStream.rsqrt(TensorStream.maximum(square_sum, epsilon))
|
15
|
+
TensorStream.multiply(x, x_inv_norm, name: name)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
extend MathFunctions
|
21
|
+
end
|
22
|
+
end
|
@@ -136,7 +136,7 @@ module TensorStream
|
|
136
136
|
when :sparse_softmax_cross_entropy_with_logits
|
137
137
|
output = node
|
138
138
|
[_broadcast_mul(grad, output[1]), nil]
|
139
|
-
|
139
|
+
when :zeros_like
|
140
140
|
# non differentiable
|
141
141
|
nil
|
142
142
|
when :transpose
|
@@ -165,12 +165,26 @@ module TensorStream
|
|
165
165
|
ts.stack(grad, axis: node.options[:axis])
|
166
166
|
when :conv2d
|
167
167
|
_Conv2DGrad(node, grad)
|
168
|
+
when :flow_dynamic_stitch
|
169
|
+
num_values = node.inputs.size / 2
|
170
|
+
indices_grad = [nil] * num_values
|
171
|
+
|
172
|
+
inputs = (0...num_values).map { |i| _int32(node, node.inputs[i]) }
|
173
|
+
|
174
|
+
values_grad = inputs.map { |inp| TensorStream.gather(grad, inp) }
|
175
|
+
indices_grad + values_grad
|
176
|
+
when :gather
|
177
|
+
[_op(:gather_grad, grad, node.inputs[1], TensorStream.shape(node.inputs[0])), nil]
|
168
178
|
else
|
169
179
|
TensorStream::OpMaker.gradient_op(self, node, grad)
|
170
180
|
end
|
171
181
|
end
|
172
182
|
end
|
173
183
|
|
184
|
+
def self._int32(node, x)
|
185
|
+
(node.inputs[0].data_type == :int32 ? x : TensorStream.cast(x, :int32))
|
186
|
+
end
|
187
|
+
|
174
188
|
def self._reshape_to_input(node, grad)
|
175
189
|
ts.reshape(grad, ts.shape(node.inputs[0]))
|
176
190
|
end
|
@@ -0,0 +1,114 @@
|
|
1
|
+
require 'tensor_stream/utils/py_ports'
|
2
|
+
##
|
3
|
+
# ruby port of https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/embedding_ops.py
|
4
|
+
#
|
5
|
+
module TensorStream
|
6
|
+
module EmbeddingLookup
|
7
|
+
include TensorStream::PyPorts
|
8
|
+
|
9
|
+
##
|
10
|
+
# Looks up `ids` in a list of embedding tensors.
|
11
|
+
def embedding_lookup(params, ids, partition_strategy: "mod", name: nil, validate_indices: true, max_norm: nil)
|
12
|
+
_embedding_lookup_and_transform(params, ids, partition_strategy: partition_strategy, name: name, max_norm: max_norm, transform_fn: nil)
|
13
|
+
end
|
14
|
+
|
15
|
+
##
|
16
|
+
# Helper function for embedding_lookup and _compute_sampled_logits.
|
17
|
+
def _embedding_lookup_and_transform(params, ids, partition_strategy: "mod", name: nil, max_norm: nil, transform_fn: nil)
|
18
|
+
raise TensorStream::ValueError, "Need at least one param" if params.nil?
|
19
|
+
|
20
|
+
params = [params] unless params.is_a?(Array)
|
21
|
+
|
22
|
+
TensorStream.name_scope(name, "embedding_lookup", values: params + [ids]) do |name|
|
23
|
+
np = params.size
|
24
|
+
ids = TensorStream.convert_to_tensor(ids, name: "ids")
|
25
|
+
if (np == 1) && (transform_fn.nil? || (ids.shape.size == 1))
|
26
|
+
result = nil
|
27
|
+
TensorStream.colocate_with(params[0]) do
|
28
|
+
result = _clip(TensorStream.gather(params[0], ids, name: name), ids, max_norm)
|
29
|
+
result = transform_fn.call(result) if transform_fn
|
30
|
+
end
|
31
|
+
|
32
|
+
return TensorStream.identity(result)
|
33
|
+
else
|
34
|
+
flat_ids = TensorStream.reshape(ids, [-1])
|
35
|
+
original_indices = TensorStream.range(TensorStream.size(flat_ids))
|
36
|
+
|
37
|
+
p_assignments = nil
|
38
|
+
new_ids = nil
|
39
|
+
|
40
|
+
if partition_strategy == "mod"
|
41
|
+
p_assignments = flat_ids % np
|
42
|
+
new_ids = floor_div(flat_ids, np)
|
43
|
+
elsif partition_strategy == "div"
|
44
|
+
raise "not yet supported!"
|
45
|
+
else
|
46
|
+
raise TensorStream::ValueError, "Unrecognized partition strategy: " + partition_strategy
|
47
|
+
end
|
48
|
+
|
49
|
+
p_assignments = TensorStream.cast(p_assignments, :int32)
|
50
|
+
gather_ids = TensorStream.dynamic_partition(new_ids, p_assignments, np)
|
51
|
+
pindices = TensorStream.dynamic_partition(original_indices, p_assignments, np)
|
52
|
+
partitioned_result = []
|
53
|
+
(0...np).each do |p|
|
54
|
+
pids = gather_ids[p]
|
55
|
+
result = nil
|
56
|
+
TensorStream.colocate_with(params[p]) do
|
57
|
+
result = TensorStream.gather(params[p], pids)
|
58
|
+
if transform_fn
|
59
|
+
# If transform_fn is provided, the clip_by_norm precedes
|
60
|
+
# the transform and hence must be co-located. See below
|
61
|
+
# for the counterpart if transform_fn is not proveded.
|
62
|
+
result = transform_fn.call(_clip(result, pids, max_norm))
|
63
|
+
end
|
64
|
+
end
|
65
|
+
partitioned_result << result
|
66
|
+
end
|
67
|
+
ret = TensorStream.dynamic_stitch(pindices, partitioned_result, name: name)
|
68
|
+
|
69
|
+
if transform_fn.nil?
|
70
|
+
element_shape_s = params[0].shape[1..-1]
|
71
|
+
params[1..-1].each { |p| element_shape_s = element_shape_s.merge_with(p.shape[1..-1]) }
|
72
|
+
else
|
73
|
+
element_shape_s = ret.shape[1..-1]
|
74
|
+
end
|
75
|
+
|
76
|
+
# Compute the dynamic element shape.
|
77
|
+
element_shape_d = if element_shape_s.fully_defined?
|
78
|
+
element_shape_s
|
79
|
+
elsif transform_fn.nil?
|
80
|
+
# It's important that we compute params[0].shape on the right device
|
81
|
+
# to avoid data motion.
|
82
|
+
TensorStream.colocate_with(params[0]) do
|
83
|
+
params_shape = TensorStream.shape(params[0])
|
84
|
+
params_shape[1..-1]
|
85
|
+
end
|
86
|
+
else
|
87
|
+
TensorStream.shape(ret)[1..-1]
|
88
|
+
end
|
89
|
+
ret = TensorStream.reshape(ret, TensorStream.concat([TensorStream.shape(ids), element_shape_d], 0))
|
90
|
+
ret = _clip(ret, ids, max_norm) unless transform_fn
|
91
|
+
ret
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def _clip(params, ids, max_norm)
|
97
|
+
return params if max_norm.nil?
|
98
|
+
|
99
|
+
ids_rank, ids_static = _rank(ids)
|
100
|
+
params_rank, params_static = _rank(params)
|
101
|
+
|
102
|
+
TensorStream.clip_by_norm(params, max_norm, axes: ids_static && params_static ? (ids_rank...params_rank).to_a : TensorStream.range(ids_rank, params_rank))
|
103
|
+
end
|
104
|
+
|
105
|
+
def _rank(x)
|
106
|
+
rank = TensorStream.convert_to_tensor(x).shape.ndims
|
107
|
+
if rank
|
108
|
+
[rank, false]
|
109
|
+
else
|
110
|
+
[TensorStream.rank(x), false]
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
@@ -1,7 +1,10 @@
|
|
1
|
+
require 'tensor_stream/nn/embedding_lookup'
|
1
2
|
module TensorStream
|
2
3
|
# High level machine learning functions
|
3
4
|
class NN
|
4
5
|
extend TensorStream::OpHelper
|
6
|
+
extend TensorStream::EmbeddingLookup
|
7
|
+
extend TensorStream::Maths::MathFunctions
|
5
8
|
|
6
9
|
class << self
|
7
10
|
def softmax(logits, axis: nil, name: nil)
|
@@ -137,6 +140,19 @@ module TensorStream
|
|
137
140
|
def conv2d(input, filter, strides, padding, name: nil)
|
138
141
|
_op(:conv2d, input, filter, strides: strides, padding: padding, name: name)
|
139
142
|
end
|
143
|
+
|
144
|
+
##
|
145
|
+
# Adds bias to value.
|
146
|
+
#
|
147
|
+
# This is a narrow version of tf add where the bias is restructed to 1-D only
|
148
|
+
def bias_add(value, bias, data_format: nil, name: nil)
|
149
|
+
value = TensorStream.convert_to_tensor(value, name: "input")
|
150
|
+
bias = TensorStream.convert_to_tensor(bias, dtype: value.dtype, name: "bias")
|
151
|
+
|
152
|
+
raise TensorStreamError, "value must be at least rank 2" if value.shape.known? && value.shape.ndims < 2
|
153
|
+
|
154
|
+
_op(:bias_add, value, bias, data_format: data_format, name: name)
|
155
|
+
end
|
140
156
|
end
|
141
157
|
end
|
142
158
|
|
@@ -2,7 +2,8 @@ class TensorStream::OpMaker
|
|
2
2
|
attr_reader :operation, :description, :parameters,
|
3
3
|
:options, :gradient, :check_types,
|
4
4
|
:supports_broadcast, :data_type_coercion,
|
5
|
-
:aliases, :custom, :infer_type_proc, :exclude
|
5
|
+
:aliases, :custom, :custom_post, :infer_type_proc, :exclude,
|
6
|
+
:data_type_block
|
6
7
|
|
7
8
|
def initialize(op)
|
8
9
|
@operation = op
|
@@ -15,6 +16,7 @@ class TensorStream::OpMaker
|
|
15
16
|
@description = []
|
16
17
|
@aliases = []
|
17
18
|
@custom = []
|
19
|
+
@custom_post = []
|
18
20
|
@infer_type_proc = lambda { |tensor|
|
19
21
|
next nil if tensor.inputs[0].nil?
|
20
22
|
next tensor.inputs[0].shape.shape if tensor.inputs.size == 1
|
@@ -31,6 +33,10 @@ class TensorStream::OpMaker
|
|
31
33
|
@custom << custom_code
|
32
34
|
end
|
33
35
|
|
36
|
+
def add_custom_post(custom_code)
|
37
|
+
@custom_post << custom_code
|
38
|
+
end
|
39
|
+
|
34
40
|
def self.scan
|
35
41
|
op_files = Dir[File.join(File.dirname(__FILE__), "ops", "*.rb")]
|
36
42
|
op_files.each { |file|
|
@@ -58,6 +64,22 @@ class TensorStream::OpMaker
|
|
58
64
|
context_caller.instance_exec(tensor, &@ops[tensor.operation].infer_type_proc)
|
59
65
|
end
|
60
66
|
|
67
|
+
def self.infer_data_type(context_caller, tensor, passed_data_type)
|
68
|
+
return passed_data_type if passed_data_type
|
69
|
+
|
70
|
+
if @ops[tensor.operation] && @ops[tensor.operation].data_type_block
|
71
|
+
context_caller.instance_exec(tensor, &@ops[tensor.operation].data_type_block)
|
72
|
+
else
|
73
|
+
if tensor.inputs[0]
|
74
|
+
tensor.inputs[0].data_type
|
75
|
+
elsif tensor.inputs[1]
|
76
|
+
tensor.inputs[1].data_type
|
77
|
+
else
|
78
|
+
:unknown
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
61
83
|
def self.each_op(&block)
|
62
84
|
@ops.values.sort_by { |op| op.operation }.reject(&:exclude).each do |op|
|
63
85
|
block.call(op)
|
@@ -94,7 +116,14 @@ class TensorStream::OpMaker
|
|
94
116
|
custom.each do |c|
|
95
117
|
body << c
|
96
118
|
end
|
97
|
-
|
119
|
+
if custom_post.empty?
|
120
|
+
body << "_op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
|
121
|
+
else
|
122
|
+
body << "result = _op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
|
123
|
+
end
|
124
|
+
custom_post.each do |c|
|
125
|
+
body << c
|
126
|
+
end
|
98
127
|
body.map { |line| " #{line}"}.join("\n")
|
99
128
|
end
|
100
129
|
|
@@ -122,6 +151,10 @@ class TensorStream::OpMaker
|
|
122
151
|
@infer_type_proc = block
|
123
152
|
end
|
124
153
|
|
154
|
+
def define_data_type(&block)
|
155
|
+
@data_type_block = block
|
156
|
+
end
|
157
|
+
|
125
158
|
def expand_params(print_defaults)
|
126
159
|
@parameters.map { |param|
|
127
160
|
print_defaults && param[:default_value] ? "#{param[:name]} = #{default_with_nil(param[:default_value])}" : "#{param[:name]}"
|
@@ -163,7 +196,7 @@ class TensorStream::OpMaker
|
|
163
196
|
end
|
164
197
|
|
165
198
|
def options_call
|
166
|
-
@options.map { |k, v|
|
199
|
+
@options.reject { |k, v| v.dig(:options, :exclude) }.map { |k, v|
|
167
200
|
if v.dig(:options, :alias)
|
168
201
|
"#{v.dig(:options, :alias)}: #{k}"
|
169
202
|
else
|
@@ -7,7 +7,7 @@ module TensorStream
|
|
7
7
|
attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
|
8
8
|
attr_reader :outputs, :options, :is_const, :data_type, :shape
|
9
9
|
|
10
|
-
def initialize(graph, inputs
|
10
|
+
def initialize(graph, inputs: [], options: {})
|
11
11
|
@consumers = Set.new
|
12
12
|
@outputs = []
|
13
13
|
@op = self
|
@@ -42,14 +42,6 @@ module TensorStream
|
|
42
42
|
@options[:container] ? @options[:container].buffer : nil
|
43
43
|
end
|
44
44
|
|
45
|
-
def container
|
46
|
-
@options[:container].read_value
|
47
|
-
end
|
48
|
-
|
49
|
-
def container=(value)
|
50
|
-
@options[:container].value = value
|
51
|
-
end
|
52
|
-
|
53
45
|
def set_input(index, value)
|
54
46
|
@inputs[index] = value
|
55
47
|
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
@@ -58,6 +50,10 @@ module TensorStream
|
|
58
50
|
@data_type = set_data_type(@options[:data_type])
|
59
51
|
end
|
60
52
|
|
53
|
+
def set_option(key, value)
|
54
|
+
@options.merge!(key.to_sym => value)
|
55
|
+
end
|
56
|
+
|
61
57
|
def infer_const
|
62
58
|
return false if breakpoint
|
63
59
|
|
@@ -68,7 +64,7 @@ module TensorStream
|
|
68
64
|
true
|
69
65
|
when :placeholder
|
70
66
|
false
|
71
|
-
when :variable_v2
|
67
|
+
when :variable_v2, :assign, :assign_add, :assign_sub
|
72
68
|
false
|
73
69
|
else
|
74
70
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
@@ -96,7 +92,7 @@ module TensorStream
|
|
96
92
|
options[:data_type]
|
97
93
|
when :fill
|
98
94
|
@inputs[1].data_type
|
99
|
-
when :
|
95
|
+
when :logical_and
|
100
96
|
:boolean
|
101
97
|
when :shape, :rank, :shape_n
|
102
98
|
options[:out_type] || :int32
|
@@ -119,15 +115,7 @@ module TensorStream
|
|
119
115
|
@inputs[0].data_type
|
120
116
|
end
|
121
117
|
else
|
122
|
-
|
123
|
-
|
124
|
-
if @inputs[0]
|
125
|
-
@inputs[0].data_type
|
126
|
-
elsif @inputs[1]
|
127
|
-
@inputs[1].data_type
|
128
|
-
else
|
129
|
-
:unknown
|
130
|
-
end
|
118
|
+
OpMaker.infer_data_type(self, self, passed_data_type)
|
131
119
|
end
|
132
120
|
end
|
133
121
|
|