tensor_stream 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.circleci/config.yml +57 -0
- data/README.md +2 -0
- data/lib/tensor_stream.rb +74 -10
- data/lib/tensor_stream/control_flow.rb +2 -2
- data/lib/tensor_stream/device.rb +8 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +104 -40
- data/lib/tensor_stream/graph.rb +53 -5
- data/lib/tensor_stream/graph_keys.rb +1 -0
- data/lib/tensor_stream/graph_serializers/graphml.rb +91 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +71 -0
- data/lib/tensor_stream/helpers/op_helper.rb +7 -1
- data/lib/tensor_stream/initializer.rb +16 -0
- data/lib/tensor_stream/math_gradients.rb +37 -30
- data/lib/tensor_stream/nn/nn_ops.rb +17 -0
- data/lib/tensor_stream/operation.rb +92 -31
- data/lib/tensor_stream/ops.rb +87 -53
- data/lib/tensor_stream/placeholder.rb +1 -1
- data/lib/tensor_stream/session.rb +26 -4
- data/lib/tensor_stream/tensor.rb +29 -33
- data/lib/tensor_stream/tensor_shape.rb +52 -2
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +1 -4
- data/lib/tensor_stream/variable.rb +23 -7
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/logistic_regression.rb +76 -0
- data/tensor_stream.gemspec +3 -0
- metadata +50 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f0d4f44991feb4af9b68cd9cd5ee6cd202bb2670
|
4
|
+
data.tar.gz: 64801dfe3a55d76d0d7535d24867f94fc4dd0861
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d68686258355308d9dbdd2fbd632a02fd98668765ab0ef3161494d28d32263aa6b88f4dbdc05da5b473c9f20cf8466bc64e93834442559cd2835a26ffe771c1c
|
7
|
+
data.tar.gz: 7517729b8aad2ffc2e557c3f55c09bf149fa84c4757b979842a3e56adb3e121331a73fc07667c5f8f639dd3c20b1c2f2f4b55311cf930f9d80cdffcea6a886e1
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Ruby CircleCI 2.0 configuration file
|
2
|
+
#
|
3
|
+
# Check https://circleci.com/docs/2.0/language-ruby/ for more details
|
4
|
+
#
|
5
|
+
version: 2
|
6
|
+
jobs:
|
7
|
+
build:
|
8
|
+
docker:
|
9
|
+
# specify the version you desire here
|
10
|
+
- image: circleci/ruby:2.4.1-node-browsers
|
11
|
+
|
12
|
+
# Specify service dependencies here if necessary
|
13
|
+
# CircleCI maintains a library of pre-built images
|
14
|
+
# documented at https://circleci.com/docs/2.0/circleci-images/
|
15
|
+
# - image: circleci/postgres:9.4
|
16
|
+
|
17
|
+
working_directory: ~/repo
|
18
|
+
|
19
|
+
steps:
|
20
|
+
- checkout
|
21
|
+
|
22
|
+
# Download and cache dependencies
|
23
|
+
- restore_cache:
|
24
|
+
keys:
|
25
|
+
- v1-dependencies-{{ checksum "Gemfile.lock" }}
|
26
|
+
# fallback to using the latest cache if no exact match is found
|
27
|
+
- v1-dependencies-
|
28
|
+
|
29
|
+
- run:
|
30
|
+
name: install dependencies
|
31
|
+
command: |
|
32
|
+
bundle install --jobs=4 --retry=3 --path vendor/bundle
|
33
|
+
|
34
|
+
- save_cache:
|
35
|
+
paths:
|
36
|
+
- ./vendor/bundle
|
37
|
+
key: v1-dependencies-{{ checksum "Gemfile.lock" }}
|
38
|
+
|
39
|
+
# run tests!
|
40
|
+
- run:
|
41
|
+
name: run tests
|
42
|
+
command: |
|
43
|
+
mkdir /tmp/test-results
|
44
|
+
TEST_FILES="$(circleci tests glob "spec/**/*_spec.rb" | circleci tests split --split-by=timings)"
|
45
|
+
|
46
|
+
bundle exec rspec -r rspec_junit_formatter --format progress \
|
47
|
+
--format RspecJunitFormatter \
|
48
|
+
--out /tmp/test-results/rspec.xml \
|
49
|
+
--format progress \
|
50
|
+
$TEST_FILES
|
51
|
+
|
52
|
+
# collect reports
|
53
|
+
- store_test_results:
|
54
|
+
path: /tmp/test-results
|
55
|
+
- store_artifacts:
|
56
|
+
path: /tmp/test-results
|
57
|
+
destination: test-results
|
data/README.md
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
[![Gem Version](https://badge.fury.io/rb/tensor_stream.svg)](https://badge.fury.io/rb/tensor_stream)
|
2
2
|
|
3
|
+
[![CircleCI](https://circleci.com/gh/jedld/tensor_stream.svg?style=svg)](https://circleci.com/gh/jedld/tensor_stream)
|
4
|
+
|
3
5
|
# TensorStream
|
4
6
|
|
5
7
|
A reimplementation of TensorFlow for ruby. This is a ground up implementation with no dependency on TensorFlow. Effort has been made to make the programming style as near to TensorFlow as possible, comes with a pure ruby evaluator by default as well with support for an opencl evaluator.
|
data/lib/tensor_stream.rb
CHANGED
@@ -3,9 +3,11 @@ require 'deep_merge'
|
|
3
3
|
require 'matrix'
|
4
4
|
require 'concurrent'
|
5
5
|
require 'tensor_stream/helpers/op_helper'
|
6
|
+
require 'tensor_stream/initializer'
|
6
7
|
require 'tensor_stream/graph_keys'
|
7
8
|
require 'tensor_stream/types'
|
8
9
|
require 'tensor_stream/graph'
|
10
|
+
require 'tensor_stream/device'
|
9
11
|
require 'tensor_stream/session'
|
10
12
|
require 'tensor_stream/tensor_shape'
|
11
13
|
require 'tensor_stream/tensor'
|
@@ -16,6 +18,8 @@ require 'tensor_stream/control_flow'
|
|
16
18
|
require 'tensor_stream/trainer'
|
17
19
|
require 'tensor_stream/nn/nn_ops'
|
18
20
|
require 'tensor_stream/evaluator/evaluator'
|
21
|
+
require 'tensor_stream/graph_serializers/pbtext'
|
22
|
+
require 'tensor_stream/graph_serializers/graphml'
|
19
23
|
# require 'tensor_stream/libraries/layers'
|
20
24
|
require 'tensor_stream/monkey_patches/integer'
|
21
25
|
require 'tensor_stream/ops'
|
@@ -29,6 +33,10 @@ module TensorStream
|
|
29
33
|
Types.float32
|
30
34
|
end
|
31
35
|
|
36
|
+
def self.graph
|
37
|
+
TensorStream::Graph.new
|
38
|
+
end
|
39
|
+
|
32
40
|
def self.get_default_graph
|
33
41
|
TensorStream::Graph.get_default_graph
|
34
42
|
end
|
@@ -49,22 +57,60 @@ module TensorStream
|
|
49
57
|
TensorStream::Graph.get_default_graph.executing_eagerly?
|
50
58
|
end
|
51
59
|
|
52
|
-
def self.variable(value,
|
60
|
+
def self.variable(value, name: nil, initializer: nil, graph: nil, dtype: nil, trainable: true)
|
53
61
|
common_options = {
|
54
|
-
initializer: Operation.new(:assign, nil, value),
|
55
|
-
name:
|
62
|
+
initializer: initializer || Operation.new(:assign, nil, value),
|
63
|
+
name: name,
|
64
|
+
graph: graph,
|
65
|
+
dtype: dtype,
|
66
|
+
trainable: trainable
|
56
67
|
}
|
57
68
|
if value.is_a?(String)
|
58
|
-
TensorStream::Variable.new(
|
69
|
+
TensorStream::Variable.new(dtype || :string, 0, [], common_options)
|
59
70
|
elsif value.is_a?(Integer)
|
60
|
-
TensorStream::Variable.new(
|
71
|
+
TensorStream::Variable.new(dtype || :int32, 0, [], common_options)
|
61
72
|
elsif value.is_a?(Float)
|
62
|
-
TensorStream::Variable.new(
|
73
|
+
TensorStream::Variable.new(dtype || :float32, 0, [], common_options)
|
63
74
|
else
|
64
|
-
TensorStream::Variable.new(
|
75
|
+
TensorStream::Variable.new(dtype || :float32, 0, nil, common_options)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
def self.variable_scope(scope = nil, reuse: nil, initializer: nil)
|
80
|
+
Thread.current[:tensor_stream_variable_scope] ||= []
|
81
|
+
Thread.current[:tensor_stream_variable_scope] << OpenStruct.new(name: scope, reuse: reuse, initializer: initializer)
|
82
|
+
scope_name = __v_scope_name
|
83
|
+
begin
|
84
|
+
if block_given?
|
85
|
+
TensorStream.get_default_graph.name_scope(scope) do
|
86
|
+
yield(scope_name)
|
87
|
+
end
|
88
|
+
end
|
89
|
+
ensure
|
90
|
+
Thread.current[:tensor_stream_variable_scope].pop
|
65
91
|
end
|
66
92
|
end
|
67
93
|
|
94
|
+
def self.name_scope(name, default: nil, values: nil)
|
95
|
+
if values
|
96
|
+
graph_count = values.select { |v| v.is_a?(Tensor) }.map(&:graph).map(&:object_id).uniq.size
|
97
|
+
raise "values are not on the same graph" if graph_count > 1
|
98
|
+
end
|
99
|
+
|
100
|
+
get_default_graph.name_scope(name || default) do |scope|
|
101
|
+
yield scope if block_given?
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def self.get_variable_scope
|
106
|
+
return nil unless Thread.current[:tensor_stream_variable_scope]
|
107
|
+
__v_scope_name
|
108
|
+
end
|
109
|
+
|
110
|
+
def self.__v_scope_name
|
111
|
+
Thread.current[:tensor_stream_variable_scope].map(&:name).compact.reject(&:empty?).join('/')
|
112
|
+
end
|
113
|
+
|
68
114
|
def self.session(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
|
69
115
|
session = TensorStream::Session.new(evaluator, thread_pool_class: thread_pool_class)
|
70
116
|
yield session if block_given?
|
@@ -108,8 +154,8 @@ module TensorStream
|
|
108
154
|
TensorStream::ControlFlow.new(:group, inputs)
|
109
155
|
end
|
110
156
|
|
111
|
-
def self.get_variable(name,
|
112
|
-
TensorStream::Variable.new(
|
157
|
+
def self.get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
|
158
|
+
TensorStream::Variable.new(dtype || :float32, nil, shape, collections: collections, name: name, initializer: initializer, trainable: trainable)
|
113
159
|
end
|
114
160
|
|
115
161
|
def self.get_collection(name, options = {})
|
@@ -128,10 +174,28 @@ module TensorStream
|
|
128
174
|
TensorStream::Trainer
|
129
175
|
end
|
130
176
|
|
177
|
+
def self.trainable_variables
|
178
|
+
TensorStream.get_default_graph.get_collection(TensorStream::GraphKeys::TRAINABLE_VARIABLES)
|
179
|
+
end
|
180
|
+
|
181
|
+
def self.set_random_seed(seed)
|
182
|
+
TensorStream.get_default_graph.random_seed = seed
|
183
|
+
end
|
184
|
+
|
185
|
+
def self.convert_to_tensor(value, dtype: nil, name: nil, preferred_dtype: nil)
|
186
|
+
return convert_to_tensor(value.call) if value.is_a?(Proc)
|
187
|
+
|
188
|
+
if !value.is_a?(Tensor)
|
189
|
+
i_cons(value, dtype: dtype || Tensor.detect_type(value), name: name)
|
190
|
+
else
|
191
|
+
value
|
192
|
+
end
|
193
|
+
end
|
194
|
+
|
131
195
|
def self.check_allowed_types(input, types)
|
132
196
|
return input unless input.is_a?(Tensor)
|
133
197
|
return input if input.data_type.nil?
|
134
198
|
|
135
|
-
raise "Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.map(&:to_sym).include?(input.data_type)
|
199
|
+
raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.map(&:to_sym).include?(input.data_type)
|
136
200
|
end
|
137
201
|
end
|
@@ -8,10 +8,10 @@ module TensorStream
|
|
8
8
|
|
9
9
|
@operation = :"flow_#{flow_type}"
|
10
10
|
@items = items
|
11
|
-
@name = set_name
|
11
|
+
@name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
|
12
12
|
@ops = ops
|
13
13
|
@source = format_source(caller_locations)
|
14
|
-
|
14
|
+
@shape = TensorShape.new([items.size])
|
15
15
|
@graph.add_node(self)
|
16
16
|
end
|
17
17
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
require 'tensor_stream/evaluator/operation_helpers/random_gaussian'
|
2
2
|
require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
|
3
3
|
require 'tensor_stream/math_gradients'
|
4
|
+
require 'distribution'
|
4
5
|
|
5
6
|
module TensorStream
|
6
7
|
module Evaluator
|
@@ -28,19 +29,22 @@ module TensorStream
|
|
28
29
|
include TensorStream::OpHelper
|
29
30
|
include TensorStream::ArrayOpsHelper
|
30
31
|
|
31
|
-
def initialize(session, context, thread_pool: nil)
|
32
|
+
def initialize(session, context, thread_pool: nil, log_intermediates: false)
|
32
33
|
@session = session
|
33
34
|
@context = context
|
35
|
+
@log_intermediates = log_intermediates
|
34
36
|
@retain = context[:retain] || []
|
35
37
|
@thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
|
38
|
+
|
39
|
+
@context[:compute_history] = [] if log_intermediates
|
36
40
|
end
|
37
41
|
|
38
42
|
def run(tensor, execution_context)
|
39
43
|
return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)
|
40
44
|
|
41
45
|
return tensor if retain.include?(tensor) # if var is in retain don't eval to value
|
42
|
-
|
43
|
-
tensor = tensor.call
|
46
|
+
|
47
|
+
tensor = tensor.call if tensor.is_a?(Proc)
|
44
48
|
|
45
49
|
child_context = execution_context.dup
|
46
50
|
res = if tensor.is_a?(Operation)
|
@@ -71,7 +75,9 @@ module TensorStream
|
|
71
75
|
protected
|
72
76
|
|
73
77
|
def eval_variable(tensor, child_context)
|
74
|
-
|
78
|
+
if tensor.value.nil?
|
79
|
+
raise "variable #{tensor.name} not initalized"
|
80
|
+
end
|
75
81
|
eval_tensor(tensor.value, child_context).tap do |val|
|
76
82
|
child_context[:returns] ||= {}
|
77
83
|
child_context[:returns][:vars] ||= []
|
@@ -86,6 +92,8 @@ module TensorStream
|
|
86
92
|
b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
|
87
93
|
|
88
94
|
case tensor.operation
|
95
|
+
when :const
|
96
|
+
complete_eval(a, child_context)
|
89
97
|
when :argmax
|
90
98
|
a = complete_eval(a, child_context)
|
91
99
|
axis = tensor.options[:axis] || 0
|
@@ -150,6 +158,8 @@ module TensorStream
|
|
150
158
|
when :concat
|
151
159
|
values = complete_eval(a, child_context)
|
152
160
|
concat_array(values, tensor.options[:axis])
|
161
|
+
when :round
|
162
|
+
call_op(:round, a, child_context, ->(t, _b) { t.round })
|
153
163
|
when :abs
|
154
164
|
call_op(:abs, a, child_context, ->(t, _b) { t.abs })
|
155
165
|
when :tanh
|
@@ -162,27 +172,57 @@ module TensorStream
|
|
162
172
|
call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
|
163
173
|
when :cos
|
164
174
|
call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
|
175
|
+
when :log1p
|
176
|
+
call_op(:log1p, a, child_context, ->(t, _b) { Distribution::MathExtension::Log.log1p(t) })
|
165
177
|
when :log
|
166
178
|
call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t) })
|
167
179
|
when :exp
|
168
180
|
call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) })
|
181
|
+
when :sigmoid
|
182
|
+
call_op(:sigmoid, a, child_context, ->(t, _b) { 1 / (1 + Math.exp(-t)) })
|
169
183
|
when :sqrt
|
170
184
|
call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) })
|
171
185
|
when :square
|
172
186
|
call_op(:square, a, child_context, ->(t, _b) { t * t })
|
187
|
+
when :reciprocal
|
188
|
+
call_op(:square, a, child_context, ->(t, _b) { 1 / t })
|
173
189
|
when :stop_gradient
|
174
190
|
run(a, child_context)
|
175
191
|
when :random_uniform
|
176
192
|
maxval = tensor.options.fetch(:maxval, 1)
|
177
193
|
minval = tensor.options.fetch(:minval, 0)
|
194
|
+
seed = tensor.options[:seed]
|
178
195
|
|
179
|
-
|
180
|
-
|
196
|
+
random = _get_randomizer(tensor, seed)
|
197
|
+
generator = -> { random.rand * (maxval - minval) + minval }
|
198
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
199
|
+
generate_vector(shape, generator: generator)
|
181
200
|
when :random_normal
|
182
|
-
|
201
|
+
random = _get_randomizer(tensor, seed)
|
202
|
+
r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
|
203
|
+
random = _get_randomizer(tensor, seed)
|
183
204
|
generator = -> { r.rand }
|
184
|
-
|
185
|
-
generate_vector(
|
205
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
206
|
+
generate_vector(shape, generator: generator)
|
207
|
+
when :glorot_uniform
|
208
|
+
random = _get_randomizer(tensor, seed)
|
209
|
+
|
210
|
+
shape = tensor.options[:shape] || tensor.shape.shape
|
211
|
+
fan_in, fan_out = if shape.size == 0
|
212
|
+
[1, 1]
|
213
|
+
elsif shape.size == 1
|
214
|
+
[1, shape[0]]
|
215
|
+
else
|
216
|
+
[shape[0], shape.last]
|
217
|
+
end
|
218
|
+
|
219
|
+
limit = Math.sqrt(6.0 / (fan_in + fan_out))
|
220
|
+
|
221
|
+
minval = -limit
|
222
|
+
maxval = limit
|
223
|
+
|
224
|
+
generator = -> { random.rand * (maxval - minval) + minval }
|
225
|
+
generate_vector(shape, generator: generator)
|
186
226
|
when :flow_group
|
187
227
|
tensor.items.collect { |item| run(item, child_context) }
|
188
228
|
when :assign
|
@@ -314,8 +354,8 @@ module TensorStream
|
|
314
354
|
rank_a = get_rank(matrix_a)
|
315
355
|
rank_b = get_rank(matrix_b)
|
316
356
|
|
317
|
-
raise "#{
|
318
|
-
raise "#{
|
357
|
+
raise "#{tensor.items[0].name} rank must be greater than 1" if rank_a < 2
|
358
|
+
raise "#{tensor.items[1].name} rank must be greater than 1" if rank_b < 2
|
319
359
|
|
320
360
|
matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
|
321
361
|
matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
|
@@ -353,9 +393,9 @@ module TensorStream
|
|
353
393
|
flat_arr = arr.flatten
|
354
394
|
return flat_arr[0] if new_shape.size.zero? && flat_arr.size == 1
|
355
395
|
|
356
|
-
new_shape = fix_inferred_elements(new_shape, flat_arr.size)
|
396
|
+
new_shape = TensorShape.fix_inferred_elements(new_shape, flat_arr.size)
|
357
397
|
|
358
|
-
reshape(flat_arr, new_shape)
|
398
|
+
TensorShape.reshape(flat_arr, new_shape)
|
359
399
|
when :pad
|
360
400
|
a = complete_eval(a, child_context)
|
361
401
|
p = complete_eval(tensor.options[:paddings], child_context)
|
@@ -380,19 +420,35 @@ module TensorStream
|
|
380
420
|
|
381
421
|
tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
|
382
422
|
end
|
423
|
+
if @log_intermediates
|
424
|
+
@context[:compute_history] << {
|
425
|
+
name: tensor.name,
|
426
|
+
type: tensor.data_type,
|
427
|
+
shape: shape_eval(result),
|
428
|
+
source: tensor.source,
|
429
|
+
description: tensor.to_math(true, 1),
|
430
|
+
value: result
|
431
|
+
}
|
432
|
+
end
|
383
433
|
@context[tensor.name] = result
|
384
434
|
end
|
385
435
|
rescue EvaluatorExcecutionException => e
|
386
436
|
raise e
|
387
437
|
rescue StandardError => e
|
438
|
+
puts e.message
|
439
|
+
puts e.backtrace.join("\n")
|
440
|
+
shape_a = a.shape.shape if a
|
441
|
+
shape_b = b.shape.shape if b
|
442
|
+
dtype_a = a.data_type if a
|
443
|
+
dtype_b = b.data_type if b
|
388
444
|
a = complete_eval(a, child_context)
|
389
445
|
b = complete_eval(b, child_context)
|
390
446
|
puts "name: #{tensor.given_name}"
|
391
447
|
puts "op: #{tensor.to_math(true, 1)}"
|
392
|
-
puts "A: #{a}" if a
|
393
|
-
puts "B: #{b}" if b
|
448
|
+
puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
449
|
+
puts "B #{shape_b} #{dtype_b}: #{b}" if b
|
450
|
+
dump_intermediates if @log_intermediates
|
394
451
|
|
395
|
-
puts e.backtrace.join("\n")
|
396
452
|
raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true,1)} defined at #{tensor.source}"
|
397
453
|
end
|
398
454
|
|
@@ -515,30 +571,6 @@ module TensorStream
|
|
515
571
|
end
|
516
572
|
end
|
517
573
|
|
518
|
-
def fix_inferred_elements(shape, total_size)
|
519
|
-
return shape if shape.empty?
|
520
|
-
|
521
|
-
current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
|
522
|
-
inferred_size = total_size / current_size
|
523
|
-
shape.map { |s| s == -1 ? inferred_size : s }
|
524
|
-
end
|
525
|
-
|
526
|
-
def reshape(arr, new_shape)
|
527
|
-
return arr if new_shape.empty?
|
528
|
-
|
529
|
-
s = new_shape.shift
|
530
|
-
|
531
|
-
if new_shape.size.zero?
|
532
|
-
raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
|
533
|
-
return arr
|
534
|
-
end
|
535
|
-
|
536
|
-
dim = (arr.size / s)
|
537
|
-
arr.each_slice(dim).collect do |slice|
|
538
|
-
reshape(slice, new_shape.dup)
|
539
|
-
end
|
540
|
-
end
|
541
|
-
|
542
574
|
def call_op(op, a, child_context, func)
|
543
575
|
a = complete_eval(a, child_context)
|
544
576
|
process_function_op(a, child_context, func)
|
@@ -723,6 +755,38 @@ module TensorStream
|
|
723
755
|
generator.call
|
724
756
|
end
|
725
757
|
end
|
758
|
+
|
759
|
+
def _get_randomizer(tensor, seed)
|
760
|
+
if tensor.graph.random_seed && seed
|
761
|
+
Random.new(tensor.graph.random_seed ^ seed)
|
762
|
+
elsif tensor.graph.random_seed
|
763
|
+
@session.randomizer[tensor.graph.object_id] ||= Random.new(tensor.graph.random_seed)
|
764
|
+
@session.randomizer[tensor.graph.object_id]
|
765
|
+
elsif seed
|
766
|
+
@session.randomizer[tensor.operation] ||= Random.new(seed)
|
767
|
+
@session.randomizer[tensor.operation]
|
768
|
+
else
|
769
|
+
Random.new
|
770
|
+
end
|
771
|
+
end
|
772
|
+
|
773
|
+
def dump_intermediates
|
774
|
+
arr = []
|
775
|
+
arr << "============== start ==================="
|
776
|
+
@context[:compute_history].each_with_index do |history, index|
|
777
|
+
arr << "------------------------------------"
|
778
|
+
arr << history[:name]
|
779
|
+
arr << "#{history[:type]} #{history[:shape]}"
|
780
|
+
arr << history[:source]
|
781
|
+
arr << history[:description]
|
782
|
+
arr << ""
|
783
|
+
arr << history[:value].to_json
|
784
|
+
arr << "------------------------------------"
|
785
|
+
end
|
786
|
+
arr << "============== end ====================="
|
787
|
+
str = arr.join("\n")
|
788
|
+
File.write("/tmp/intermediates.txt", str)
|
789
|
+
end
|
726
790
|
end
|
727
791
|
end
|
728
792
|
end
|