tensor_stream 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.circleci/config.yml +57 -0
- data/README.md +2 -0
- data/lib/tensor_stream.rb +74 -10
- data/lib/tensor_stream/control_flow.rb +2 -2
- data/lib/tensor_stream/device.rb +8 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +104 -40
- data/lib/tensor_stream/graph.rb +53 -5
- data/lib/tensor_stream/graph_keys.rb +1 -0
- data/lib/tensor_stream/graph_serializers/graphml.rb +91 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +71 -0
- data/lib/tensor_stream/helpers/op_helper.rb +7 -1
- data/lib/tensor_stream/initializer.rb +16 -0
- data/lib/tensor_stream/math_gradients.rb +37 -30
- data/lib/tensor_stream/nn/nn_ops.rb +17 -0
- data/lib/tensor_stream/operation.rb +92 -31
- data/lib/tensor_stream/ops.rb +87 -53
- data/lib/tensor_stream/placeholder.rb +1 -1
- data/lib/tensor_stream/session.rb +26 -4
- data/lib/tensor_stream/tensor.rb +29 -33
- data/lib/tensor_stream/tensor_shape.rb +52 -2
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +1 -4
- data/lib/tensor_stream/variable.rb +23 -7
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/logistic_regression.rb +76 -0
- data/tensor_stream.gemspec +3 -0
- metadata +50 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f0d4f44991feb4af9b68cd9cd5ee6cd202bb2670
|
4
|
+
data.tar.gz: 64801dfe3a55d76d0d7535d24867f94fc4dd0861
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d68686258355308d9dbdd2fbd632a02fd98668765ab0ef3161494d28d32263aa6b88f4dbdc05da5b473c9f20cf8466bc64e93834442559cd2835a26ffe771c1c
|
7
|
+
data.tar.gz: 7517729b8aad2ffc2e557c3f55c09bf149fa84c4757b979842a3e56adb3e121331a73fc07667c5f8f639dd3c20b1c2f2f4b55311cf930f9d80cdffcea6a886e1
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Ruby CircleCI 2.0 configuration file
|
2
|
+
#
|
3
|
+
# Check https://circleci.com/docs/2.0/language-ruby/ for more details
|
4
|
+
#
|
5
|
+
version: 2
|
6
|
+
jobs:
|
7
|
+
build:
|
8
|
+
docker:
|
9
|
+
# specify the version you desire here
|
10
|
+
- image: circleci/ruby:2.4.1-node-browsers
|
11
|
+
|
12
|
+
# Specify service dependencies here if necessary
|
13
|
+
# CircleCI maintains a library of pre-built images
|
14
|
+
# documented at https://circleci.com/docs/2.0/circleci-images/
|
15
|
+
# - image: circleci/postgres:9.4
|
16
|
+
|
17
|
+
working_directory: ~/repo
|
18
|
+
|
19
|
+
steps:
|
20
|
+
- checkout
|
21
|
+
|
22
|
+
# Download and cache dependencies
|
23
|
+
- restore_cache:
|
24
|
+
keys:
|
25
|
+
- v1-dependencies-{{ checksum "Gemfile.lock" }}
|
26
|
+
# fallback to using the latest cache if no exact match is found
|
27
|
+
- v1-dependencies-
|
28
|
+
|
29
|
+
- run:
|
30
|
+
name: install dependencies
|
31
|
+
command: |
|
32
|
+
bundle install --jobs=4 --retry=3 --path vendor/bundle
|
33
|
+
|
34
|
+
- save_cache:
|
35
|
+
paths:
|
36
|
+
- ./vendor/bundle
|
37
|
+
key: v1-dependencies-{{ checksum "Gemfile.lock" }}
|
38
|
+
|
39
|
+
# run tests!
|
40
|
+
- run:
|
41
|
+
name: run tests
|
42
|
+
command: |
|
43
|
+
mkdir /tmp/test-results
|
44
|
+
TEST_FILES="$(circleci tests glob "spec/**/*_spec.rb" | circleci tests split --split-by=timings)"
|
45
|
+
|
46
|
+
bundle exec rspec -r rspec_junit_formatter --format progress \
|
47
|
+
--format RspecJunitFormatter \
|
48
|
+
--out /tmp/test-results/rspec.xml \
|
49
|
+
--format progress \
|
50
|
+
$TEST_FILES
|
51
|
+
|
52
|
+
# collect reports
|
53
|
+
- store_test_results:
|
54
|
+
path: /tmp/test-results
|
55
|
+
- store_artifacts:
|
56
|
+
path: /tmp/test-results
|
57
|
+
destination: test-results
|
data/README.md
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
[](https://badge.fury.io/rb/tensor_stream)
|
2
2
|
|
3
|
+
[](https://circleci.com/gh/jedld/tensor_stream)
|
4
|
+
|
3
5
|
# TensorStream
|
4
6
|
|
5
7
|
A reimplementation of TensorFlow for ruby. This is a ground up implementation with no dependency on TensorFlow. Effort has been made to make the programming style as near to TensorFlow as possible, comes with a pure ruby evaluator by default as well with support for an opencl evaluator.
|
data/lib/tensor_stream.rb
CHANGED
@@ -3,9 +3,11 @@ require 'deep_merge'
|
|
3
3
|
require 'matrix'
|
4
4
|
require 'concurrent'
|
5
5
|
require 'tensor_stream/helpers/op_helper'
|
6
|
+
require 'tensor_stream/initializer'
|
6
7
|
require 'tensor_stream/graph_keys'
|
7
8
|
require 'tensor_stream/types'
|
8
9
|
require 'tensor_stream/graph'
|
10
|
+
require 'tensor_stream/device'
|
9
11
|
require 'tensor_stream/session'
|
10
12
|
require 'tensor_stream/tensor_shape'
|
11
13
|
require 'tensor_stream/tensor'
|
@@ -16,6 +18,8 @@ require 'tensor_stream/control_flow'
|
|
16
18
|
require 'tensor_stream/trainer'
|
17
19
|
require 'tensor_stream/nn/nn_ops'
|
18
20
|
require 'tensor_stream/evaluator/evaluator'
|
21
|
+
require 'tensor_stream/graph_serializers/pbtext'
|
22
|
+
require 'tensor_stream/graph_serializers/graphml'
|
19
23
|
# require 'tensor_stream/libraries/layers'
|
20
24
|
require 'tensor_stream/monkey_patches/integer'
|
21
25
|
require 'tensor_stream/ops'
|
@@ -29,6 +33,10 @@ module TensorStream
|
|
29
33
|
Types.float32
|
30
34
|
end
|
31
35
|
|
36
|
+
def self.graph
|
37
|
+
TensorStream::Graph.new
|
38
|
+
end
|
39
|
+
|
32
40
|
def self.get_default_graph
|
33
41
|
TensorStream::Graph.get_default_graph
|
34
42
|
end
|
@@ -49,22 +57,60 @@ module TensorStream
|
|
49
57
|
TensorStream::Graph.get_default_graph.executing_eagerly?
|
50
58
|
end
|
51
59
|
|
52
|
-
def self.variable(value,
|
60
|
+
def self.variable(value, name: nil, initializer: nil, graph: nil, dtype: nil, trainable: true)
|
53
61
|
common_options = {
|
54
|
-
initializer: Operation.new(:assign, nil, value),
|
55
|
-
name:
|
62
|
+
initializer: initializer || Operation.new(:assign, nil, value),
|
63
|
+
name: name,
|
64
|
+
graph: graph,
|
65
|
+
dtype: dtype,
|
66
|
+
trainable: trainable
|
56
67
|
}
|
57
68
|
if value.is_a?(String)
|
58
|
-
TensorStream::Variable.new(
|
69
|
+
TensorStream::Variable.new(dtype || :string, 0, [], common_options)
|
59
70
|
elsif value.is_a?(Integer)
|
60
|
-
TensorStream::Variable.new(
|
71
|
+
TensorStream::Variable.new(dtype || :int32, 0, [], common_options)
|
61
72
|
elsif value.is_a?(Float)
|
62
|
-
TensorStream::Variable.new(
|
73
|
+
TensorStream::Variable.new(dtype || :float32, 0, [], common_options)
|
63
74
|
else
|
64
|
-
TensorStream::Variable.new(
|
75
|
+
TensorStream::Variable.new(dtype || :float32, 0, nil, common_options)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
def self.variable_scope(scope = nil, reuse: nil, initializer: nil)
|
80
|
+
Thread.current[:tensor_stream_variable_scope] ||= []
|
81
|
+
Thread.current[:tensor_stream_variable_scope] << OpenStruct.new(name: scope, reuse: reuse, initializer: initializer)
|
82
|
+
scope_name = __v_scope_name
|
83
|
+
begin
|
84
|
+
if block_given?
|
85
|
+
TensorStream.get_default_graph.name_scope(scope) do
|
86
|
+
yield(scope_name)
|
87
|
+
end
|
88
|
+
end
|
89
|
+
ensure
|
90
|
+
Thread.current[:tensor_stream_variable_scope].pop
|
65
91
|
end
|
66
92
|
end
|
67
93
|
|
94
|
+
def self.name_scope(name, default: nil, values: nil)
|
95
|
+
if values
|
96
|
+
graph_count = values.select { |v| v.is_a?(Tensor) }.map(&:graph).map(&:object_id).uniq.size
|
97
|
+
raise "values are not on the same graph" if graph_count > 1
|
98
|
+
end
|
99
|
+
|
100
|
+
get_default_graph.name_scope(name || default) do |scope|
|
101
|
+
yield scope if block_given?
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def self.get_variable_scope
|
106
|
+
return nil unless Thread.current[:tensor_stream_variable_scope]
|
107
|
+
__v_scope_name
|
108
|
+
end
|
109
|
+
|
110
|
+
def self.__v_scope_name
|
111
|
+
Thread.current[:tensor_stream_variable_scope].map(&:name).compact.reject(&:empty?).join('/')
|
112
|
+
end
|
113
|
+
|
68
114
|
def self.session(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
|
69
115
|
session = TensorStream::Session.new(evaluator, thread_pool_class: thread_pool_class)
|
70
116
|
yield session if block_given?
|
@@ -108,8 +154,8 @@ module TensorStream
|
|
108
154
|
TensorStream::ControlFlow.new(:group, inputs)
|
109
155
|
end
|
110
156
|
|
111
|
-
def self.get_variable(name,
|
112
|
-
TensorStream::Variable.new(
|
157
|
+
def self.get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
158
|
+
TensorStream::Variable.new(dtype || :float32, nil, shape, collections: collections, name: name, initializer: initializer, trainable: trainable)
|
113
159
|
end
|
114
160
|
|
115
161
|
def self.get_collection(name, options = {})
|
@@ -128,10 +174,28 @@ module TensorStream
|
|
128
174
|
TensorStream::Trainer
|
129
175
|
end
|
130
176
|
|
177
|
+
def self.trainable_variables
|
178
|
+
TensorStream.get_default_graph.get_collection(TensorStream::GraphKeys::TRAINABLE_VARIABLES)
|
179
|
+
end
|
180
|
+
|
181
|
+
def self.set_random_seed(seed)
|
182
|
+
TensorStream.get_default_graph.random_seed = seed
|
183
|
+
end
|
184
|
+
|
185
|
+
def self.convert_to_tensor(value, dtype: nil, name: nil, preferred_dtype: nil)
|
186
|
+
return convert_to_tensor(value.call) if value.is_a?(Proc)
|
187
|
+
|
188
|
+
if !value.is_a?(Tensor)
|
189
|
+
i_cons(value, dtype: dtype || Tensor.detect_type(value), name: name)
|
190
|
+
else
|
191
|
+
value
|
192
|
+
end
|
193
|
+
end
|
194
|
+
|
131
195
|
def self.check_allowed_types(input, types)
|
132
196
|
return input unless input.is_a?(Tensor)
|
133
197
|
return input if input.data_type.nil?
|
134
198
|
|
135
|
-
raise "Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.map(&:to_sym).include?(input.data_type)
|
199
|
+
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.map(&:to_sym).include?(input.data_type)
|
136
200
|
end
|
137
201
|
end
|
@@ -8,10 +8,10 @@ module TensorStream
|
|
8
8
|
|
9
9
|
@operation = :"flow_#{flow_type}"
|
10
10
|
@items = items
|
11
|
-
@name = set_name
|
11
|
+
@name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
|
12
12
|
@ops = ops
|
13
13
|
@source = format_source(caller_locations)
|
14
|
-
|
14
|
+
@shape = TensorShape.new([items.size])
|
15
15
|
@graph.add_node(self)
|
16
16
|
end
|
17
17
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
require 'tensor_stream/evaluator/operation_helpers/random_gaussian'
|
2
2
|
require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
|
3
3
|
require 'tensor_stream/math_gradients'
|
4
|
+
require 'distribution'
|
4
5
|
|
5
6
|
module TensorStream
|
6
7
|
module Evaluator
|
@@ -28,19 +29,22 @@ module TensorStream
|
|
28
29
|
include TensorStream::OpHelper
|
29
30
|
include TensorStream::ArrayOpsHelper
|
30
31
|
|
31
|
-
def initialize(session, context, thread_pool: nil)
|
32
|
+
def initialize(session, context, thread_pool: nil, log_intermediates: false)
|
32
33
|
@session = session
|
33
34
|
@context = context
|
35
|
+
@log_intermediates = log_intermediates
|
34
36
|
@retain = context[:retain] || []
|
35
37
|
@thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
|
38
|
+
|
39
|
+
@context[:compute_history] = [] if log_intermediates
|
36
40
|
end
|
37
41
|
|
38
42
|
def run(tensor, execution_context)
|
39
43
|
return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)
|
40
44
|
|
41
45
|
return tensor if retain.include?(tensor) # if var is in retain don't eval to value
|
42
|
-
|
43
|
-
tensor = tensor.call
|
46
|
+
|
47
|
+
tensor = tensor.call if tensor.is_a?(Proc)
|
44
48
|
|
45
49
|
child_context = execution_context.dup
|
46
50
|
res = if tensor.is_a?(Operation)
|
@@ -71,7 +75,9 @@ module TensorStream
|
|
71
75
|
protected
|
72
76
|
|
73
77
|
def eval_variable(tensor, child_context)
|
74
|
-
|
78
|
+
if tensor.value.nil?
|
79
|
+
raise "variable #{tensor.name} not initalized"
|
80
|
+
end
|
75
81
|
eval_tensor(tensor.value, child_context).tap do |val|
|
76
82
|
child_context[:returns] ||= {}
|
77
83
|
child_context[:returns][:vars] ||= []
|
@@ -86,6 +92,8 @@ module TensorStream
|
|
86
92
|
b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
|
87
93
|
|
88
94
|
case tensor.operation
|
95
|
+
when :const
|
96
|
+
complete_eval(a, child_context)
|
89
97
|
when :argmax
|
90
98
|
a = complete_eval(a, child_context)
|
91
99
|
axis = tensor.options[:axis] || 0
|
@@ -150,6 +158,8 @@ module TensorStream
|
|
150
158
|
when :concat
|
151
159
|
values = complete_eval(a, child_context)
|
152
160
|
concat_array(values, tensor.options[:axis])
|
161
|
+
when :round
|
162
|
+
call_op(:round, a, child_context, ->(t, _b) { t.round })
|
153
163
|
when :abs
|
154
164
|
call_op(:abs, a, child_context, ->(t, _b) { t.abs })
|
155
165
|
when :tanh
|
@@ -162,27 +172,57 @@ module TensorStream
|
|
162
172
|
call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
|
163
173
|
when :cos
|
164
174
|
call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
|
175
|
+
when :log1p
|
176
|
+
call_op(:log1p, a, child_context, ->(t, _b) { Distribution::MathExtension::Log.log1p(t) })
|
165
177
|
when :log
|
166
178
|
call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t) })
|
167
179
|
when :exp
|
168
180
|
call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) })
|
181
|
+
when :sigmoid
|
182
|
+
call_op(:sigmoid, a, child_context, ->(t, _b) { 1 / (1 + Math.exp(-t)) })
|
169
183
|
when :sqrt
|
170
184
|
call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) })
|
171
185
|
when :square
|
172
186
|
call_op(:square, a, child_context, ->(t, _b) { t * t })
|
187
|
+
when :reciprocal
|
188
|
+
call_op(:square, a, child_context, ->(t, _b) { 1 / t })
|
173
189
|
when :stop_gradient
|
174
190
|
run(a, child_context)
|
175
191
|
when :random_uniform
|
176
192
|
maxval = tensor.options.fetch(:maxval, 1)
|
177
193
|
minval = tensor.options.fetch(:minval, 0)
|
194
|
+
seed = tensor.options[:seed]
|
178
195
|
|
179
|
-
|
180
|
-
|
196
|
+
random = _get_randomizer(tensor, seed)
|
197
|
+
generator = -> { random.rand * (maxval - minval) + minval }
|
198
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
199
|
+
generate_vector(shape, generator: generator)
|
181
200
|
when :random_normal
|
182
|
-
|
201
|
+
random = _get_randomizer(tensor, seed)
|
202
|
+
r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
|
203
|
+
random = _get_randomizer(tensor, seed)
|
183
204
|
generator = -> { r.rand }
|
184
|
-
|
185
|
-
generate_vector(
|
205
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
206
|
+
generate_vector(shape, generator: generator)
|
207
|
+
when :glorot_uniform
|
208
|
+
random = _get_randomizer(tensor, seed)
|
209
|
+
|
210
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
211
|
+
fan_in, fan_out = if shape.size == 0
|
212
|
+
[1, 1]
|
213
|
+
elsif shape.size == 1
|
214
|
+
[1, shape[0]]
|
215
|
+
else
|
216
|
+
[shape[0], shape.last]
|
217
|
+
end
|
218
|
+
|
219
|
+
limit = Math.sqrt(6.0 / (fan_in + fan_out))
|
220
|
+
|
221
|
+
minval = -limit
|
222
|
+
maxval = limit
|
223
|
+
|
224
|
+
generator = -> { random.rand * (maxval - minval) + minval }
|
225
|
+
generate_vector(shape, generator: generator)
|
186
226
|
when :flow_group
|
187
227
|
tensor.items.collect { |item| run(item, child_context) }
|
188
228
|
when :assign
|
@@ -314,8 +354,8 @@ module TensorStream
|
|
314
354
|
rank_a = get_rank(matrix_a)
|
315
355
|
rank_b = get_rank(matrix_b)
|
316
356
|
|
317
|
-
raise "#{
|
318
|
-
raise "#{
|
357
|
+
raise "#{tensor.items[0].name} rank must be greater than 1" if rank_a < 2
|
358
|
+
raise "#{tensor.items[1].name} rank must be greater than 1" if rank_b < 2
|
319
359
|
|
320
360
|
matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
|
321
361
|
matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
|
@@ -353,9 +393,9 @@ module TensorStream
|
|
353
393
|
flat_arr = arr.flatten
|
354
394
|
return flat_arr[0] if new_shape.size.zero? && flat_arr.size == 1
|
355
395
|
|
356
|
-
new_shape = fix_inferred_elements(new_shape, flat_arr.size)
|
396
|
+
new_shape = TensorShape.fix_inferred_elements(new_shape, flat_arr.size)
|
357
397
|
|
358
|
-
reshape(flat_arr, new_shape)
|
398
|
+
TensorShape.reshape(flat_arr, new_shape)
|
359
399
|
when :pad
|
360
400
|
a = complete_eval(a, child_context)
|
361
401
|
p = complete_eval(tensor.options[:paddings], child_context)
|
@@ -380,19 +420,35 @@ module TensorStream
|
|
380
420
|
|
381
421
|
tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
|
382
422
|
end
|
423
|
+
if @log_intermediates
|
424
|
+
@context[:compute_history] << {
|
425
|
+
name: tensor.name,
|
426
|
+
type: tensor.data_type,
|
427
|
+
shape: shape_eval(result),
|
428
|
+
source: tensor.source,
|
429
|
+
description: tensor.to_math(true, 1),
|
430
|
+
value: result
|
431
|
+
}
|
432
|
+
end
|
383
433
|
@context[tensor.name] = result
|
384
434
|
end
|
385
435
|
rescue EvaluatorExcecutionException => e
|
386
436
|
raise e
|
387
437
|
rescue StandardError => e
|
438
|
+
puts e.message
|
439
|
+
puts e.backtrace.join("\n")
|
440
|
+
shape_a = a.shape.shape if a
|
441
|
+
shape_b = b.shape.shape if b
|
442
|
+
dtype_a = a.data_type if a
|
443
|
+
dtype_b = b.data_type if b
|
388
444
|
a = complete_eval(a, child_context)
|
389
445
|
b = complete_eval(b, child_context)
|
390
446
|
puts "name: #{tensor.given_name}"
|
391
447
|
puts "op: #{tensor.to_math(true, 1)}"
|
392
|
-
puts "A: #{a}" if a
|
393
|
-
puts "B: #{b}" if b
|
448
|
+
puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
449
|
+
puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
450
|
+
dump_intermediates if @log_intermediates
|
394
451
|
|
395
|
-
puts e.backtrace.join("\n")
|
396
452
|
raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true,1)} defined at #{tensor.source}"
|
397
453
|
end
|
398
454
|
|
@@ -515,30 +571,6 @@ module TensorStream
|
|
515
571
|
end
|
516
572
|
end
|
517
573
|
|
518
|
-
def fix_inferred_elements(shape, total_size)
|
519
|
-
return shape if shape.empty?
|
520
|
-
|
521
|
-
current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
|
522
|
-
inferred_size = total_size / current_size
|
523
|
-
shape.map { |s| s == -1 ? inferred_size : s }
|
524
|
-
end
|
525
|
-
|
526
|
-
def reshape(arr, new_shape)
|
527
|
-
return arr if new_shape.empty?
|
528
|
-
|
529
|
-
s = new_shape.shift
|
530
|
-
|
531
|
-
if new_shape.size.zero?
|
532
|
-
raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
|
533
|
-
return arr
|
534
|
-
end
|
535
|
-
|
536
|
-
dim = (arr.size / s)
|
537
|
-
arr.each_slice(dim).collect do |slice|
|
538
|
-
reshape(slice, new_shape.dup)
|
539
|
-
end
|
540
|
-
end
|
541
|
-
|
542
574
|
def call_op(op, a, child_context, func)
|
543
575
|
a = complete_eval(a, child_context)
|
544
576
|
process_function_op(a, child_context, func)
|
@@ -723,6 +755,38 @@ module TensorStream
|
|
723
755
|
generator.call
|
724
756
|
end
|
725
757
|
end
|
758
|
+
|
759
|
+
def _get_randomizer(tensor, seed)
|
760
|
+
if tensor.graph.random_seed && seed
|
761
|
+
Random.new(tensor.graph.random_seed ^ seed)
|
762
|
+
elsif tensor.graph.random_seed
|
763
|
+
@session.randomizer[tensor.graph.object_id] ||= Random.new(tensor.graph.random_seed)
|
764
|
+
@session.randomizer[tensor.graph.object_id]
|
765
|
+
elsif seed
|
766
|
+
@session.randomizer[tensor.operation] ||= Random.new(seed)
|
767
|
+
@session.randomizer[tensor.operation]
|
768
|
+
else
|
769
|
+
Random.new
|
770
|
+
end
|
771
|
+
end
|
772
|
+
|
773
|
+
def dump_intermediates
|
774
|
+
arr = []
|
775
|
+
arr << "============== start ==================="
|
776
|
+
@context[:compute_history].each_with_index do |history, index|
|
777
|
+
arr << "------------------------------------"
|
778
|
+
arr << history[:name]
|
779
|
+
arr << "#{history[:type]} #{history[:shape]}"
|
780
|
+
arr << history[:source]
|
781
|
+
arr << history[:description]
|
782
|
+
arr << ""
|
783
|
+
arr << history[:value].to_json
|
784
|
+
arr << "------------------------------------"
|
785
|
+
end
|
786
|
+
arr << "============== end ====================="
|
787
|
+
str = arr.join("\n")
|
788
|
+
File.write("/tmp/intermediates.txt", str)
|
789
|
+
end
|
726
790
|
end
|
727
791
|
end
|
728
792
|
end
|