tensor_stream 1.0.4 → 1.0.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +12 -2
- data/Dockerfile +1 -1
- data/USAGE_GUIDE.md +68 -0
- data/lib/tensor_stream.rb +1 -0
- data/lib/tensor_stream/evaluator/base_evaluator.rb +21 -1
- data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
- data/lib/tensor_stream/evaluator/evaluator_utils.rb +20 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +26 -0
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +60 -5
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +25 -29
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +7 -11
- data/lib/tensor_stream/evaluator/ruby/storage_manager.rb +40 -0
- data/lib/tensor_stream/evaluator/ruby/variable_ops.rb +74 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +31 -77
- data/lib/tensor_stream/generated_stub/ops.rb +256 -166
- data/lib/tensor_stream/generated_stub/stub_file.erb +4 -4
- data/lib/tensor_stream/graph.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +4 -6
- data/lib/tensor_stream/helpers/infer_shape.rb +1 -7
- data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
- data/lib/tensor_stream/images.rb +4 -0
- data/lib/tensor_stream/math/math_ops.rb +22 -0
- data/lib/tensor_stream/math_gradients.rb +15 -1
- data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
- data/lib/tensor_stream/nn/nn_ops.rb +16 -0
- data/lib/tensor_stream/op_maker.rb +36 -3
- data/lib/tensor_stream/operation.rb +8 -20
- data/lib/tensor_stream/ops.rb +14 -11
- data/lib/tensor_stream/ops/bias_add.rb +16 -0
- data/lib/tensor_stream/ops/equal.rb +4 -0
- data/lib/tensor_stream/ops/greater.rb +4 -0
- data/lib/tensor_stream/ops/greater_equal.rb +4 -0
- data/lib/tensor_stream/ops/less.rb +19 -0
- data/lib/tensor_stream/ops/less_equal.rb +4 -0
- data/lib/tensor_stream/ops/not_equal.rb +19 -0
- data/lib/tensor_stream/ops/rsqrt.rb +11 -0
- data/lib/tensor_stream/ops/strided_slice.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +4 -2
- data/lib/tensor_stream/ops/top_k.rb +23 -0
- data/lib/tensor_stream/session.rb +6 -12
- data/lib/tensor_stream/tensor.rb +1 -0
- data/lib/tensor_stream/tensor_shape.rb +32 -1
- data/lib/tensor_stream/train/saver.rb +2 -3
- data/lib/tensor_stream/utils.rb +18 -13
- data/lib/tensor_stream/utils/freezer.rb +5 -1
- data/lib/tensor_stream/utils/py_ports.rb +11 -0
- data/lib/tensor_stream/variable.rb +9 -6
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/word_embeddings/word_embedding_1.rb +192 -0
- data/samples/word_embeddings/word_embedding_2.rb +203 -0
- data/tensor_stream.gemspec +7 -2
- metadata +67 -10
@@ -9,12 +9,12 @@ module TensorStream
|
|
9
9
|
<%end%> #
|
10
10
|
#<% if op.supports_broadcasting? %> This operation supports broadcasting
|
11
11
|
#<% end %>
|
12
|
-
#
|
13
|
-
<% op.parameters.each do |param| %> # +<%= param[:name] %>+:: <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
|
12
|
+
<% op.parameters.each do |param| %> # @param <%= param[:name] %> <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
|
14
13
|
<% end %> #
|
15
14
|
# Options:
|
16
|
-
<% op.options.each do |k, v| %> #
|
17
|
-
<%end%>
|
15
|
+
<% op.options.each do |k, v| %> # @option <%= k %> <%= v[:description]%><% if v[:default_value] != :nil %> default (<%= v[:default_value] %>)<%end%>
|
16
|
+
<%end%> # @return Tensor
|
17
|
+
def <%= op.operation.to_s %>(<%= (op.expand_params(true) + op.expand_options(true)).join(', ') %>)
|
18
18
|
<%= op.generate_body %>
|
19
19
|
end
|
20
20
|
<% op.aliases.each do |a|%>
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -30,6 +30,7 @@ module TensorStream
|
|
30
30
|
:"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
|
31
31
|
}
|
32
32
|
@constants = {}
|
33
|
+
TensorStream::Evaluator.clear_storages(self)
|
33
34
|
end
|
34
35
|
|
35
36
|
def as_default
|
@@ -129,7 +130,7 @@ module TensorStream
|
|
129
130
|
|
130
131
|
def add_op(operation, *args)
|
131
132
|
options = if args.last.is_a?(Hash)
|
132
|
-
args.pop
|
133
|
+
args.pop || {}
|
133
134
|
else
|
134
135
|
{}
|
135
136
|
end
|
@@ -180,8 +181,7 @@ module TensorStream
|
|
180
181
|
|
181
182
|
def add_variable!(node, options = {})
|
182
183
|
node = add_variable(node, options)
|
183
|
-
op = Graph.get_default_graph.add_op!(:variable_v2,
|
184
|
-
node.name = op.name
|
184
|
+
op = Graph.get_default_graph.add_op!(:variable_v2, var_name: node.name, shape: options[:shape], data_type: options[:data_type])
|
185
185
|
op
|
186
186
|
end
|
187
187
|
|
@@ -31,15 +31,13 @@ module TensorStream
|
|
31
31
|
options = {}
|
32
32
|
|
33
33
|
new_var = nil
|
34
|
-
if op_def.
|
34
|
+
if op_def[:op].to_sym == :variable_v2
|
35
35
|
new_var = Variable.new(op_def.dig(:attrs, :data_type))
|
36
|
-
var_shape = op_def.dig(:attrs, :container, :shape)
|
37
|
-
var_options = op_def.dig(:attrs, :container, :options)
|
38
|
-
var_options[:name] = op_def[:name]
|
39
36
|
|
40
|
-
|
41
|
-
|
37
|
+
var_options = {}
|
38
|
+
var_options[:name] = op_def.dig(:attrs, :var_name)
|
42
39
|
|
40
|
+
new_var.prepare(nil, nil, TensorStream.get_variable_scope, var_options)
|
43
41
|
@graph.add_variable(new_var, var_options)
|
44
42
|
end
|
45
43
|
|
@@ -10,13 +10,7 @@ module TensorStream
|
|
10
10
|
def self.infer_shape(tensor)
|
11
11
|
case tensor.operation
|
12
12
|
when :assign
|
13
|
-
|
14
|
-
tensor.inputs[0].shape.shape
|
15
|
-
else
|
16
|
-
tensor.inputs[1].shape.shape
|
17
|
-
end
|
18
|
-
|
19
|
-
possible_shape
|
13
|
+
tensor.inputs[0]&.shape&.shape
|
20
14
|
when :const
|
21
15
|
shape_eval(tensor.options[:value])
|
22
16
|
when :variable_v2
|
@@ -5,7 +5,16 @@ module TensorStream
|
|
5
5
|
end
|
6
6
|
|
7
7
|
def [](index)
|
8
|
-
|
8
|
+
if index.is_a?(Range)
|
9
|
+
last = if index.end.nil?
|
10
|
+
[TensorStream.shape(self)[0]]
|
11
|
+
else
|
12
|
+
[index.max + 1]
|
13
|
+
end
|
14
|
+
_op(:strided_slice, self, [index.min], last, [1])
|
15
|
+
else
|
16
|
+
_op(:index, self, index)
|
17
|
+
end
|
9
18
|
end
|
10
19
|
|
11
20
|
def *(other)
|
data/lib/tensor_stream/images.rb
CHANGED
@@ -7,6 +7,10 @@ module TensorStream
|
|
7
7
|
_op(:decode_png, contents, channels: channels, data_type: dtype, name: name, new_shape: new_shape)
|
8
8
|
end
|
9
9
|
|
10
|
+
def self.decode_jpeg(contents, channels: 0, dtype: :uint8, name: nil, new_shape: nil)
|
11
|
+
_op(:decode_jpg, contents, channels: channels, data_type: dtype, name: name, new_shape: new_shape)
|
12
|
+
end
|
13
|
+
|
10
14
|
def self.encode_png(contents, compression: -1, name: nil, new_shape: nil, resample_method: nil)
|
11
15
|
check_allowed_types(contents, %i[uint8 uint16])
|
12
16
|
contents = convert_to_tensor(contents, dtype: :uint16)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# High level math functions
|
3
|
+
class Maths
|
4
|
+
extend TensorStream::OpHelper
|
5
|
+
|
6
|
+
module MathFunctions
|
7
|
+
|
8
|
+
##
|
9
|
+
# Normalizes along dimension axis using an L2 norm.
|
10
|
+
def l2_normalize(x, axis: nil, epsilon: 1e-12, name: nil)
|
11
|
+
TensorStream.name_scope(name, "l2_normalize", values: [x]) do |name|
|
12
|
+
x = TensorStream.convert_to_tensor(x, name: "x")
|
13
|
+
square_sum = TensorStream.reduce_sum(TensorStream.square(x), axis, keepdims: true)
|
14
|
+
x_inv_norm = TensorStream.rsqrt(TensorStream.maximum(square_sum, epsilon))
|
15
|
+
TensorStream.multiply(x, x_inv_norm, name: name)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
extend MathFunctions
|
21
|
+
end
|
22
|
+
end
|
@@ -136,7 +136,7 @@ module TensorStream
|
|
136
136
|
when :sparse_softmax_cross_entropy_with_logits
|
137
137
|
output = node
|
138
138
|
[_broadcast_mul(grad, output[1]), nil]
|
139
|
-
|
139
|
+
when :zeros_like
|
140
140
|
# non differentiable
|
141
141
|
nil
|
142
142
|
when :transpose
|
@@ -165,12 +165,26 @@ module TensorStream
|
|
165
165
|
ts.stack(grad, axis: node.options[:axis])
|
166
166
|
when :conv2d
|
167
167
|
_Conv2DGrad(node, grad)
|
168
|
+
when :flow_dynamic_stitch
|
169
|
+
num_values = node.inputs.size / 2
|
170
|
+
indices_grad = [nil] * num_values
|
171
|
+
|
172
|
+
inputs = (0...num_values).map { |i| _int32(node, node.inputs[i]) }
|
173
|
+
|
174
|
+
values_grad = inputs.map { |inp| TensorStream.gather(grad, inp) }
|
175
|
+
indices_grad + values_grad
|
176
|
+
when :gather
|
177
|
+
[_op(:gather_grad, grad, node.inputs[1], TensorStream.shape(node.inputs[0])), nil]
|
168
178
|
else
|
169
179
|
TensorStream::OpMaker.gradient_op(self, node, grad)
|
170
180
|
end
|
171
181
|
end
|
172
182
|
end
|
173
183
|
|
184
|
+
def self._int32(node, x)
|
185
|
+
(node.inputs[0].data_type == :int32 ? x : TensorStream.cast(x, :int32))
|
186
|
+
end
|
187
|
+
|
174
188
|
def self._reshape_to_input(node, grad)
|
175
189
|
ts.reshape(grad, ts.shape(node.inputs[0]))
|
176
190
|
end
|
@@ -0,0 +1,114 @@
|
|
1
|
+
require 'tensor_stream/utils/py_ports'
|
2
|
+
##
|
3
|
+
# ruby port of https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/embedding_ops.py
|
4
|
+
#
|
5
|
+
module TensorStream
|
6
|
+
module EmbeddingLookup
|
7
|
+
include TensorStream::PyPorts
|
8
|
+
|
9
|
+
##
|
10
|
+
# Looks up `ids` in a list of embedding tensors.
|
11
|
+
def embedding_lookup(params, ids, partition_strategy: "mod", name: nil, validate_indices: true, max_norm: nil)
|
12
|
+
_embedding_lookup_and_transform(params, ids, partition_strategy: partition_strategy, name: name, max_norm: max_norm, transform_fn: nil)
|
13
|
+
end
|
14
|
+
|
15
|
+
##
|
16
|
+
# Helper function for embedding_lookup and _compute_sampled_logits.
|
17
|
+
def _embedding_lookup_and_transform(params, ids, partition_strategy: "mod", name: nil, max_norm: nil, transform_fn: nil)
|
18
|
+
raise TensorStream::ValueError, "Need at least one param" if params.nil?
|
19
|
+
|
20
|
+
params = [params] unless params.is_a?(Array)
|
21
|
+
|
22
|
+
TensorStream.name_scope(name, "embedding_lookup", values: params + [ids]) do |name|
|
23
|
+
np = params.size
|
24
|
+
ids = TensorStream.convert_to_tensor(ids, name: "ids")
|
25
|
+
if (np == 1) && (transform_fn.nil? || (ids.shape.size == 1))
|
26
|
+
result = nil
|
27
|
+
TensorStream.colocate_with(params[0]) do
|
28
|
+
result = _clip(TensorStream.gather(params[0], ids, name: name), ids, max_norm)
|
29
|
+
result = transform_fn.call(result) if transform_fn
|
30
|
+
end
|
31
|
+
|
32
|
+
return TensorStream.identity(result)
|
33
|
+
else
|
34
|
+
flat_ids = TensorStream.reshape(ids, [-1])
|
35
|
+
original_indices = TensorStream.range(TensorStream.size(flat_ids))
|
36
|
+
|
37
|
+
p_assignments = nil
|
38
|
+
new_ids = nil
|
39
|
+
|
40
|
+
if partition_strategy == "mod"
|
41
|
+
p_assignments = flat_ids % np
|
42
|
+
new_ids = floor_div(flat_ids, np)
|
43
|
+
elsif partition_strategy == "div"
|
44
|
+
raise "not yet supported!"
|
45
|
+
else
|
46
|
+
raise TensorStream::ValueError, "Unrecognized partition strategy: " + partition_strategy
|
47
|
+
end
|
48
|
+
|
49
|
+
p_assignments = TensorStream.cast(p_assignments, :int32)
|
50
|
+
gather_ids = TensorStream.dynamic_partition(new_ids, p_assignments, np)
|
51
|
+
pindices = TensorStream.dynamic_partition(original_indices, p_assignments, np)
|
52
|
+
partitioned_result = []
|
53
|
+
(0...np).each do |p|
|
54
|
+
pids = gather_ids[p]
|
55
|
+
result = nil
|
56
|
+
TensorStream.colocate_with(params[p]) do
|
57
|
+
result = TensorStream.gather(params[p], pids)
|
58
|
+
if transform_fn
|
59
|
+
# If transform_fn is provided, the clip_by_norm precedes
|
60
|
+
# the transform and hence must be co-located. See below
|
61
|
+
# for the counterpart if transform_fn is not proveded.
|
62
|
+
result = transform_fn.call(_clip(result, pids, max_norm))
|
63
|
+
end
|
64
|
+
end
|
65
|
+
partitioned_result << result
|
66
|
+
end
|
67
|
+
ret = TensorStream.dynamic_stitch(pindices, partitioned_result, name: name)
|
68
|
+
|
69
|
+
if transform_fn.nil?
|
70
|
+
element_shape_s = params[0].shape[1..-1]
|
71
|
+
params[1..-1].each { |p| element_shape_s = element_shape_s.merge_with(p.shape[1..-1]) }
|
72
|
+
else
|
73
|
+
element_shape_s = ret.shape[1..-1]
|
74
|
+
end
|
75
|
+
|
76
|
+
# Compute the dynamic element shape.
|
77
|
+
element_shape_d = if element_shape_s.fully_defined?
|
78
|
+
element_shape_s
|
79
|
+
elsif transform_fn.nil?
|
80
|
+
# It's important that we compute params[0].shape on the right device
|
81
|
+
# to avoid data motion.
|
82
|
+
TensorStream.colocate_with(params[0]) do
|
83
|
+
params_shape = TensorStream.shape(params[0])
|
84
|
+
params_shape[1..-1]
|
85
|
+
end
|
86
|
+
else
|
87
|
+
TensorStream.shape(ret)[1..-1]
|
88
|
+
end
|
89
|
+
ret = TensorStream.reshape(ret, TensorStream.concat([TensorStream.shape(ids), element_shape_d], 0))
|
90
|
+
ret = _clip(ret, ids, max_norm) unless transform_fn
|
91
|
+
ret
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def _clip(params, ids, max_norm)
|
97
|
+
return params if max_norm.nil?
|
98
|
+
|
99
|
+
ids_rank, ids_static = _rank(ids)
|
100
|
+
params_rank, params_static = _rank(params)
|
101
|
+
|
102
|
+
TensorStream.clip_by_norm(params, max_norm, axes: ids_static && params_static ? (ids_rank...params_rank).to_a : TensorStream.range(ids_rank, params_rank))
|
103
|
+
end
|
104
|
+
|
105
|
+
def _rank(x)
|
106
|
+
rank = TensorStream.convert_to_tensor(x).shape.ndims
|
107
|
+
if rank
|
108
|
+
[rank, false]
|
109
|
+
else
|
110
|
+
[TensorStream.rank(x), false]
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
@@ -1,7 +1,10 @@
|
|
1
|
+
require 'tensor_stream/nn/embedding_lookup'
|
1
2
|
module TensorStream
|
2
3
|
# High level machine learning functions
|
3
4
|
class NN
|
4
5
|
extend TensorStream::OpHelper
|
6
|
+
extend TensorStream::EmbeddingLookup
|
7
|
+
extend TensorStream::Maths::MathFunctions
|
5
8
|
|
6
9
|
class << self
|
7
10
|
def softmax(logits, axis: nil, name: nil)
|
@@ -137,6 +140,19 @@ module TensorStream
|
|
137
140
|
def conv2d(input, filter, strides, padding, name: nil)
|
138
141
|
_op(:conv2d, input, filter, strides: strides, padding: padding, name: name)
|
139
142
|
end
|
143
|
+
|
144
|
+
##
|
145
|
+
# Adds bias to value.
|
146
|
+
#
|
147
|
+
# This is a narrow version of tf add where the bias is restructed to 1-D only
|
148
|
+
def bias_add(value, bias, data_format: nil, name: nil)
|
149
|
+
value = TensorStream.convert_to_tensor(value, name: "input")
|
150
|
+
bias = TensorStream.convert_to_tensor(bias, dtype: value.dtype, name: "bias")
|
151
|
+
|
152
|
+
raise TensorStreamError, "value must be at least rank 2" if value.shape.known? && value.shape.ndims < 2
|
153
|
+
|
154
|
+
_op(:bias_add, value, bias, data_format: data_format, name: name)
|
155
|
+
end
|
140
156
|
end
|
141
157
|
end
|
142
158
|
|
@@ -2,7 +2,8 @@ class TensorStream::OpMaker
|
|
2
2
|
attr_reader :operation, :description, :parameters,
|
3
3
|
:options, :gradient, :check_types,
|
4
4
|
:supports_broadcast, :data_type_coercion,
|
5
|
-
:aliases, :custom, :infer_type_proc, :exclude
|
5
|
+
:aliases, :custom, :custom_post, :infer_type_proc, :exclude,
|
6
|
+
:data_type_block
|
6
7
|
|
7
8
|
def initialize(op)
|
8
9
|
@operation = op
|
@@ -15,6 +16,7 @@ class TensorStream::OpMaker
|
|
15
16
|
@description = []
|
16
17
|
@aliases = []
|
17
18
|
@custom = []
|
19
|
+
@custom_post = []
|
18
20
|
@infer_type_proc = lambda { |tensor|
|
19
21
|
next nil if tensor.inputs[0].nil?
|
20
22
|
next tensor.inputs[0].shape.shape if tensor.inputs.size == 1
|
@@ -31,6 +33,10 @@ class TensorStream::OpMaker
|
|
31
33
|
@custom << custom_code
|
32
34
|
end
|
33
35
|
|
36
|
+
def add_custom_post(custom_code)
|
37
|
+
@custom_post << custom_code
|
38
|
+
end
|
39
|
+
|
34
40
|
def self.scan
|
35
41
|
op_files = Dir[File.join(File.dirname(__FILE__), "ops", "*.rb")]
|
36
42
|
op_files.each { |file|
|
@@ -58,6 +64,22 @@ class TensorStream::OpMaker
|
|
58
64
|
context_caller.instance_exec(tensor, &@ops[tensor.operation].infer_type_proc)
|
59
65
|
end
|
60
66
|
|
67
|
+
def self.infer_data_type(context_caller, tensor, passed_data_type)
|
68
|
+
return passed_data_type if passed_data_type
|
69
|
+
|
70
|
+
if @ops[tensor.operation] && @ops[tensor.operation].data_type_block
|
71
|
+
context_caller.instance_exec(tensor, &@ops[tensor.operation].data_type_block)
|
72
|
+
else
|
73
|
+
if tensor.inputs[0]
|
74
|
+
tensor.inputs[0].data_type
|
75
|
+
elsif tensor.inputs[1]
|
76
|
+
tensor.inputs[1].data_type
|
77
|
+
else
|
78
|
+
:unknown
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
61
83
|
def self.each_op(&block)
|
62
84
|
@ops.values.sort_by { |op| op.operation }.reject(&:exclude).each do |op|
|
63
85
|
block.call(op)
|
@@ -94,7 +116,14 @@ class TensorStream::OpMaker
|
|
94
116
|
custom.each do |c|
|
95
117
|
body << c
|
96
118
|
end
|
97
|
-
|
119
|
+
if custom_post.empty?
|
120
|
+
body << "_op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
|
121
|
+
else
|
122
|
+
body << "result = _op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
|
123
|
+
end
|
124
|
+
custom_post.each do |c|
|
125
|
+
body << c
|
126
|
+
end
|
98
127
|
body.map { |line| " #{line}"}.join("\n")
|
99
128
|
end
|
100
129
|
|
@@ -122,6 +151,10 @@ class TensorStream::OpMaker
|
|
122
151
|
@infer_type_proc = block
|
123
152
|
end
|
124
153
|
|
154
|
+
def define_data_type(&block)
|
155
|
+
@data_type_block = block
|
156
|
+
end
|
157
|
+
|
125
158
|
def expand_params(print_defaults)
|
126
159
|
@parameters.map { |param|
|
127
160
|
print_defaults && param[:default_value] ? "#{param[:name]} = #{default_with_nil(param[:default_value])}" : "#{param[:name]}"
|
@@ -163,7 +196,7 @@ class TensorStream::OpMaker
|
|
163
196
|
end
|
164
197
|
|
165
198
|
def options_call
|
166
|
-
@options.map { |k, v|
|
199
|
+
@options.reject { |k, v| v.dig(:options, :exclude) }.map { |k, v|
|
167
200
|
if v.dig(:options, :alias)
|
168
201
|
"#{v.dig(:options, :alias)}: #{k}"
|
169
202
|
else
|
@@ -7,7 +7,7 @@ module TensorStream
|
|
7
7
|
attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
|
8
8
|
attr_reader :outputs, :options, :is_const, :data_type, :shape
|
9
9
|
|
10
|
-
def initialize(graph, inputs
|
10
|
+
def initialize(graph, inputs: [], options: {})
|
11
11
|
@consumers = Set.new
|
12
12
|
@outputs = []
|
13
13
|
@op = self
|
@@ -42,14 +42,6 @@ module TensorStream
|
|
42
42
|
@options[:container] ? @options[:container].buffer : nil
|
43
43
|
end
|
44
44
|
|
45
|
-
def container
|
46
|
-
@options[:container].read_value
|
47
|
-
end
|
48
|
-
|
49
|
-
def container=(value)
|
50
|
-
@options[:container].value = value
|
51
|
-
end
|
52
|
-
|
53
45
|
def set_input(index, value)
|
54
46
|
@inputs[index] = value
|
55
47
|
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
@@ -58,6 +50,10 @@ module TensorStream
|
|
58
50
|
@data_type = set_data_type(@options[:data_type])
|
59
51
|
end
|
60
52
|
|
53
|
+
def set_option(key, value)
|
54
|
+
@options.merge!(key.to_sym => value)
|
55
|
+
end
|
56
|
+
|
61
57
|
def infer_const
|
62
58
|
return false if breakpoint
|
63
59
|
|
@@ -68,7 +64,7 @@ module TensorStream
|
|
68
64
|
true
|
69
65
|
when :placeholder
|
70
66
|
false
|
71
|
-
when :variable_v2
|
67
|
+
when :variable_v2, :assign, :assign_add, :assign_sub
|
72
68
|
false
|
73
69
|
else
|
74
70
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
@@ -96,7 +92,7 @@ module TensorStream
|
|
96
92
|
options[:data_type]
|
97
93
|
when :fill
|
98
94
|
@inputs[1].data_type
|
99
|
-
when :
|
95
|
+
when :logical_and
|
100
96
|
:boolean
|
101
97
|
when :shape, :rank, :shape_n
|
102
98
|
options[:out_type] || :int32
|
@@ -119,15 +115,7 @@ module TensorStream
|
|
119
115
|
@inputs[0].data_type
|
120
116
|
end
|
121
117
|
else
|
122
|
-
|
123
|
-
|
124
|
-
if @inputs[0]
|
125
|
-
@inputs[0].data_type
|
126
|
-
elsif @inputs[1]
|
127
|
-
@inputs[1].data_type
|
128
|
-
else
|
129
|
-
:unknown
|
130
|
-
end
|
118
|
+
OpMaker.infer_data_type(self, self, passed_data_type)
|
131
119
|
end
|
132
120
|
end
|
133
121
|
|