tensor_stream 0.1.5 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/CHANGELOG.md +13 -0
- data/README.md +34 -0
- data/lib/tensor_stream.rb +7 -3
- data/lib/tensor_stream/control_flow.rb +1 -2
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +44 -3
- data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +9 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +70 -36
- data/lib/tensor_stream/graph.rb +15 -7
- data/lib/tensor_stream/graph_serializers/graphml.rb +183 -35
- data/lib/tensor_stream/graph_serializers/pbtext.rb +81 -14
- data/lib/tensor_stream/graph_serializers/serializer.rb +13 -0
- data/lib/tensor_stream/helpers/string_helper.rb +12 -0
- data/lib/tensor_stream/math_gradients.rb +203 -161
- data/lib/tensor_stream/operation.rb +30 -16
- data/lib/tensor_stream/ops.rb +29 -19
- data/lib/tensor_stream/placeholder.rb +2 -3
- data/lib/tensor_stream/session.rb +7 -13
- data/lib/tensor_stream/tensor.rb +22 -5
- data/lib/tensor_stream/tensor_shape.rb +2 -0
- data/lib/tensor_stream/trainer.rb +6 -1
- data/lib/tensor_stream/variable.rb +4 -3
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/gradient_sample.graphml +1255 -0
- data/samples/linear_regression.rb +1 -1
- data/samples/logistic_regression.rb +9 -2
- data/tensor_stream.gemspec +1 -1
- data/test_samples/error.graphml +120 -0
- data/test_samples/gradient_sample.graphml +1255 -0
- data/{samples → test_samples}/iris.rb +0 -0
- data/{samples → test_samples}/raw_neural_net_sample.rb +0 -0
- data/{samples → test_samples}/test.py +2 -0
- data/test_samples/test2.py +41 -0
- metadata +41 -47
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: a5a6dce7a4317dee2e9fba536056d0d346ac9fd4a2763396a5735204f77b90c6
|
4
|
+
data.tar.gz: 4a6d2973badfa0f2ac20850f6885181c4f19b6fe00e416e9f8394f89ac267f77
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 285c9deb129680a9050a2afa9811de1d3cbcf2250cb78b4b6eab06f225bf8a5532d301a4cad7c0e9d39b7b16bcffb0527c7b6cadaad8b840a78cadc8fcf92c80
|
7
|
+
data.tar.gz: 2021fc72c95a8aad4e8b7f8a5b25fa7ce64e715ae08be9d3c325f058155f1a2ea9d58be38aae8f4849dabe067c75e140f174efe498503a86c7a73820dc367e5d
|
data/CHANGELOG.md
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
# Changelog
|
2
|
+
All notable changes to this project will be documented in this file.
|
3
|
+
|
4
|
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
|
5
|
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
6
|
+
|
7
|
+
## [0.2.0] - 2018-05-27
|
8
|
+
### Added
|
9
|
+
- working logistic regression sample
|
10
|
+
- reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
|
11
|
+
- alpha support for saving to pbtext format, added graphml generation
|
12
|
+
- significant number of ops added
|
13
|
+
- ops that support broadcasting now work better
|
data/README.md
CHANGED
@@ -102,6 +102,10 @@ You can take a look at spec/tensor_stream/operation_spec.rb for a list of suppor
|
|
102
102
|
sliver of what TensorFlow can do, so feel free to file a PR to add requested
|
103
103
|
ops and test cases.
|
104
104
|
|
105
|
+
Other working samples can also be seen under tensor_stream/samples.
|
106
|
+
|
107
|
+
Samples that are used for development and are still being made to work can be found under test_samples
|
108
|
+
|
105
109
|
## Python to Ruby guide
|
106
110
|
|
107
111
|
Not all ops are available. Available ops are defined in lib/tensor_stream/ops.rb, corresponding gradients are found at lib/tensor_stream/math_gradients.
|
@@ -172,6 +176,36 @@ f = tf.matmul(a, b).breakpoint! { |tensor, a, b, result_value| binding.pry }
|
|
172
176
|
tf.session.run(f)
|
173
177
|
```
|
174
178
|
|
179
|
+
# Visualization
|
180
|
+
|
181
|
+
tensorstream does not support tensorboard yet, but a graphml generator is included:
|
182
|
+
|
183
|
+
```ruby
|
184
|
+
tf = TensorStream
|
185
|
+
a = tf.constant(1.0)
|
186
|
+
b = tf.constant(2.0)
|
187
|
+
result = a + b
|
188
|
+
sess = tf.session
|
189
|
+
sess.run(result)
|
190
|
+
|
191
|
+
File.write('gradients.graphml', TensorStream::Graphml.new.get_string(result)) # dump graph only
|
192
|
+
File.write('gradients.graphml', TensorStream::Graphml.new.get_string(result, sess)) # dump with values from session
|
193
|
+
```
|
194
|
+
|
195
|
+
the resulting graphml is designed to work with yED, after loading the graph change layout to "Flowchart" for best results
|
196
|
+
|
197
|
+
## Exporting to TensorFlow
|
198
|
+
|
199
|
+
Still in alpha but tensorstream supports TensorFlows as_graph_def serialization method:
|
200
|
+
|
201
|
+
```ruby
|
202
|
+
tf = TensorStream
|
203
|
+
a = tf.constant(1.0)
|
204
|
+
b = tf.constant(2.0)
|
205
|
+
result = a + b
|
206
|
+
File.write("model.pbtext", result.graph.as_graph_def)
|
207
|
+
```
|
208
|
+
|
175
209
|
## Roadmap
|
176
210
|
|
177
211
|
- Docs
|
data/lib/tensor_stream.rb
CHANGED
@@ -3,6 +3,7 @@ require 'deep_merge'
|
|
3
3
|
require 'matrix'
|
4
4
|
require 'concurrent'
|
5
5
|
require 'tensor_stream/helpers/op_helper'
|
6
|
+
require 'tensor_stream/helpers/string_helper'
|
6
7
|
require 'tensor_stream/initializer'
|
7
8
|
require 'tensor_stream/graph_keys'
|
8
9
|
require 'tensor_stream/types'
|
@@ -18,8 +19,11 @@ require 'tensor_stream/control_flow'
|
|
18
19
|
require 'tensor_stream/trainer'
|
19
20
|
require 'tensor_stream/nn/nn_ops'
|
20
21
|
require 'tensor_stream/evaluator/evaluator'
|
22
|
+
require 'tensor_stream/graph_serializers/serializer'
|
21
23
|
require 'tensor_stream/graph_serializers/pbtext'
|
22
24
|
require 'tensor_stream/graph_serializers/graphml'
|
25
|
+
require 'tensor_stream/math_gradients'
|
26
|
+
|
23
27
|
# require 'tensor_stream/libraries/layers'
|
24
28
|
require 'tensor_stream/monkey_patches/integer'
|
25
29
|
require 'tensor_stream/ops'
|
@@ -150,8 +154,8 @@ module TensorStream
|
|
150
154
|
end
|
151
155
|
end
|
152
156
|
|
153
|
-
def self.group(inputs)
|
154
|
-
TensorStream::ControlFlow.new(:group, inputs)
|
157
|
+
def self.group(inputs, name: nil)
|
158
|
+
TensorStream::ControlFlow.new(:group, inputs, nil, name: name)
|
155
159
|
end
|
156
160
|
|
157
161
|
def self.get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
@@ -196,6 +200,6 @@ module TensorStream
|
|
196
200
|
return input unless input.is_a?(Tensor)
|
197
201
|
return input if input.data_type.nil?
|
198
202
|
|
199
|
-
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.
|
203
|
+
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.include?(input.data_type.to_sym)
|
200
204
|
end
|
201
205
|
end
|
@@ -4,13 +4,12 @@ module TensorStream
|
|
4
4
|
attr_accessor :ops
|
5
5
|
|
6
6
|
def initialize(flow_type, items, ops = nil, options = {})
|
7
|
-
|
7
|
+
setup_initial_state(options)
|
8
8
|
|
9
9
|
@operation = :"flow_#{flow_type}"
|
10
10
|
@items = items
|
11
11
|
@name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
|
12
12
|
@ops = ops
|
13
|
-
@source = format_source(caller_locations)
|
14
13
|
@shape = TensorShape.new([items.size])
|
15
14
|
@graph.add_node(self)
|
16
15
|
end
|
@@ -1,6 +1,27 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# varoius utility functions for array processing
|
3
3
|
module ArrayOpsHelper
|
4
|
+
def slice_tensor(input, start, size)
|
5
|
+
start_index = start.shift
|
6
|
+
dimen_size = start_index + size.shift
|
7
|
+
|
8
|
+
input[start_index...dimen_size].collect do |item|
|
9
|
+
if item.is_a?(Array)
|
10
|
+
slice_tensor(item, start.dup, size.dup)
|
11
|
+
else
|
12
|
+
item
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def truncate(input, target_shape)
|
18
|
+
rank = get_rank(input)
|
19
|
+
return input if rank.zero?
|
20
|
+
|
21
|
+
start = Array.new(rank) { 0 }
|
22
|
+
slice_tensor(input, start, target_shape)
|
23
|
+
end
|
24
|
+
|
4
25
|
def broadcast(input_a, input_b)
|
5
26
|
sa = shape_eval(input_a)
|
6
27
|
sb = shape_eval(input_b)
|
@@ -24,7 +45,7 @@ module TensorStream
|
|
24
45
|
input_b = broadcast_dimensions(input_b, target_shape)
|
25
46
|
else
|
26
47
|
target_shape = shape_diff(sb, sa)
|
27
|
-
raise "Incompatible shapes for op #{shape_eval(input_a)} vs #{shape_eval(
|
48
|
+
raise "Incompatible shapes for op #{shape_eval(input_a)} vs #{shape_eval(input_b)}" if target_shape.nil?
|
28
49
|
|
29
50
|
input_a = broadcast_dimensions(input_a, target_shape)
|
30
51
|
end
|
@@ -52,7 +73,7 @@ module TensorStream
|
|
52
73
|
end
|
53
74
|
|
54
75
|
# handle 2 tensor math operations
|
55
|
-
def vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false)
|
76
|
+
def vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false, safe = true)
|
56
77
|
if get_rank(vector) < get_rank(vector2) # upgrade rank of A
|
57
78
|
duplicated = Array.new(vector2.size) do
|
58
79
|
vector
|
@@ -65,6 +86,10 @@ module TensorStream
|
|
65
86
|
vector.each_with_index.collect do |item, index|
|
66
87
|
next vector_op(item, vector2, op, switch) if item.is_a?(Array) && get_rank(vector) > get_rank(vector2)
|
67
88
|
|
89
|
+
if safe && vector2.is_a?(Array)
|
90
|
+
next nil if vector2.size != 1 && index >= vector2.size
|
91
|
+
end
|
92
|
+
|
68
93
|
z = if vector2.is_a?(Array)
|
69
94
|
if index < vector2.size
|
70
95
|
vector2[index]
|
@@ -81,7 +106,7 @@ module TensorStream
|
|
81
106
|
else
|
82
107
|
switch ? op.call(z, item) : op.call(item, z)
|
83
108
|
end
|
84
|
-
end
|
109
|
+
end.compact
|
85
110
|
end
|
86
111
|
|
87
112
|
def shape_diff(shape_a, shape_b)
|
@@ -96,5 +121,21 @@ module TensorStream
|
|
96
121
|
s - reversed_b[index]
|
97
122
|
end.reverse
|
98
123
|
end
|
124
|
+
|
125
|
+
def tile_arr(input, dimen, multiples)
|
126
|
+
t = multiples[dimen]
|
127
|
+
if dimen == multiples.size - 1
|
128
|
+
return nil if t.zero?
|
129
|
+
input * t # ruby array dup
|
130
|
+
else
|
131
|
+
new_arr = input.collect do |sub|
|
132
|
+
tile_arr(sub, dimen + 1, multiples)
|
133
|
+
end.compact
|
134
|
+
|
135
|
+
return nil if new_arr.empty?
|
136
|
+
|
137
|
+
new_arr * t
|
138
|
+
end
|
139
|
+
end
|
99
140
|
end
|
100
141
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
require 'tensor_stream/evaluator/operation_helpers/random_gaussian'
|
2
2
|
require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
|
3
|
-
require 'tensor_stream/
|
3
|
+
require 'tensor_stream/evaluator/operation_helpers/math_helper'
|
4
4
|
require 'distribution'
|
5
5
|
|
6
6
|
module TensorStream
|
@@ -28,6 +28,7 @@ module TensorStream
|
|
28
28
|
|
29
29
|
include TensorStream::OpHelper
|
30
30
|
include TensorStream::ArrayOpsHelper
|
31
|
+
include TensorStream::MathHelper
|
31
32
|
|
32
33
|
def initialize(session, context, thread_pool: nil, log_intermediates: false)
|
33
34
|
@session = session
|
@@ -179,7 +180,9 @@ module TensorStream
|
|
179
180
|
when :exp
|
180
181
|
call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) })
|
181
182
|
when :sigmoid
|
182
|
-
call_op(:sigmoid, a, child_context, ->(t, _b) {
|
183
|
+
call_op(:sigmoid, a, child_context, ->(t, _b) { sigmoid(t) })
|
184
|
+
when :sigmoid_grad
|
185
|
+
call_vector_op(:sigmoid_grad, a, b, child_context, ->(t, u) { u * sigmoid(t) * (1 - sigmoid(t)) })
|
183
186
|
when :sqrt
|
184
187
|
call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) })
|
185
188
|
when :square
|
@@ -208,7 +211,7 @@ module TensorStream
|
|
208
211
|
random = _get_randomizer(tensor, seed)
|
209
212
|
|
210
213
|
shape = tensor.options[:shape] || tensor.shape.shape
|
211
|
-
fan_in, fan_out = if shape.size
|
214
|
+
fan_in, fan_out = if shape.size.zero?
|
212
215
|
[1, 1]
|
213
216
|
elsif shape.size == 1
|
214
217
|
[1, shape[0]]
|
@@ -235,7 +238,7 @@ module TensorStream
|
|
235
238
|
when :assign_sub
|
236
239
|
tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(t, u) { t - u })
|
237
240
|
tensor.items[0].value
|
238
|
-
when :
|
241
|
+
when :mean
|
239
242
|
c = fp_type?(tensor.data_type) ? 0.0 : 0
|
240
243
|
func = lambda do |arr|
|
241
244
|
return c if arr.nil?
|
@@ -249,7 +252,7 @@ module TensorStream
|
|
249
252
|
end
|
250
253
|
|
251
254
|
reduction(child_context, tensor, func)
|
252
|
-
when :
|
255
|
+
when :sum
|
253
256
|
c = fp_type?(tensor.data_type) ? 0.0 : 0
|
254
257
|
func = lambda do |arr|
|
255
258
|
reduced_val = arr[0]
|
@@ -260,7 +263,10 @@ module TensorStream
|
|
260
263
|
end
|
261
264
|
|
262
265
|
reduction(child_context, tensor, func)
|
263
|
-
when :
|
266
|
+
when :tanh_grad
|
267
|
+
x = complete_eval(a, child_context)
|
268
|
+
call_op(:tanh_grad, x, child_context, ->(t, _b) { 1 - Math.tanh(t) * Math.tanh(t) })
|
269
|
+
when :prod
|
264
270
|
c = fp_type?(tensor.data_type) ? 1.0 : 1
|
265
271
|
func = lambda do |arr|
|
266
272
|
return c if arr.nil?
|
@@ -342,7 +348,15 @@ module TensorStream
|
|
342
348
|
func.call
|
343
349
|
else
|
344
350
|
shape = [shape.to_i] unless shape.is_a?(Array)
|
345
|
-
|
351
|
+
|
352
|
+
cache_key = "#{tensor.operation}_#{shape.to_s}"
|
353
|
+
if @context[:_cache].key?(cache_key)
|
354
|
+
return @context[:_cache][cache_key]
|
355
|
+
else
|
356
|
+
generate_vector(shape, generator: func).tap do |v|
|
357
|
+
@context[:_cache][cache_key] = v
|
358
|
+
end
|
359
|
+
end
|
346
360
|
end
|
347
361
|
when :shape
|
348
362
|
input = complete_eval(a, child_context)
|
@@ -374,6 +388,10 @@ module TensorStream
|
|
374
388
|
a = complete_eval(a, child_context)
|
375
389
|
b = complete_eval(b, child_context)
|
376
390
|
broadcast(a, b)
|
391
|
+
when :truncate
|
392
|
+
a = complete_eval(a, child_context)
|
393
|
+
b = complete_eval(b, child_context)
|
394
|
+
truncate(a, b)
|
377
395
|
when :identity
|
378
396
|
complete_eval(a, child_context)
|
379
397
|
when :print
|
@@ -390,6 +408,8 @@ module TensorStream
|
|
390
408
|
arr = complete_eval(a, child_context)
|
391
409
|
new_shape = complete_eval(b, child_context)
|
392
410
|
|
411
|
+
arr = [arr] unless arr.is_a?(Array)
|
412
|
+
|
393
413
|
flat_arr = arr.flatten
|
394
414
|
return flat_arr[0] if new_shape.size.zero? && flat_arr.size == 1
|
395
415
|
|
@@ -411,6 +431,28 @@ module TensorStream
|
|
411
431
|
b = complete_eval(b, child_context)
|
412
432
|
|
413
433
|
get_broadcast_gradient_args(a, b)
|
434
|
+
when :reduced_shape
|
435
|
+
input_shape = complete_eval(a, child_context)
|
436
|
+
axes = complete_eval(b, child_context)
|
437
|
+
|
438
|
+
return [] if axes.nil? # reduce to scalar
|
439
|
+
axes = [ axes ] unless axes.is_a?(Array)
|
440
|
+
return input_shape if axes.empty?
|
441
|
+
|
442
|
+
axes.each do |dimen|
|
443
|
+
input_shape[dimen] = 1
|
444
|
+
end
|
445
|
+
input_shape
|
446
|
+
when :tile
|
447
|
+
input = complete_eval(a, child_context)
|
448
|
+
multiples = complete_eval(b, child_context)
|
449
|
+
|
450
|
+
rank = get_rank(input)
|
451
|
+
raise '1D or higher tensor required' if rank.zero?
|
452
|
+
raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
|
453
|
+
|
454
|
+
tile = tile_arr(input, 0, multiples)
|
455
|
+
tile.nil? ? [] : tile
|
414
456
|
else
|
415
457
|
raise "unknown op #{tensor.operation}"
|
416
458
|
end.tap do |result|
|
@@ -437,18 +479,21 @@ module TensorStream
|
|
437
479
|
rescue StandardError => e
|
438
480
|
puts e.message
|
439
481
|
puts e.backtrace.join("\n")
|
482
|
+
|
440
483
|
shape_a = a.shape.shape if a
|
441
484
|
shape_b = b.shape.shape if b
|
442
485
|
dtype_a = a.data_type if a
|
443
486
|
dtype_b = b.data_type if b
|
444
487
|
a = complete_eval(a, child_context)
|
445
488
|
b = complete_eval(b, child_context)
|
446
|
-
puts "name: #{tensor.given_name}"
|
447
|
-
puts "op: #{tensor.to_math(true, 1)}"
|
448
|
-
puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
449
|
-
puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
450
|
-
dump_intermediates if @log_intermediates
|
451
|
-
|
489
|
+
# puts "name: #{tensor.given_name}"
|
490
|
+
# # puts "op: #{tensor.to_math(true, 1)}"
|
491
|
+
# puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
492
|
+
# puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
493
|
+
# dump_intermediates if @log_intermediates
|
494
|
+
# File.write('/home/jedld/workspace/tensor_stream/samples/error.graphml', TensorStream::Graphml.new.get_string(tensor, @session))
|
495
|
+
|
496
|
+
# File.write('/Users/josephemmanueldayo/workspace/gradients.graphml', TensorStream::Graphml.new.get_string(tensor, @session))
|
452
497
|
raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true,1)} defined at #{tensor.source}"
|
453
498
|
end
|
454
499
|
|
@@ -504,7 +549,7 @@ module TensorStream
|
|
504
549
|
|
505
550
|
def reduction(child_context, tensor, func)
|
506
551
|
val = complete_eval(tensor.items[0], child_context)
|
507
|
-
axis = complete_eval(tensor.
|
552
|
+
axis = complete_eval(tensor.items[1], child_context)
|
508
553
|
keep_dims = complete_eval(tensor.options[:keepdims], child_context)
|
509
554
|
rank = get_rank(val)
|
510
555
|
return val if axis && axis.is_a?(Array) && axis.empty?
|
@@ -547,19 +592,6 @@ module TensorStream
|
|
547
592
|
end
|
548
593
|
end
|
549
594
|
|
550
|
-
def slice_tensor(input, start, size)
|
551
|
-
start_index = start.shift
|
552
|
-
dimen_size = start_index + size.shift
|
553
|
-
|
554
|
-
input[start_index...dimen_size].collect do |item|
|
555
|
-
if item.is_a?(Array)
|
556
|
-
slice_tensor(item, start.dup, size.dup)
|
557
|
-
else
|
558
|
-
item
|
559
|
-
end
|
560
|
-
end
|
561
|
-
end
|
562
|
-
|
563
595
|
def matmul_const_transform(mat, mat_b, tensor)
|
564
596
|
if !mat.is_a?(Array)
|
565
597
|
compat_shape = shape_eval(mat_b).reverse
|
@@ -591,15 +623,17 @@ module TensorStream
|
|
591
623
|
raise FullEvalNotPossible.new, "full eval not possible for #{a.name}" if eval_a.is_a?(Tensor) || eval_b.is_a?(Tensor)
|
592
624
|
|
593
625
|
# ruby scalar
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
626
|
+
eval_a, eval_b = broadcast(eval_a, eval_b)
|
627
|
+
vector_op(eval_a, eval_b, op)
|
628
|
+
# if get_rank(eval_a).zero?
|
629
|
+
# if get_rank(eval_b).zero?
|
630
|
+
# op.call(eval_a, eval_b)
|
631
|
+
# else
|
632
|
+
# vector_op(eval_b, eval_a, op, true)
|
633
|
+
# end
|
634
|
+
# else
|
635
|
+
# vector_op(eval_a, eval_b, op)
|
636
|
+
# end
|
603
637
|
end
|
604
638
|
|
605
639
|
# determine possible reduction axis to be used
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -62,12 +62,14 @@ module TensorStream
|
|
62
62
|
raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
|
63
63
|
|
64
64
|
node.name = if @nodes[node.name]
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
65
|
+
uniqunify(node.name)
|
66
|
+
else
|
67
|
+
node.name
|
68
|
+
end
|
69
69
|
|
70
70
|
@nodes[node.name] = node
|
71
|
+
|
72
|
+
node.send(:propagate_outputs)
|
71
73
|
node.send(:propagate_consumer, node)
|
72
74
|
node.value = node.eval if @eager_execution
|
73
75
|
end
|
@@ -89,14 +91,12 @@ module TensorStream
|
|
89
91
|
scope = _variable_scope
|
90
92
|
|
91
93
|
raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse
|
92
|
-
|
93
94
|
return @nodes[node.name] if @nodes[node.name]
|
94
|
-
|
95
95
|
raise "shape is not declared for #{node.name}" if node.shape.nil?
|
96
96
|
|
97
97
|
if !options[:collections].nil? && !options[:collections].empty?
|
98
98
|
options[:collections] = [options[:collections]] unless options[:collections].is_a?(Array)
|
99
|
-
options[:collections].each { |coll| add_to_collection(coll, node) }
|
99
|
+
options[:collections].each { |coll| add_to_collection(coll, node) }
|
100
100
|
end
|
101
101
|
|
102
102
|
add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
|
@@ -162,6 +162,14 @@ module TensorStream
|
|
162
162
|
graph_thread_storage[:current_scope].join('/')
|
163
163
|
end
|
164
164
|
|
165
|
+
def as_graph_def
|
166
|
+
TensorStream::Pbtext.new.get_string(self)
|
167
|
+
end
|
168
|
+
|
169
|
+
def graph_def_versions
|
170
|
+
"producer: 26"
|
171
|
+
end
|
172
|
+
|
165
173
|
protected
|
166
174
|
|
167
175
|
def _variable_scope
|