tensor_stream 0.6.1 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +10 -0
- data/CHANGELOG.md +8 -0
- data/README.md +40 -1
- data/benchmark/benchmark.rb +4 -1
- data/lib/tensor_stream.rb +5 -0
- data/lib/tensor_stream/debugging/debugging.rb +4 -2
- data/lib/tensor_stream/device.rb +2 -1
- data/lib/tensor_stream/evaluator/base_evaluator.rb +43 -32
- data/lib/tensor_stream/evaluator/evaluator.rb +0 -1
- data/lib/tensor_stream/evaluator/opencl/kernels/acos.cl +8 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_gradient.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/asin.cl +9 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/floor_mod.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/log_softmax.cl +26 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/max.cl +5 -5
- data/lib/tensor_stream/evaluator/opencl/kernels/min.cl +46 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/real_div.cl +3 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +27 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross_grad.cl +28 -0
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +5 -6
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +200 -265
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -8
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +193 -122
- data/lib/tensor_stream/exceptions.rb +6 -0
- data/lib/tensor_stream/graph.rb +21 -6
- data/lib/tensor_stream/graph_builder.rb +67 -0
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +271 -0
- data/lib/tensor_stream/graph_keys.rb +1 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +11 -10
- data/lib/tensor_stream/helpers/op_helper.rb +7 -33
- data/lib/tensor_stream/helpers/string_helper.rb +16 -0
- data/lib/tensor_stream/math_gradients.rb +67 -44
- data/lib/tensor_stream/nn/nn_ops.rb +7 -1
- data/lib/tensor_stream/operation.rb +14 -27
- data/lib/tensor_stream/ops.rb +82 -29
- data/lib/tensor_stream/session.rb +4 -0
- data/lib/tensor_stream/tensor.rb +30 -12
- data/lib/tensor_stream/tensor_shape.rb +1 -1
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +37 -4
- data/lib/tensor_stream/train/saver.rb +46 -0
- data/lib/tensor_stream/train/utils.rb +37 -0
- data/lib/tensor_stream/trainer.rb +2 -0
- data/lib/tensor_stream/utils.rb +24 -14
- data/lib/tensor_stream/variable.rb +5 -11
- data/lib/tensor_stream/variable_scope.rb +15 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +8 -4
- data/samples/linear_regression.rb +1 -1
- data/samples/multigpu.rb +73 -0
- data/samples/nearest_neighbor.rb +3 -3
- data/tensor_stream.gemspec +1 -1
- data/test_samples/raw_neural_net_sample.rb +4 -1
- metadata +21 -6
@@ -115,6 +115,10 @@ module TensorStream
|
|
115
115
|
def delegate_to_evaluator(tensor_arr, session_context, context)
|
116
116
|
arr = tensor_arr.is_a?(Array) ? tensor_arr : [tensor_arr]
|
117
117
|
result = arr.collect do |tensor|
|
118
|
+
if session_context[:_cache][:placement][tensor.name].nil?
|
119
|
+
session_context[:_cache][:placement][tensor.name] = assign_evaluator(tensor)
|
120
|
+
end
|
121
|
+
|
118
122
|
session_context[:_cache][:placement][tensor.name][1].run_with_buffer(tensor, session_context, context)
|
119
123
|
end
|
120
124
|
result.size == 1 ? result.first : result
|
data/lib/tensor_stream/tensor.rb
CHANGED
@@ -100,6 +100,10 @@ module TensorStream
|
|
100
100
|
TensorStream.ceil(self)
|
101
101
|
end
|
102
102
|
|
103
|
+
def zero?
|
104
|
+
_op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
|
105
|
+
end
|
106
|
+
|
103
107
|
def ==(other)
|
104
108
|
_a, other = TensorStream.check_data_types(self, other)
|
105
109
|
_op(:equal, self, other)
|
@@ -137,12 +141,28 @@ module TensorStream
|
|
137
141
|
|
138
142
|
def matmul(other)
|
139
143
|
_a, other = TensorStream.check_data_types(self, other)
|
140
|
-
_op(:
|
144
|
+
_op(:mat_mul, self, other)
|
141
145
|
end
|
142
146
|
|
143
147
|
def dot(other)
|
144
148
|
_a, other = TensorStream.check_data_types(self, other)
|
145
|
-
_op(:
|
149
|
+
_op(:mat_mul, self, other)
|
150
|
+
end
|
151
|
+
|
152
|
+
##
|
153
|
+
# Apply a reduction to tensor
|
154
|
+
def reduce(op_type)
|
155
|
+
reduce_op = case op_type.to_sym
|
156
|
+
when :+
|
157
|
+
:sum
|
158
|
+
when :*
|
159
|
+
:prod
|
160
|
+
else
|
161
|
+
raise "unsupported reduce op type #{op_type}"
|
162
|
+
end
|
163
|
+
raise "blocks are not supported for tensors" if block_given?
|
164
|
+
|
165
|
+
_op(reduce_op, self, nil)
|
146
166
|
end
|
147
167
|
|
148
168
|
def collect(&block)
|
@@ -154,7 +174,7 @@ module TensorStream
|
|
154
174
|
end
|
155
175
|
|
156
176
|
def op
|
157
|
-
is_const ? _op(:const, self, nil, name:
|
177
|
+
is_const ? _op(:const, self, nil, name: name) : _op(:variable, self, nil, name: name)
|
158
178
|
end
|
159
179
|
|
160
180
|
def eval(options = {})
|
@@ -197,12 +217,12 @@ module TensorStream
|
|
197
217
|
end
|
198
218
|
end
|
199
219
|
|
200
|
-
def auto_math(tensor, name_only = false, max_depth = 99,
|
201
|
-
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth,
|
220
|
+
def auto_math(tensor, name_only = false, max_depth = 99, cur_depth = 0)
|
221
|
+
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth, cur_depth) : tensor
|
202
222
|
end
|
203
223
|
|
204
224
|
def self.detect_type(value)
|
205
|
-
if !!value==value
|
225
|
+
if !!value == value
|
206
226
|
:boolean
|
207
227
|
elsif value.is_a?(String)
|
208
228
|
:string
|
@@ -211,7 +231,7 @@ module TensorStream
|
|
211
231
|
elsif value.is_a?(Integer)
|
212
232
|
:int32
|
213
233
|
elsif value.is_a?(Array)
|
214
|
-
|
234
|
+
detect_type(value[0])
|
215
235
|
elsif value.is_a?(Tensor)
|
216
236
|
value.data_type
|
217
237
|
else
|
@@ -229,9 +249,7 @@ module TensorStream
|
|
229
249
|
end
|
230
250
|
end
|
231
251
|
|
232
|
-
if dtype.is_a?(Hash)
|
233
|
-
dtype = dtype[:dtype]
|
234
|
-
end
|
252
|
+
dtype = dtype[:dtype] if dtype.is_a?(Hash)
|
235
253
|
|
236
254
|
case dtype.to_sym
|
237
255
|
when :float64, :float32, :float
|
@@ -257,7 +275,7 @@ module TensorStream
|
|
257
275
|
end
|
258
276
|
end
|
259
277
|
|
260
|
-
def breakpoint!(&
|
278
|
+
def breakpoint!(&_block)
|
261
279
|
self
|
262
280
|
end
|
263
281
|
|
@@ -269,7 +287,7 @@ module TensorStream
|
|
269
287
|
|
270
288
|
def setup_initial_state(options)
|
271
289
|
@outputs = []
|
272
|
-
@graph = options[:
|
290
|
+
@graph = options[:__graph] || TensorStream.get_default_graph
|
273
291
|
@source = format_source(caller_locations)
|
274
292
|
end
|
275
293
|
|
@@ -71,7 +71,7 @@ module TensorStream
|
|
71
71
|
return shape if shape.empty?
|
72
72
|
|
73
73
|
current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
|
74
|
-
inferred_size = total_size / current_size
|
74
|
+
inferred_size = total_size.nil? ? nil : total_size / current_size
|
75
75
|
shape.map { |s| s == -1 ? inferred_size : s }
|
76
76
|
end
|
77
77
|
end
|
@@ -2,17 +2,50 @@ module TensorStream
|
|
2
2
|
module Train
|
3
3
|
# High Level implementation of the gradient descent algorithm
|
4
4
|
class GradientDescentOptimizer
|
5
|
+
include TensorStream::OpHelper
|
6
|
+
|
5
7
|
attr_accessor :learning_rate
|
6
8
|
|
7
9
|
def initialize(learning_rate, _options = {})
|
8
10
|
@learning_rate = learning_rate
|
9
11
|
end
|
10
12
|
|
11
|
-
def minimize(
|
12
|
-
|
13
|
-
|
13
|
+
def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil)
|
14
|
+
grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss)
|
15
|
+
apply_gradients(grads_and_vars, global_step: global_step)
|
16
|
+
end
|
17
|
+
|
18
|
+
##
|
19
|
+
# Apply gradients to variables.
|
20
|
+
# This is the second part of minimize(). It returns an Operation that applies gradients.
|
21
|
+
def apply_gradients(grads_and_vars, global_step: nil)
|
22
|
+
apply_ops = grads_and_vars.map do |grad, var|
|
23
|
+
i_op(:apply_gradient_descent, var, TensorStream.cast(@learning_rate, grad.data_type), grad)
|
24
|
+
end
|
25
|
+
|
26
|
+
if global_step.nil?
|
27
|
+
apply_ops
|
28
|
+
else
|
29
|
+
apply_ops + [global_step.assign_add(1)]
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
##
|
34
|
+
# Compute gradients of loss for the variables in var_list.
|
35
|
+
#
|
36
|
+
# This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable".
|
37
|
+
def compute_gradients(loss, var_list: nil, grad_loss: nil)
|
38
|
+
trainable_vars = if var_list
|
39
|
+
raise "var_list must be an array" unless var_list.is_a?(Array)
|
40
|
+
var_list.each_with_index { |var, index| raise "var #{index} not a Variable" unless var.is_a?(Variable) }
|
41
|
+
|
42
|
+
var_list
|
43
|
+
else
|
44
|
+
loss.graph.get_collection(TensorStream::GraphKeys::TRAINABLE_VARIABLES)
|
45
|
+
end
|
46
|
+
all_grads = grad_loss || TensorStream.gradients(loss, trainable_vars)
|
14
47
|
trainable_vars.each_with_index.collect do |var, index|
|
15
|
-
|
48
|
+
[all_grads[index], var]
|
16
49
|
end
|
17
50
|
end
|
18
51
|
end
|
@@ -4,6 +4,8 @@ module TensorStream
|
|
4
4
|
module Train
|
5
5
|
# High level class used for loading and saving variables
|
6
6
|
class Saver
|
7
|
+
include TensorStream::OpHelper
|
8
|
+
|
7
9
|
def save(session, outputfile, global_step: nil,
|
8
10
|
latest_filename: nil,
|
9
11
|
meta_graph_suffix: 'meta',
|
@@ -45,6 +47,50 @@ module TensorStream
|
|
45
47
|
|
46
48
|
private
|
47
49
|
|
50
|
+
def build_internal(names_to_saveables, reshape: false, sharded: false, max_to_keep: 5,
|
51
|
+
keep_checkpoint_every_n_hours: 10000.0,
|
52
|
+
name: nil,
|
53
|
+
restore_sequentially: false,
|
54
|
+
filename: "model",
|
55
|
+
build_save: true,
|
56
|
+
build_restore: true)
|
57
|
+
saveables = _validate_and_slice_inputs(names_to_saveables)
|
58
|
+
|
59
|
+
end
|
60
|
+
|
61
|
+
def _validate_and_slice_inputs(names_to_saveables)
|
62
|
+
saveables = []
|
63
|
+
seen_ops = []
|
64
|
+
|
65
|
+
names_to_saveables.values.sort_by { |item| item[0] }.each do |name, op|
|
66
|
+
_saveable_objects_for_op(op, name).each do |converted_saveable_object|
|
67
|
+
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
saveables
|
71
|
+
end
|
72
|
+
|
73
|
+
def _add_saveable(saveables, seen_ops, saveable)
|
74
|
+
raise TensorStreamm::ValueError, "The same saveable will be restored with two names: #{saveable.name}" if seen_ops.include?(saveable.op)
|
75
|
+
saveables << saveable
|
76
|
+
seen_ops << saveable.op
|
77
|
+
end
|
78
|
+
|
79
|
+
def save_op(filename_tensor, saveables)
|
80
|
+
tensor_names = []
|
81
|
+
tensors = []
|
82
|
+
tensor_slices = []
|
83
|
+
|
84
|
+
saveables.each do |saveable|
|
85
|
+
saveable.specs.each do |spec|
|
86
|
+
tensor_names << spec.name
|
87
|
+
tensors << spec.tensor
|
88
|
+
tensor_slices << spec.slice_spec
|
89
|
+
end
|
90
|
+
end
|
91
|
+
i_op(:save_v2, filename_tensor, tensor_names, tensor_slices, tensors)
|
92
|
+
end
|
93
|
+
|
48
94
|
def eval_global_step(session, global_step)
|
49
95
|
return nil if global_step.nil?
|
50
96
|
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module Train
|
3
|
+
# convenience methods used for training
|
4
|
+
module Utils
|
5
|
+
def create_global_step(graph = nil)
|
6
|
+
target_graph = graph || TensorStream.get_default_graph
|
7
|
+
raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil?
|
8
|
+
|
9
|
+
TensorStream.variable_scope.get_variable(
|
10
|
+
TensorStream::GraphKeys::GLOBAL_STEP, shape: [],
|
11
|
+
dtype: :int64,
|
12
|
+
initializer: TensorStream.zeros_initializer,
|
13
|
+
trainable: false,
|
14
|
+
collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES,
|
15
|
+
TensorStream::GraphKeys::GLOBAL_STEP])
|
16
|
+
end
|
17
|
+
|
18
|
+
def get_global_step(graph = nil)
|
19
|
+
target_graph = graph || TensorStream.get_default_graph
|
20
|
+
global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP)
|
21
|
+
global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty?
|
22
|
+
begin
|
23
|
+
target_graph.get_tensor_by_name('global_step:0')
|
24
|
+
rescue TensorStream::KeyError
|
25
|
+
nil
|
26
|
+
end
|
27
|
+
elsif global_step_tensors.size == 1
|
28
|
+
global_step_tensors[0]
|
29
|
+
else
|
30
|
+
TensorStream.logger.error("Multiple tensors in global_step collection.")
|
31
|
+
nil
|
32
|
+
end
|
33
|
+
global_step_tensor
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -3,6 +3,8 @@ require 'tensor_stream/train/saver'
|
|
3
3
|
|
4
4
|
module TensorStream
|
5
5
|
module Trainer
|
6
|
+
extend TensorStream::Train::Utils
|
7
|
+
|
6
8
|
def self.write_graph(graph, path, filename, as_text: true, serializer: TensorStream::Pbtext)
|
7
9
|
raise "only supports as_text=true for now" unless as_text
|
8
10
|
new_filename = File.join(path, filename)
|
data/lib/tensor_stream/utils.rb
CHANGED
@@ -41,6 +41,9 @@ module TensorStream
|
|
41
41
|
end.flatten
|
42
42
|
end
|
43
43
|
|
44
|
+
##
|
45
|
+
# Creates a variable
|
46
|
+
# A variable maintains state across sessions
|
44
47
|
def variable(value, name: nil, initializer: nil, graph: nil, dtype: nil, trainable: true)
|
45
48
|
op = Operation.new(:assign, nil, value)
|
46
49
|
common_options = {
|
@@ -51,13 +54,13 @@ module TensorStream
|
|
51
54
|
trainable: trainable
|
52
55
|
}
|
53
56
|
tensor = if value.is_a?(String)
|
54
|
-
TensorStream::Variable.new(dtype || :string, 0, [], common_options)
|
57
|
+
TensorStream::Variable.new(dtype || :string, 0, [], get_variable_scope, common_options)
|
55
58
|
elsif value.is_a?(Integer)
|
56
|
-
TensorStream::Variable.new(dtype || :int32, 0, [], common_options)
|
59
|
+
TensorStream::Variable.new(dtype || :int32, 0, [], get_variable_scope, common_options)
|
57
60
|
elsif value.is_a?(Float)
|
58
|
-
TensorStream::Variable.new(dtype || :float32, 0, [], common_options)
|
61
|
+
TensorStream::Variable.new(dtype || :float32, 0, [], get_variable_scope, common_options)
|
59
62
|
else
|
60
|
-
TensorStream::Variable.new(dtype || :float32, 0, nil, common_options)
|
63
|
+
TensorStream::Variable.new(dtype || :float32, 0, nil, get_variable_scope, common_options)
|
61
64
|
end
|
62
65
|
op.inputs[0] = tensor
|
63
66
|
tensor
|
@@ -65,16 +68,19 @@ module TensorStream
|
|
65
68
|
|
66
69
|
def variable_scope(scope = nil, reuse: nil, initializer: nil)
|
67
70
|
Thread.current[:tensor_stream_variable_scope] ||= []
|
68
|
-
|
71
|
+
variable_scope = VariableScope.new(name: scope, reuse: reuse, initializer: initializer)
|
72
|
+
Thread.current[:tensor_stream_variable_scope] << variable_scope
|
69
73
|
scope_name = __v_scope_name
|
70
|
-
|
71
|
-
|
74
|
+
if block_given?
|
75
|
+
begin
|
72
76
|
TensorStream.get_default_graph.name_scope(scope) do
|
73
77
|
yield(scope_name)
|
74
78
|
end
|
79
|
+
ensure
|
80
|
+
Thread.current[:tensor_stream_variable_scope].pop
|
75
81
|
end
|
76
|
-
|
77
|
-
|
82
|
+
else
|
83
|
+
variable_scope
|
78
84
|
end
|
79
85
|
end
|
80
86
|
|
@@ -94,8 +100,8 @@ module TensorStream
|
|
94
100
|
end
|
95
101
|
|
96
102
|
def get_variable_scope
|
97
|
-
return
|
98
|
-
|
103
|
+
return VariableScope.new unless Thread.current[:tensor_stream_variable_scope]
|
104
|
+
Thread.current[:tensor_stream_variable_scope].last || VariableScope.new
|
99
105
|
end
|
100
106
|
|
101
107
|
def __v_scope_name
|
@@ -125,6 +131,8 @@ module TensorStream
|
|
125
131
|
TensorStream::Tensor.new(dtype || :int32, 0, shape || [], shared_options)
|
126
132
|
elsif value.is_a?(String)
|
127
133
|
TensorStream::Tensor.new(dtype || :string, 0, shape || [], shared_options)
|
134
|
+
elsif !!value == value
|
135
|
+
TensorStream::Tensor.new(dtype || :boolean, 0, shape || [], shared_options)
|
128
136
|
elsif value.is_a?(Array)
|
129
137
|
dimension = shape || shape_eval(value)
|
130
138
|
rank = dimension.size
|
@@ -146,7 +154,7 @@ module TensorStream
|
|
146
154
|
end
|
147
155
|
|
148
156
|
def get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
149
|
-
|
157
|
+
get_variable_scope.get_variable(name, dtype: dtype, shape: shape, initializer: initializer, trainable: trainable, collections: collections)
|
150
158
|
end
|
151
159
|
|
152
160
|
def get_collection(name, options = {})
|
@@ -158,6 +166,8 @@ module TensorStream
|
|
158
166
|
ref.assign(value, name: name)
|
159
167
|
end
|
160
168
|
|
169
|
+
##
|
170
|
+
# Inserts a placeholder for a tensor that will be always fed.
|
161
171
|
def placeholder(dtype, shape: nil, name: nil)
|
162
172
|
TensorStream::Placeholder.new(dtype, nil, shape, name: name)
|
163
173
|
end
|
@@ -208,9 +218,9 @@ module TensorStream
|
|
208
218
|
input_a = convert_to_tensor(input_a)
|
209
219
|
input_b = convert_to_tensor(input_b)
|
210
220
|
end
|
211
|
-
|
221
|
+
|
212
222
|
if norm_dtype(input_a.data_type) != norm_dtype(input_b.data_type)
|
213
|
-
raise "Value Error: Tensor conversion requested dtype #{input_a.data_type} for tensor type #{input_b.data_type}"
|
223
|
+
raise TensorStream::ValueError, "Value Error: Tensor conversion requested dtype #{input_a.data_type} for tensor type #{input_b.data_type}"
|
214
224
|
end
|
215
225
|
|
216
226
|
[input_a, input_b]
|
@@ -2,7 +2,7 @@ module TensorStream
|
|
2
2
|
# Class that defines a TensorStream variable
|
3
3
|
class Variable < Tensor
|
4
4
|
attr_accessor :trainable, :options, :buffer
|
5
|
-
def initialize(data_type, rank, shape, options = {})
|
5
|
+
def initialize(data_type, rank, shape, variable_scope, options = {})
|
6
6
|
setup_initial_state(options)
|
7
7
|
|
8
8
|
@options = {
|
@@ -11,8 +11,10 @@ module TensorStream
|
|
11
11
|
@rank = rank
|
12
12
|
@value = nil
|
13
13
|
@is_const = false
|
14
|
-
|
15
|
-
|
14
|
+
scope_name = variable_scope ? variable_scope.name : nil
|
15
|
+
variable_scope_initializer = variable_scope ? variable_scope.initializer : nil
|
16
|
+
@name = [scope_name, options[:name] || build_name].compact.reject(&:empty?).join('/')
|
17
|
+
@initalizer_tensor = options[:initializer] ? options[:initializer] : variable_scope_initializer || TensorStream.glorot_uniform_initializer
|
16
18
|
if shape.nil? && @initalizer_tensor && @initalizer_tensor.shape
|
17
19
|
shape = @initalizer_tensor.shape.shape
|
18
20
|
end
|
@@ -66,13 +68,5 @@ module TensorStream
|
|
66
68
|
def self.global_variables_initializer
|
67
69
|
variables_initializer(TensorStream::GraphKeys::GLOBAL_VARIABLES)
|
68
70
|
end
|
69
|
-
|
70
|
-
private
|
71
|
-
|
72
|
-
def _variable_scope
|
73
|
-
return OpenStruct.new(name: '', reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
|
74
|
-
scope = Thread.current[:tensor_stream_variable_scope].last
|
75
|
-
scope
|
76
|
-
end
|
77
71
|
end
|
78
72
|
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class VariableScope
|
3
|
+
attr_accessor :name, :reuse, :initializer
|
4
|
+
|
5
|
+
def initialize(name: '', reuse: nil, initializer: nil)
|
6
|
+
@name = name
|
7
|
+
@reuse = reuse
|
8
|
+
@initializer = initializer
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
12
|
+
TensorStream::Variable.new(dtype || :float32, nil, shape, self, collections: collections, name: name, initializer: initializer, trainable: trainable)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|