tensor_stream 0.1.5 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/CHANGELOG.md +13 -0
- data/README.md +34 -0
- data/lib/tensor_stream.rb +7 -3
- data/lib/tensor_stream/control_flow.rb +1 -2
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +44 -3
- data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +9 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +70 -36
- data/lib/tensor_stream/graph.rb +15 -7
- data/lib/tensor_stream/graph_serializers/graphml.rb +183 -35
- data/lib/tensor_stream/graph_serializers/pbtext.rb +81 -14
- data/lib/tensor_stream/graph_serializers/serializer.rb +13 -0
- data/lib/tensor_stream/helpers/string_helper.rb +12 -0
- data/lib/tensor_stream/math_gradients.rb +203 -161
- data/lib/tensor_stream/operation.rb +30 -16
- data/lib/tensor_stream/ops.rb +29 -19
- data/lib/tensor_stream/placeholder.rb +2 -3
- data/lib/tensor_stream/session.rb +7 -13
- data/lib/tensor_stream/tensor.rb +22 -5
- data/lib/tensor_stream/tensor_shape.rb +2 -0
- data/lib/tensor_stream/trainer.rb +6 -1
- data/lib/tensor_stream/variable.rb +4 -3
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/gradient_sample.graphml +1255 -0
- data/samples/linear_regression.rb +1 -1
- data/samples/logistic_regression.rb +9 -2
- data/tensor_stream.gemspec +1 -1
- data/test_samples/error.graphml +120 -0
- data/test_samples/gradient_sample.graphml +1255 -0
- data/{samples → test_samples}/iris.rb +0 -0
- data/{samples → test_samples}/raw_neural_net_sample.rb +0 -0
- data/{samples → test_samples}/test.py +2 -0
- data/test_samples/test2.py +41 -0
- metadata +41 -47
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: a5a6dce7a4317dee2e9fba536056d0d346ac9fd4a2763396a5735204f77b90c6
|
4
|
+
data.tar.gz: 4a6d2973badfa0f2ac20850f6885181c4f19b6fe00e416e9f8394f89ac267f77
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 285c9deb129680a9050a2afa9811de1d3cbcf2250cb78b4b6eab06f225bf8a5532d301a4cad7c0e9d39b7b16bcffb0527c7b6cadaad8b840a78cadc8fcf92c80
|
7
|
+
data.tar.gz: 2021fc72c95a8aad4e8b7f8a5b25fa7ce64e715ae08be9d3c325f058155f1a2ea9d58be38aae8f4849dabe067c75e140f174efe498503a86c7a73820dc367e5d
|
data/CHANGELOG.md
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
# Changelog
|
2
|
+
All notable changes to this project will be documented in this file.
|
3
|
+
|
4
|
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
|
5
|
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
6
|
+
|
7
|
+
## [0.2.0] - 2018-05-27
|
8
|
+
### Added
|
9
|
+
- working logistic regression sample
|
10
|
+
- reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
|
11
|
+
- alpha support for saving to pbtext format, added graphml generation
|
12
|
+
- significant number of ops added
|
13
|
+
- ops that support broadcasting now work better
|
data/README.md
CHANGED
@@ -102,6 +102,10 @@ You can take a look at spec/tensor_stream/operation_spec.rb for a list of suppor
|
|
102
102
|
sliver of what TensorFlow can do, so feel free to file a PR to add requested
|
103
103
|
ops and test cases.
|
104
104
|
|
105
|
+
Other working samples can also be seen under tensor_stream/samples.
|
106
|
+
|
107
|
+
Samples that are used for development and are still being made to work can be found under test_samples
|
108
|
+
|
105
109
|
## Python to Ruby guide
|
106
110
|
|
107
111
|
Not all ops are available. Available ops are defined in lib/tensor_stream/ops.rb, corresponding gradients are found at lib/tensor_stream/math_gradients.
|
@@ -172,6 +176,36 @@ f = tf.matmul(a, b).breakpoint! { |tensor, a, b, result_value| binding.pry }
|
|
172
176
|
tf.session.run(f)
|
173
177
|
```
|
174
178
|
|
179
|
+
# Visualization
|
180
|
+
|
181
|
+
tensorstream does not support tensorboard yet, but a graphml generator is included:
|
182
|
+
|
183
|
+
```ruby
|
184
|
+
tf = TensorStream
|
185
|
+
a = tf.constant(1.0)
|
186
|
+
b = tf.constant(2.0)
|
187
|
+
result = a + b
|
188
|
+
sess = tf.session
|
189
|
+
sess.run(result)
|
190
|
+
|
191
|
+
File.write('gradients.graphml', TensorStream::Graphml.new.get_string(result)) # dump graph only
|
192
|
+
File.write('gradients.graphml', TensorStream::Graphml.new.get_string(result, sess)) # dump with values from session
|
193
|
+
```
|
194
|
+
|
195
|
+
the resulting graphml is designed to work with yED, after loading the graph change layout to "Flowchart" for best results
|
196
|
+
|
197
|
+
## Exporting to TensorFlow
|
198
|
+
|
199
|
+
Still in alpha but tensorstream supports TensorFlows as_graph_def serialization method:
|
200
|
+
|
201
|
+
```ruby
|
202
|
+
tf = TensorStream
|
203
|
+
a = tf.constant(1.0)
|
204
|
+
b = tf.constant(2.0)
|
205
|
+
result = a + b
|
206
|
+
File.write("model.pbtext", result.graph.as_graph_def)
|
207
|
+
```
|
208
|
+
|
175
209
|
## Roadmap
|
176
210
|
|
177
211
|
- Docs
|
data/lib/tensor_stream.rb
CHANGED
@@ -3,6 +3,7 @@ require 'deep_merge'
|
|
3
3
|
require 'matrix'
|
4
4
|
require 'concurrent'
|
5
5
|
require 'tensor_stream/helpers/op_helper'
|
6
|
+
require 'tensor_stream/helpers/string_helper'
|
6
7
|
require 'tensor_stream/initializer'
|
7
8
|
require 'tensor_stream/graph_keys'
|
8
9
|
require 'tensor_stream/types'
|
@@ -18,8 +19,11 @@ require 'tensor_stream/control_flow'
|
|
18
19
|
require 'tensor_stream/trainer'
|
19
20
|
require 'tensor_stream/nn/nn_ops'
|
20
21
|
require 'tensor_stream/evaluator/evaluator'
|
22
|
+
require 'tensor_stream/graph_serializers/serializer'
|
21
23
|
require 'tensor_stream/graph_serializers/pbtext'
|
22
24
|
require 'tensor_stream/graph_serializers/graphml'
|
25
|
+
require 'tensor_stream/math_gradients'
|
26
|
+
|
23
27
|
# require 'tensor_stream/libraries/layers'
|
24
28
|
require 'tensor_stream/monkey_patches/integer'
|
25
29
|
require 'tensor_stream/ops'
|
@@ -150,8 +154,8 @@ module TensorStream
|
|
150
154
|
end
|
151
155
|
end
|
152
156
|
|
153
|
-
def self.group(inputs)
|
154
|
-
TensorStream::ControlFlow.new(:group, inputs)
|
157
|
+
def self.group(inputs, name: nil)
|
158
|
+
TensorStream::ControlFlow.new(:group, inputs, nil, name: name)
|
155
159
|
end
|
156
160
|
|
157
161
|
def self.get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
@@ -196,6 +200,6 @@ module TensorStream
|
|
196
200
|
return input unless input.is_a?(Tensor)
|
197
201
|
return input if input.data_type.nil?
|
198
202
|
|
199
|
-
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.
|
203
|
+
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.include?(input.data_type.to_sym)
|
200
204
|
end
|
201
205
|
end
|
@@ -4,13 +4,12 @@ module TensorStream
|
|
4
4
|
attr_accessor :ops
|
5
5
|
|
6
6
|
def initialize(flow_type, items, ops = nil, options = {})
|
7
|
-
|
7
|
+
setup_initial_state(options)
|
8
8
|
|
9
9
|
@operation = :"flow_#{flow_type}"
|
10
10
|
@items = items
|
11
11
|
@name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
|
12
12
|
@ops = ops
|
13
|
-
@source = format_source(caller_locations)
|
14
13
|
@shape = TensorShape.new([items.size])
|
15
14
|
@graph.add_node(self)
|
16
15
|
end
|
@@ -1,6 +1,27 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# varoius utility functions for array processing
|
3
3
|
module ArrayOpsHelper
|
4
|
+
def slice_tensor(input, start, size)
|
5
|
+
start_index = start.shift
|
6
|
+
dimen_size = start_index + size.shift
|
7
|
+
|
8
|
+
input[start_index...dimen_size].collect do |item|
|
9
|
+
if item.is_a?(Array)
|
10
|
+
slice_tensor(item, start.dup, size.dup)
|
11
|
+
else
|
12
|
+
item
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def truncate(input, target_shape)
|
18
|
+
rank = get_rank(input)
|
19
|
+
return input if rank.zero?
|
20
|
+
|
21
|
+
start = Array.new(rank) { 0 }
|
22
|
+
slice_tensor(input, start, target_shape)
|
23
|
+
end
|
24
|
+
|
4
25
|
def broadcast(input_a, input_b)
|
5
26
|
sa = shape_eval(input_a)
|
6
27
|
sb = shape_eval(input_b)
|
@@ -24,7 +45,7 @@ module TensorStream
|
|
24
45
|
input_b = broadcast_dimensions(input_b, target_shape)
|
25
46
|
else
|
26
47
|
target_shape = shape_diff(sb, sa)
|
27
|
-
raise "Incompatible shapes for op #{shape_eval(input_a)} vs #{shape_eval(
|
48
|
+
raise "Incompatible shapes for op #{shape_eval(input_a)} vs #{shape_eval(input_b)}" if target_shape.nil?
|
28
49
|
|
29
50
|
input_a = broadcast_dimensions(input_a, target_shape)
|
30
51
|
end
|
@@ -52,7 +73,7 @@ module TensorStream
|
|
52
73
|
end
|
53
74
|
|
54
75
|
# handle 2 tensor math operations
|
55
|
-
def vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false)
|
76
|
+
def vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false, safe = true)
|
56
77
|
if get_rank(vector) < get_rank(vector2) # upgrade rank of A
|
57
78
|
duplicated = Array.new(vector2.size) do
|
58
79
|
vector
|
@@ -65,6 +86,10 @@ module TensorStream
|
|
65
86
|
vector.each_with_index.collect do |item, index|
|
66
87
|
next vector_op(item, vector2, op, switch) if item.is_a?(Array) && get_rank(vector) > get_rank(vector2)
|
67
88
|
|
89
|
+
if safe && vector2.is_a?(Array)
|
90
|
+
next nil if vector2.size != 1 && index >= vector2.size
|
91
|
+
end
|
92
|
+
|
68
93
|
z = if vector2.is_a?(Array)
|
69
94
|
if index < vector2.size
|
70
95
|
vector2[index]
|
@@ -81,7 +106,7 @@ module TensorStream
|
|
81
106
|
else
|
82
107
|
switch ? op.call(z, item) : op.call(item, z)
|
83
108
|
end
|
84
|
-
end
|
109
|
+
end.compact
|
85
110
|
end
|
86
111
|
|
87
112
|
def shape_diff(shape_a, shape_b)
|
@@ -96,5 +121,21 @@ module TensorStream
|
|
96
121
|
s - reversed_b[index]
|
97
122
|
end.reverse
|
98
123
|
end
|
124
|
+
|
125
|
+
def tile_arr(input, dimen, multiples)
|
126
|
+
t = multiples[dimen]
|
127
|
+
if dimen == multiples.size - 1
|
128
|
+
return nil if t.zero?
|
129
|
+
input * t # ruby array dup
|
130
|
+
else
|
131
|
+
new_arr = input.collect do |sub|
|
132
|
+
tile_arr(sub, dimen + 1, multiples)
|
133
|
+
end.compact
|
134
|
+
|
135
|
+
return nil if new_arr.empty?
|
136
|
+
|
137
|
+
new_arr * t
|
138
|
+
end
|
139
|
+
end
|
99
140
|
end
|
100
141
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
require 'tensor_stream/evaluator/operation_helpers/random_gaussian'
|
2
2
|
require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
|
3
|
-
require 'tensor_stream/
|
3
|
+
require 'tensor_stream/evaluator/operation_helpers/math_helper'
|
4
4
|
require 'distribution'
|
5
5
|
|
6
6
|
module TensorStream
|
@@ -28,6 +28,7 @@ module TensorStream
|
|
28
28
|
|
29
29
|
include TensorStream::OpHelper
|
30
30
|
include TensorStream::ArrayOpsHelper
|
31
|
+
include TensorStream::MathHelper
|
31
32
|
|
32
33
|
def initialize(session, context, thread_pool: nil, log_intermediates: false)
|
33
34
|
@session = session
|
@@ -179,7 +180,9 @@ module TensorStream
|
|
179
180
|
when :exp
|
180
181
|
call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) })
|
181
182
|
when :sigmoid
|
182
|
-
call_op(:sigmoid, a, child_context, ->(t, _b) {
|
183
|
+
call_op(:sigmoid, a, child_context, ->(t, _b) { sigmoid(t) })
|
184
|
+
when :sigmoid_grad
|
185
|
+
call_vector_op(:sigmoid_grad, a, b, child_context, ->(t, u) { u * sigmoid(t) * (1 - sigmoid(t)) })
|
183
186
|
when :sqrt
|
184
187
|
call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) })
|
185
188
|
when :square
|
@@ -208,7 +211,7 @@ module TensorStream
|
|
208
211
|
random = _get_randomizer(tensor, seed)
|
209
212
|
|
210
213
|
shape = tensor.options[:shape] || tensor.shape.shape
|
211
|
-
fan_in, fan_out = if shape.size
|
214
|
+
fan_in, fan_out = if shape.size.zero?
|
212
215
|
[1, 1]
|
213
216
|
elsif shape.size == 1
|
214
217
|
[1, shape[0]]
|
@@ -235,7 +238,7 @@ module TensorStream
|
|
235
238
|
when :assign_sub
|
236
239
|
tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(t, u) { t - u })
|
237
240
|
tensor.items[0].value
|
238
|
-
when :
|
241
|
+
when :mean
|
239
242
|
c = fp_type?(tensor.data_type) ? 0.0 : 0
|
240
243
|
func = lambda do |arr|
|
241
244
|
return c if arr.nil?
|
@@ -249,7 +252,7 @@ module TensorStream
|
|
249
252
|
end
|
250
253
|
|
251
254
|
reduction(child_context, tensor, func)
|
252
|
-
when :
|
255
|
+
when :sum
|
253
256
|
c = fp_type?(tensor.data_type) ? 0.0 : 0
|
254
257
|
func = lambda do |arr|
|
255
258
|
reduced_val = arr[0]
|
@@ -260,7 +263,10 @@ module TensorStream
|
|
260
263
|
end
|
261
264
|
|
262
265
|
reduction(child_context, tensor, func)
|
263
|
-
when :
|
266
|
+
when :tanh_grad
|
267
|
+
x = complete_eval(a, child_context)
|
268
|
+
call_op(:tanh_grad, x, child_context, ->(t, _b) { 1 - Math.tanh(t) * Math.tanh(t) })
|
269
|
+
when :prod
|
264
270
|
c = fp_type?(tensor.data_type) ? 1.0 : 1
|
265
271
|
func = lambda do |arr|
|
266
272
|
return c if arr.nil?
|
@@ -342,7 +348,15 @@ module TensorStream
|
|
342
348
|
func.call
|
343
349
|
else
|
344
350
|
shape = [shape.to_i] unless shape.is_a?(Array)
|
345
|
-
|
351
|
+
|
352
|
+
cache_key = "#{tensor.operation}_#{shape.to_s}"
|
353
|
+
if @context[:_cache].key?(cache_key)
|
354
|
+
return @context[:_cache][cache_key]
|
355
|
+
else
|
356
|
+
generate_vector(shape, generator: func).tap do |v|
|
357
|
+
@context[:_cache][cache_key] = v
|
358
|
+
end
|
359
|
+
end
|
346
360
|
end
|
347
361
|
when :shape
|
348
362
|
input = complete_eval(a, child_context)
|
@@ -374,6 +388,10 @@ module TensorStream
|
|
374
388
|
a = complete_eval(a, child_context)
|
375
389
|
b = complete_eval(b, child_context)
|
376
390
|
broadcast(a, b)
|
391
|
+
when :truncate
|
392
|
+
a = complete_eval(a, child_context)
|
393
|
+
b = complete_eval(b, child_context)
|
394
|
+
truncate(a, b)
|
377
395
|
when :identity
|
378
396
|
complete_eval(a, child_context)
|
379
397
|
when :print
|
@@ -390,6 +408,8 @@ module TensorStream
|
|
390
408
|
arr = complete_eval(a, child_context)
|
391
409
|
new_shape = complete_eval(b, child_context)
|
392
410
|
|
411
|
+
arr = [arr] unless arr.is_a?(Array)
|
412
|
+
|
393
413
|
flat_arr = arr.flatten
|
394
414
|
return flat_arr[0] if new_shape.size.zero? && flat_arr.size == 1
|
395
415
|
|
@@ -411,6 +431,28 @@ module TensorStream
|
|
411
431
|
b = complete_eval(b, child_context)
|
412
432
|
|
413
433
|
get_broadcast_gradient_args(a, b)
|
434
|
+
when :reduced_shape
|
435
|
+
input_shape = complete_eval(a, child_context)
|
436
|
+
axes = complete_eval(b, child_context)
|
437
|
+
|
438
|
+
return [] if axes.nil? # reduce to scalar
|
439
|
+
axes = [ axes ] unless axes.is_a?(Array)
|
440
|
+
return input_shape if axes.empty?
|
441
|
+
|
442
|
+
axes.each do |dimen|
|
443
|
+
input_shape[dimen] = 1
|
444
|
+
end
|
445
|
+
input_shape
|
446
|
+
when :tile
|
447
|
+
input = complete_eval(a, child_context)
|
448
|
+
multiples = complete_eval(b, child_context)
|
449
|
+
|
450
|
+
rank = get_rank(input)
|
451
|
+
raise '1D or higher tensor required' if rank.zero?
|
452
|
+
raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
|
453
|
+
|
454
|
+
tile = tile_arr(input, 0, multiples)
|
455
|
+
tile.nil? ? [] : tile
|
414
456
|
else
|
415
457
|
raise "unknown op #{tensor.operation}"
|
416
458
|
end.tap do |result|
|
@@ -437,18 +479,21 @@ module TensorStream
|
|
437
479
|
rescue StandardError => e
|
438
480
|
puts e.message
|
439
481
|
puts e.backtrace.join("\n")
|
482
|
+
|
440
483
|
shape_a = a.shape.shape if a
|
441
484
|
shape_b = b.shape.shape if b
|
442
485
|
dtype_a = a.data_type if a
|
443
486
|
dtype_b = b.data_type if b
|
444
487
|
a = complete_eval(a, child_context)
|
445
488
|
b = complete_eval(b, child_context)
|
446
|
-
puts "name: #{tensor.given_name}"
|
447
|
-
puts "op: #{tensor.to_math(true, 1)}"
|
448
|
-
puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
449
|
-
puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
450
|
-
dump_intermediates if @log_intermediates
|
451
|
-
|
489
|
+
# puts "name: #{tensor.given_name}"
|
490
|
+
# # puts "op: #{tensor.to_math(true, 1)}"
|
491
|
+
# puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
492
|
+
# puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
493
|
+
# dump_intermediates if @log_intermediates
|
494
|
+
# File.write('/home/jedld/workspace/tensor_stream/samples/error.graphml', TensorStream::Graphml.new.get_string(tensor, @session))
|
495
|
+
|
496
|
+
# File.write('/Users/josephemmanueldayo/workspace/gradients.graphml', TensorStream::Graphml.new.get_string(tensor, @session))
|
452
497
|
raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true,1)} defined at #{tensor.source}"
|
453
498
|
end
|
454
499
|
|
@@ -504,7 +549,7 @@ module TensorStream
|
|
504
549
|
|
505
550
|
def reduction(child_context, tensor, func)
|
506
551
|
val = complete_eval(tensor.items[0], child_context)
|
507
|
-
axis = complete_eval(tensor.
|
552
|
+
axis = complete_eval(tensor.items[1], child_context)
|
508
553
|
keep_dims = complete_eval(tensor.options[:keepdims], child_context)
|
509
554
|
rank = get_rank(val)
|
510
555
|
return val if axis && axis.is_a?(Array) && axis.empty?
|
@@ -547,19 +592,6 @@ module TensorStream
|
|
547
592
|
end
|
548
593
|
end
|
549
594
|
|
550
|
-
def slice_tensor(input, start, size)
|
551
|
-
start_index = start.shift
|
552
|
-
dimen_size = start_index + size.shift
|
553
|
-
|
554
|
-
input[start_index...dimen_size].collect do |item|
|
555
|
-
if item.is_a?(Array)
|
556
|
-
slice_tensor(item, start.dup, size.dup)
|
557
|
-
else
|
558
|
-
item
|
559
|
-
end
|
560
|
-
end
|
561
|
-
end
|
562
|
-
|
563
595
|
def matmul_const_transform(mat, mat_b, tensor)
|
564
596
|
if !mat.is_a?(Array)
|
565
597
|
compat_shape = shape_eval(mat_b).reverse
|
@@ -591,15 +623,17 @@ module TensorStream
|
|
591
623
|
raise FullEvalNotPossible.new, "full eval not possible for #{a.name}" if eval_a.is_a?(Tensor) || eval_b.is_a?(Tensor)
|
592
624
|
|
593
625
|
# ruby scalar
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
626
|
+
eval_a, eval_b = broadcast(eval_a, eval_b)
|
627
|
+
vector_op(eval_a, eval_b, op)
|
628
|
+
# if get_rank(eval_a).zero?
|
629
|
+
# if get_rank(eval_b).zero?
|
630
|
+
# op.call(eval_a, eval_b)
|
631
|
+
# else
|
632
|
+
# vector_op(eval_b, eval_a, op, true)
|
633
|
+
# end
|
634
|
+
# else
|
635
|
+
# vector_op(eval_a, eval_b, op)
|
636
|
+
# end
|
603
637
|
end
|
604
638
|
|
605
639
|
# determine possible reduction axis to be used
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -62,12 +62,14 @@ module TensorStream
|
|
62
62
|
raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
|
63
63
|
|
64
64
|
node.name = if @nodes[node.name]
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
65
|
+
uniqunify(node.name)
|
66
|
+
else
|
67
|
+
node.name
|
68
|
+
end
|
69
69
|
|
70
70
|
@nodes[node.name] = node
|
71
|
+
|
72
|
+
node.send(:propagate_outputs)
|
71
73
|
node.send(:propagate_consumer, node)
|
72
74
|
node.value = node.eval if @eager_execution
|
73
75
|
end
|
@@ -89,14 +91,12 @@ module TensorStream
|
|
89
91
|
scope = _variable_scope
|
90
92
|
|
91
93
|
raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse
|
92
|
-
|
93
94
|
return @nodes[node.name] if @nodes[node.name]
|
94
|
-
|
95
95
|
raise "shape is not declared for #{node.name}" if node.shape.nil?
|
96
96
|
|
97
97
|
if !options[:collections].nil? && !options[:collections].empty?
|
98
98
|
options[:collections] = [options[:collections]] unless options[:collections].is_a?(Array)
|
99
|
-
options[:collections].each { |coll| add_to_collection(coll, node) }
|
99
|
+
options[:collections].each { |coll| add_to_collection(coll, node) }
|
100
100
|
end
|
101
101
|
|
102
102
|
add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
|
@@ -162,6 +162,14 @@ module TensorStream
|
|
162
162
|
graph_thread_storage[:current_scope].join('/')
|
163
163
|
end
|
164
164
|
|
165
|
+
def as_graph_def
|
166
|
+
TensorStream::Pbtext.new.get_string(self)
|
167
|
+
end
|
168
|
+
|
169
|
+
def graph_def_versions
|
170
|
+
"producer: 26"
|
171
|
+
end
|
172
|
+
|
165
173
|
protected
|
166
174
|
|
167
175
|
def _variable_scope
|