tensor_stream 1.0.0 → 1.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.rubocop.yml +1 -0
- data/Gemfile +1 -1
- data/LICENSE.txt +1 -1
- data/README.md +34 -34
- data/Rakefile +3 -3
- data/USAGE_GUIDE.md +235 -0
- data/bin/stubgen +20 -0
- data/exe/model_utils +2 -2
- data/lib/tensor_stream.rb +45 -44
- data/lib/tensor_stream/constant.rb +2 -2
- data/lib/tensor_stream/control_flow.rb +1 -1
- data/lib/tensor_stream/debugging/debugging.rb +2 -2
- data/lib/tensor_stream/dynamic_stitch.rb +2 -2
- data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
- data/lib/tensor_stream/evaluator/buffer.rb +1 -1
- data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
- data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
- data/lib/tensor_stream/exceptions.rb +1 -1
- data/lib/tensor_stream/generated_stub/ops.rb +691 -0
- data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
- data/lib/tensor_stream/graph.rb +18 -18
- data/lib/tensor_stream/graph_builder.rb +17 -17
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
- data/lib/tensor_stream/graph_keys.rb +3 -3
- data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
- data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
- data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
- data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
- data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
- data/lib/tensor_stream/helpers/op_helper.rb +8 -9
- data/lib/tensor_stream/helpers/string_helper.rb +15 -15
- data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
- data/lib/tensor_stream/images.rb +1 -1
- data/lib/tensor_stream/initializer.rb +1 -1
- data/lib/tensor_stream/math_gradients.rb +28 -187
- data/lib/tensor_stream/monkey_patches/array.rb +1 -1
- data/lib/tensor_stream/monkey_patches/float.rb +1 -1
- data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
- data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
- data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
- data/lib/tensor_stream/nn/nn_ops.rb +17 -15
- data/lib/tensor_stream/op_maker.rb +180 -0
- data/lib/tensor_stream/operation.rb +17 -17
- data/lib/tensor_stream/ops.rb +95 -384
- data/lib/tensor_stream/ops/add.rb +23 -0
- data/lib/tensor_stream/ops/argmax.rb +14 -0
- data/lib/tensor_stream/ops/argmin.rb +14 -0
- data/lib/tensor_stream/ops/case.rb +17 -0
- data/lib/tensor_stream/ops/cast.rb +15 -0
- data/lib/tensor_stream/ops/ceil.rb +15 -0
- data/lib/tensor_stream/ops/const.rb +0 -0
- data/lib/tensor_stream/ops/cos.rb +10 -0
- data/lib/tensor_stream/ops/div.rb +21 -0
- data/lib/tensor_stream/ops/equal.rb +15 -0
- data/lib/tensor_stream/ops/expand_dims.rb +17 -0
- data/lib/tensor_stream/ops/fill.rb +19 -0
- data/lib/tensor_stream/ops/floor.rb +15 -0
- data/lib/tensor_stream/ops/floor_div.rb +15 -0
- data/lib/tensor_stream/ops/greater.rb +11 -0
- data/lib/tensor_stream/ops/greater_equal.rb +11 -0
- data/lib/tensor_stream/ops/less_equal.rb +15 -0
- data/lib/tensor_stream/ops/log.rb +14 -0
- data/lib/tensor_stream/ops/mat_mul.rb +60 -0
- data/lib/tensor_stream/ops/max.rb +15 -0
- data/lib/tensor_stream/ops/min.rb +15 -0
- data/lib/tensor_stream/ops/mod.rb +23 -0
- data/lib/tensor_stream/ops/mul.rb +21 -0
- data/lib/tensor_stream/ops/negate.rb +14 -0
- data/lib/tensor_stream/ops/ones_like.rb +19 -0
- data/lib/tensor_stream/ops/pow.rb +25 -0
- data/lib/tensor_stream/ops/prod.rb +60 -0
- data/lib/tensor_stream/ops/random_uniform.rb +18 -0
- data/lib/tensor_stream/ops/range.rb +20 -0
- data/lib/tensor_stream/ops/rank.rb +13 -0
- data/lib/tensor_stream/ops/reshape.rb +24 -0
- data/lib/tensor_stream/ops/round.rb +15 -0
- data/lib/tensor_stream/ops/shape.rb +14 -0
- data/lib/tensor_stream/ops/sigmoid.rb +10 -0
- data/lib/tensor_stream/ops/sign.rb +12 -0
- data/lib/tensor_stream/ops/sin.rb +10 -0
- data/lib/tensor_stream/ops/size.rb +16 -0
- data/lib/tensor_stream/ops/sub.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +27 -0
- data/lib/tensor_stream/ops/tan.rb +12 -0
- data/lib/tensor_stream/ops/tanh.rb +10 -0
- data/lib/tensor_stream/ops/tile.rb +19 -0
- data/lib/tensor_stream/ops/zeros.rb +15 -0
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/profile/report_tool.rb +3 -3
- data/lib/tensor_stream/session.rb +36 -38
- data/lib/tensor_stream/tensor.rb +2 -2
- data/lib/tensor_stream/tensor_shape.rb +4 -4
- data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
- data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
- data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
- data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
- data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
- data/lib/tensor_stream/train/optimizer.rb +9 -9
- data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
- data/lib/tensor_stream/train/saver.rb +14 -14
- data/lib/tensor_stream/train/slot_creator.rb +6 -6
- data/lib/tensor_stream/train/utils.rb +12 -12
- data/lib/tensor_stream/trainer.rb +10 -10
- data/lib/tensor_stream/types.rb +1 -1
- data/lib/tensor_stream/utils.rb +33 -32
- data/lib/tensor_stream/utils/freezer.rb +5 -5
- data/lib/tensor_stream/variable.rb +5 -5
- data/lib/tensor_stream/variable_scope.rb +1 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/{iris.data → datasets/iris.data} +0 -0
- data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
- data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
- data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
- data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
- data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
- data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
- data/samples/regression/linear_regression.rb +63 -0
- data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
- data/tensor_stream.gemspec +9 -8
- metadata +89 -19
- data/data_1.json +0 -4764
- data/data_2.json +0 -4764
- data/data_actual.json +0 -28
- data/data_expected.json +0 -28
- data/data_input.json +0 -28
- data/samples/error.graphml +0 -2755
- data/samples/gradient_sample.graphml +0 -1255
- data/samples/linear_regression.rb +0 -69
- data/samples/multigpu.rb +0 -73
- data/samples/raw_neural_net_sample.rb +0 -112
@@ -6,7 +6,7 @@ module TensorStream
|
|
6
6
|
# Utility class to convert variables to constants for production deployment
|
7
7
|
#
|
8
8
|
def convert(session, checkpoint_folder, output_file)
|
9
|
-
model_file = File.join(checkpoint_folder,
|
9
|
+
model_file = File.join(checkpoint_folder, "model.yaml")
|
10
10
|
TensorStream.graph.as_default do |current_graph|
|
11
11
|
YamlLoader.new.load_from_string(File.read(model_file))
|
12
12
|
saver = TensorStream::Train::Saver.new
|
@@ -15,7 +15,7 @@ module TensorStream
|
|
15
15
|
# collect all assign ops and remove them from the graph
|
16
16
|
remove_nodes = Set.new(current_graph.nodes.values.select { |op| op.is_a?(TensorStream::Operation) && op.operation == :assign }.map { |op| op.consumers.to_a }.flatten.uniq)
|
17
17
|
|
18
|
-
output_buffer = TensorStream::Yaml.new.get_string(current_graph)
|
18
|
+
output_buffer = TensorStream::Yaml.new.get_string(current_graph) { |graph, node_key|
|
19
19
|
node = graph.get_tensor_by_name(node_key)
|
20
20
|
case node.operation
|
21
21
|
when :variable_v2
|
@@ -23,7 +23,7 @@ module TensorStream
|
|
23
23
|
options = {
|
24
24
|
value: value,
|
25
25
|
data_type: node.data_type,
|
26
|
-
shape: shape_eval(value)
|
26
|
+
shape: shape_eval(value),
|
27
27
|
}
|
28
28
|
const_op = TensorStream::Operation.new(current_graph, inputs: [], options: options)
|
29
29
|
const_op.name = node.name
|
@@ -37,9 +37,9 @@ module TensorStream
|
|
37
37
|
else
|
38
38
|
remove_nodes.include?(node.name) ? nil : node
|
39
39
|
end
|
40
|
-
|
40
|
+
}
|
41
41
|
File.write(output_file, output_buffer)
|
42
42
|
end
|
43
43
|
end
|
44
44
|
end
|
45
|
-
end
|
45
|
+
end
|
@@ -18,7 +18,7 @@ module TensorStream
|
|
18
18
|
|
19
19
|
scope_name = variable_scope ? variable_scope.name : nil
|
20
20
|
variable_scope_initializer = variable_scope ? variable_scope.initializer : nil
|
21
|
-
@name = [scope_name, options[:name] || build_name].compact.reject(&:empty?).join(
|
21
|
+
@name = [scope_name, options[:name] || build_name].compact.reject(&:empty?).join("/")
|
22
22
|
@initalizer_tensor = options[:initializer] || variable_scope_initializer || TensorStream.glorot_uniform_initializer
|
23
23
|
shape = @initalizer_tensor.shape.shape if shape.nil? && @initalizer_tensor && @initalizer_tensor.shape
|
24
24
|
|
@@ -45,7 +45,7 @@ module TensorStream
|
|
45
45
|
end
|
46
46
|
|
47
47
|
def assign(value, name: nil, use_locking: false)
|
48
|
-
|
48
|
+
TensorStream.check_data_types(self, value)
|
49
49
|
_op(:assign, self, value, name: name)
|
50
50
|
end
|
51
51
|
|
@@ -55,7 +55,7 @@ module TensorStream
|
|
55
55
|
end
|
56
56
|
|
57
57
|
def assign_add(value, name: nil)
|
58
|
-
|
58
|
+
TensorStream.check_data_types(self, value)
|
59
59
|
_op(:assign_add, self, value, data_type: data_type, name: name)
|
60
60
|
end
|
61
61
|
|
@@ -64,7 +64,7 @@ module TensorStream
|
|
64
64
|
end
|
65
65
|
|
66
66
|
def assign_sub(value)
|
67
|
-
|
67
|
+
TensorStream.check_data_types(self, value)
|
68
68
|
_op(:assign_sub, self, value)
|
69
69
|
end
|
70
70
|
|
@@ -77,7 +77,7 @@ module TensorStream
|
|
77
77
|
end
|
78
78
|
|
79
79
|
def inspect
|
80
|
-
"Variable(#{@name} shape: #{@shape ||
|
80
|
+
"Variable(#{@name} shape: #{@shape || "?"} data_type: #{@data_type})"
|
81
81
|
end
|
82
82
|
|
83
83
|
protected
|
File without changes
|
@@ -0,0 +1,463 @@
|
|
1
|
+
{
|
2
|
+
"cells": [
|
3
|
+
{
|
4
|
+
"cell_type": "markdown",
|
5
|
+
"metadata": {},
|
6
|
+
"source": [
|
7
|
+
"Notebook showing linear regression with tensor_stream"
|
8
|
+
]
|
9
|
+
},
|
10
|
+
{
|
11
|
+
"cell_type": "code",
|
12
|
+
"execution_count": 1,
|
13
|
+
"metadata": {},
|
14
|
+
"outputs": [
|
15
|
+
{
|
16
|
+
"data": {
|
17
|
+
"text/plain": [
|
18
|
+
"true"
|
19
|
+
]
|
20
|
+
},
|
21
|
+
"execution_count": 1,
|
22
|
+
"metadata": {},
|
23
|
+
"output_type": "execute_result"
|
24
|
+
}
|
25
|
+
],
|
26
|
+
"source": [
|
27
|
+
"require 'tensor_stream'"
|
28
|
+
]
|
29
|
+
},
|
30
|
+
{
|
31
|
+
"cell_type": "markdown",
|
32
|
+
"metadata": {},
|
33
|
+
"source": [
|
34
|
+
"Setup leraning parameters"
|
35
|
+
]
|
36
|
+
},
|
37
|
+
{
|
38
|
+
"cell_type": "code",
|
39
|
+
"execution_count": 2,
|
40
|
+
"metadata": {},
|
41
|
+
"outputs": [
|
42
|
+
{
|
43
|
+
"data": {
|
44
|
+
"text/plain": [
|
45
|
+
"50"
|
46
|
+
]
|
47
|
+
},
|
48
|
+
"execution_count": 2,
|
49
|
+
"metadata": {},
|
50
|
+
"output_type": "execute_result"
|
51
|
+
}
|
52
|
+
],
|
53
|
+
"source": [
|
54
|
+
"ts = TensorStream # assign module to var so we don't have to type so much\n",
|
55
|
+
"\n",
|
56
|
+
"learning_rate = 0.01\n",
|
57
|
+
"momentum = 0.5\n",
|
58
|
+
"training_epochs = 1000\n",
|
59
|
+
"display_step = 50"
|
60
|
+
]
|
61
|
+
},
|
62
|
+
{
|
63
|
+
"cell_type": "markdown",
|
64
|
+
"metadata": {},
|
65
|
+
"source": [
|
66
|
+
"Prepare training and test data"
|
67
|
+
]
|
68
|
+
},
|
69
|
+
{
|
70
|
+
"cell_type": "code",
|
71
|
+
"execution_count": 3,
|
72
|
+
"metadata": {},
|
73
|
+
"outputs": [
|
74
|
+
{
|
75
|
+
"data": {
|
76
|
+
"text/plain": [
|
77
|
+
"17"
|
78
|
+
]
|
79
|
+
},
|
80
|
+
"execution_count": 3,
|
81
|
+
"metadata": {},
|
82
|
+
"output_type": "execute_result"
|
83
|
+
}
|
84
|
+
],
|
85
|
+
"source": [
|
86
|
+
"# Setup training data\n",
|
87
|
+
"\n",
|
88
|
+
"train_X = [3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,\n",
|
89
|
+
"7.042,10.791,5.313,7.997,5.654,9.27,3.1]\n",
|
90
|
+
"train_Y = [1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,\n",
|
91
|
+
"2.827,3.465,1.65,2.904,2.42,2.94,1.3]\n",
|
92
|
+
"\n",
|
93
|
+
"# Store number of samples\n",
|
94
|
+
"n_samples = train_X.size"
|
95
|
+
]
|
96
|
+
},
|
97
|
+
{
|
98
|
+
"cell_type": "markdown",
|
99
|
+
"metadata": {},
|
100
|
+
"source": [
|
101
|
+
"Build the odel with two scalar variables w and b"
|
102
|
+
]
|
103
|
+
},
|
104
|
+
{
|
105
|
+
"cell_type": "code",
|
106
|
+
"execution_count": 4,
|
107
|
+
"metadata": {},
|
108
|
+
"outputs": [
|
109
|
+
{
|
110
|
+
"data": {
|
111
|
+
"text/plain": [
|
112
|
+
"Op(div name: div shape: ? data_type: float32)"
|
113
|
+
]
|
114
|
+
},
|
115
|
+
"execution_count": 4,
|
116
|
+
"metadata": {},
|
117
|
+
"output_type": "execute_result"
|
118
|
+
}
|
119
|
+
],
|
120
|
+
"source": [
|
121
|
+
"X = Float.placeholder\n",
|
122
|
+
"Y = Float.placeholder\n",
|
123
|
+
"\n",
|
124
|
+
"# Set model weights\n",
|
125
|
+
"\n",
|
126
|
+
"W = rand.t.var name: \"weight\"\n",
|
127
|
+
"b = rand.t.var name: \"bias\"\n",
|
128
|
+
"\n",
|
129
|
+
"# Construct a linear model\n",
|
130
|
+
"pred = X * W + b\n",
|
131
|
+
"\n",
|
132
|
+
"# Mean squared error\n",
|
133
|
+
"cost = ((pred - Y) ** 2).reduce / ( 2 * n_samples)"
|
134
|
+
]
|
135
|
+
},
|
136
|
+
{
|
137
|
+
"cell_type": "markdown",
|
138
|
+
"metadata": {},
|
139
|
+
"source": [
|
140
|
+
"Set Optimizer as SGD"
|
141
|
+
]
|
142
|
+
},
|
143
|
+
{
|
144
|
+
"cell_type": "code",
|
145
|
+
"execution_count": 5,
|
146
|
+
"metadata": {},
|
147
|
+
"outputs": [
|
148
|
+
{
|
149
|
+
"data": {
|
150
|
+
"text/plain": [
|
151
|
+
"Op(flow_group name: GradientDescent/flow_group shape: TensorShape([Dimension(2)]) data_type: )"
|
152
|
+
]
|
153
|
+
},
|
154
|
+
"execution_count": 5,
|
155
|
+
"metadata": {},
|
156
|
+
"output_type": "execute_result"
|
157
|
+
}
|
158
|
+
],
|
159
|
+
"source": [
|
160
|
+
"optimizer = TensorStream::Train::GradientDescentOptimizer.new(learning_rate).minimize(cost)"
|
161
|
+
]
|
162
|
+
},
|
163
|
+
{
|
164
|
+
"cell_type": "code",
|
165
|
+
"execution_count": 9,
|
166
|
+
"metadata": {
|
167
|
+
"scrolled": true
|
168
|
+
},
|
169
|
+
"outputs": [
|
170
|
+
{
|
171
|
+
"data": {
|
172
|
+
"text/plain": [
|
173
|
+
"#<TensorStream::Session:0x000055d21c13dbe0 @thread_pool=#<Concurrent::ImmediateExecutor:0x000055d21c13db18 @stopped=#<Concurrent::Event:0x000055d21c13daa0 @__Lock__=#<Thread::Mutex:0x000055d21c13da50>, @__Condition__=#<Thread::ConditionVariable:0x000055d21c13da28>, @set=false, @iteration=0>>, @closed=false, @session_cache={}, @randomizer={}, @log_device_placement=false, @evaluator_options={:profile_enabled=>false}, @evaluator_classes=[TensorStream::Evaluator::RubyEvaluator], @evaluators={}>"
|
174
|
+
]
|
175
|
+
},
|
176
|
+
"execution_count": 9,
|
177
|
+
"metadata": {},
|
178
|
+
"output_type": "execute_result"
|
179
|
+
}
|
180
|
+
],
|
181
|
+
"source": [
|
182
|
+
"sess = ts.session"
|
183
|
+
]
|
184
|
+
},
|
185
|
+
{
|
186
|
+
"cell_type": "markdown",
|
187
|
+
"metadata": {},
|
188
|
+
"source": [
|
189
|
+
"Defi"
|
190
|
+
]
|
191
|
+
},
|
192
|
+
{
|
193
|
+
"cell_type": "code",
|
194
|
+
"execution_count": 10,
|
195
|
+
"metadata": {},
|
196
|
+
"outputs": [
|
197
|
+
{
|
198
|
+
"data": {
|
199
|
+
"text/plain": [
|
200
|
+
"Op(flow_group name: /flow_group_1 shape: TensorShape([Dimension(2)]) data_type: )"
|
201
|
+
]
|
202
|
+
},
|
203
|
+
"execution_count": 10,
|
204
|
+
"metadata": {},
|
205
|
+
"output_type": "execute_result"
|
206
|
+
}
|
207
|
+
],
|
208
|
+
"source": [
|
209
|
+
"# Initialize the variables (i.e. assign their default value)\n",
|
210
|
+
"init = ts.global_variables_initializer()"
|
211
|
+
]
|
212
|
+
},
|
213
|
+
{
|
214
|
+
"cell_type": "code",
|
215
|
+
"execution_count": 12,
|
216
|
+
"metadata": {},
|
217
|
+
"outputs": [
|
218
|
+
{
|
219
|
+
"name": "stdout",
|
220
|
+
"output_type": "stream",
|
221
|
+
"text": [
|
222
|
+
"Epoch:\n",
|
223
|
+
"0050\n",
|
224
|
+
"cost=\n",
|
225
|
+
"0.08361721120294868\n",
|
226
|
+
"W=\n",
|
227
|
+
"0.29528699916245393\n",
|
228
|
+
"b=\n",
|
229
|
+
"0.4727514024197094\n",
|
230
|
+
"Epoch:\n",
|
231
|
+
"0100\n",
|
232
|
+
"cost=\n",
|
233
|
+
"0.08284862164544735\n",
|
234
|
+
"W=\n",
|
235
|
+
"0.29256938855488984\n",
|
236
|
+
"b=\n",
|
237
|
+
"0.49230169833258675\n",
|
238
|
+
"Epoch:\n",
|
239
|
+
"0150\n",
|
240
|
+
"cost=\n",
|
241
|
+
"0.08216896362778606\n",
|
242
|
+
"W=\n",
|
243
|
+
"0.29001340685813926\n",
|
244
|
+
"b=\n",
|
245
|
+
"0.5106892475755477\n",
|
246
|
+
"Epoch:\n",
|
247
|
+
"0200\n",
|
248
|
+
"cost=\n",
|
249
|
+
"0.08156795987766097\n",
|
250
|
+
"W=\n",
|
251
|
+
"0.28760944123449095\n",
|
252
|
+
"b=\n",
|
253
|
+
"0.5279832040774313\n",
|
254
|
+
"Epoch:\n",
|
255
|
+
"0250\n",
|
256
|
+
"cost=\n",
|
257
|
+
"0.08103652004507544\n",
|
258
|
+
"W=\n",
|
259
|
+
"0.28534845058380326\n",
|
260
|
+
"b=\n",
|
261
|
+
"0.5442486088650512\n",
|
262
|
+
"Epoch:\n",
|
263
|
+
"0300\n",
|
264
|
+
"cost=\n",
|
265
|
+
"0.08056660367120239\n",
|
266
|
+
"W=\n",
|
267
|
+
"0.28322193152204567\n",
|
268
|
+
"b=\n",
|
269
|
+
"0.5595466346740284\n",
|
270
|
+
"Epoch:\n",
|
271
|
+
"0350\n",
|
272
|
+
"cost=\n",
|
273
|
+
"0.08015109897338235\n",
|
274
|
+
"W=\n",
|
275
|
+
"0.28122188640083184\n",
|
276
|
+
"b=\n",
|
277
|
+
"0.5739348160139869\n",
|
278
|
+
"Epoch:\n",
|
279
|
+
"0400\n",
|
280
|
+
"cost=\n",
|
281
|
+
"0.07978371562246367\n",
|
282
|
+
"W=\n",
|
283
|
+
"0.2793407932289771\n",
|
284
|
+
"b=\n",
|
285
|
+
"0.5874672655507152\n",
|
286
|
+
"Epoch:\n",
|
287
|
+
"0450\n",
|
288
|
+
"cost=\n",
|
289
|
+
"0.0794588898974416\n",
|
290
|
+
"W=\n",
|
291
|
+
"0.277571577382961\n",
|
292
|
+
"b=\n",
|
293
|
+
"0.6001948776190661\n",
|
294
|
+
"Epoch:\n",
|
295
|
+
"0500\n",
|
296
|
+
"cost=\n",
|
297
|
+
"0.07917170078875599\n",
|
298
|
+
"W=\n",
|
299
|
+
"0.2759075849999002\n",
|
300
|
+
"b=\n",
|
301
|
+
"0.6121655196320117\n",
|
302
|
+
"Epoch:\n",
|
303
|
+
"0550\n",
|
304
|
+
"cost=\n",
|
305
|
+
"0.0789177957865016\n",
|
306
|
+
"W=\n",
|
307
|
+
"0.274342557952963\n",
|
308
|
+
"b=\n",
|
309
|
+
"0.6234242121057137\n",
|
310
|
+
"Epoch:\n",
|
311
|
+
"0600\n",
|
312
|
+
"cost=\n",
|
313
|
+
"0.07869332523566322\n",
|
314
|
+
"W=\n",
|
315
|
+
"0.27287061031511406\n",
|
316
|
+
"b=\n",
|
317
|
+
"0.634013297977656\n",
|
318
|
+
"Epoch:\n",
|
319
|
+
"0650\n",
|
320
|
+
"cost=\n",
|
321
|
+
"0.0784948842695121\n",
|
322
|
+
"W=\n",
|
323
|
+
"0.27148620622266817\n",
|
324
|
+
"b=\n",
|
325
|
+
"0.6439726018546622\n",
|
326
|
+
"Epoch:\n",
|
327
|
+
"0700\n",
|
328
|
+
"cost=\n",
|
329
|
+
"0.07831946144643755\n",
|
330
|
+
"W=\n",
|
331
|
+
"0.27018413905539995\n",
|
332
|
+
"b=\n",
|
333
|
+
"0.6533395797896867\n",
|
334
|
+
"Epoch:\n",
|
335
|
+
"0750\n",
|
336
|
+
"cost=\n",
|
337
|
+
"0.07816439331644526\n",
|
338
|
+
"W=\n",
|
339
|
+
"0.26895951185491046\n",
|
340
|
+
"b=\n",
|
341
|
+
"0.6621494601506834\n",
|
342
|
+
"Epoch:\n",
|
343
|
+
"0800\n",
|
344
|
+
"cost=\n",
|
345
|
+
"0.07802732423286504\n",
|
346
|
+
"W=\n",
|
347
|
+
"0.2678077189076042\n",
|
348
|
+
"b=\n",
|
349
|
+
"0.6704353761113543\n",
|
350
|
+
"Epoch:\n",
|
351
|
+
"0850\n",
|
352
|
+
"cost=\n",
|
353
|
+
"0.07790617080380904\n",
|
354
|
+
"W=\n",
|
355
|
+
"0.2667244284230108\n",
|
356
|
+
"b=\n",
|
357
|
+
"0.6782284902620607\n",
|
358
|
+
"Epoch:\n",
|
359
|
+
"0900\n",
|
360
|
+
"cost=\n",
|
361
|
+
"0.07779909044780774\n",
|
362
|
+
"W=\n",
|
363
|
+
"0.26570556624230784\n",
|
364
|
+
"b=\n",
|
365
|
+
"0.6855581118095395\n",
|
366
|
+
"Epoch:\n",
|
367
|
+
"0950\n",
|
368
|
+
"cost=\n",
|
369
|
+
"0.07770445357986415\n",
|
370
|
+
"W=\n",
|
371
|
+
"0.2647473005157735\n",
|
372
|
+
"b=\n",
|
373
|
+
"0.6924518068062181\n",
|
374
|
+
"Epoch:\n",
|
375
|
+
"1000\n",
|
376
|
+
"cost=\n",
|
377
|
+
"0.07762081900885288\n",
|
378
|
+
"W=\n",
|
379
|
+
"0.2638460272915415\n",
|
380
|
+
"b=\n",
|
381
|
+
"0.6989355018236741\n",
|
382
|
+
"Optimization Finished!\n",
|
383
|
+
"Training cost=\n",
|
384
|
+
"0.07761924952715166\n",
|
385
|
+
"W=\n",
|
386
|
+
"0.263828559523301\n",
|
387
|
+
"b=\n",
|
388
|
+
"0.6990611636971903\n",
|
389
|
+
"\\n\n",
|
390
|
+
"time elapsed \n",
|
391
|
+
"30\n"
|
392
|
+
]
|
393
|
+
}
|
394
|
+
],
|
395
|
+
"source": [
|
396
|
+
"# Train\n",
|
397
|
+
"start_time = Time.now\n",
|
398
|
+
"\n",
|
399
|
+
"sess.run(init)\n",
|
400
|
+
"(0..training_epochs).each do |epoch|\n",
|
401
|
+
" train_X.zip(train_Y).each do |x,y|\n",
|
402
|
+
" sess.run(optimizer, feed_dict: {X => x, Y => y})\n",
|
403
|
+
" end\n",
|
404
|
+
"\n",
|
405
|
+
" if (epoch+1) % display_step == 0\n",
|
406
|
+
" # Save the variables to disk.\n",
|
407
|
+
" c = sess.run(cost, feed_dict: {X => train_X, Y => train_Y})\n",
|
408
|
+
" puts(\"Epoch:\", '%04d' % (epoch+1), \"cost=\", c, \\\n",
|
409
|
+
" \"W=\", sess.run(W), \"b=\", sess.run(b))\n",
|
410
|
+
" end\n",
|
411
|
+
"end\n",
|
412
|
+
"\n",
|
413
|
+
"puts(\"Optimization Finished!\")\n",
|
414
|
+
"training_cost = sess.run(cost, feed_dict: { X => train_X, Y => train_Y})\n",
|
415
|
+
"puts(\"Training cost=\", training_cost, \"W=\", sess.run(W), \"b=\", sess.run(b), '\\n')\n",
|
416
|
+
"puts(\"time elapsed \", Time.now.to_i - start_time.to_i)"
|
417
|
+
]
|
418
|
+
},
|
419
|
+
{
|
420
|
+
"cell_type": "code",
|
421
|
+
"execution_count": 13,
|
422
|
+
"metadata": {},
|
423
|
+
"outputs": [
|
424
|
+
{
|
425
|
+
"data": {
|
426
|
+
"text/plain": [
|
427
|
+
"3.3373467589302"
|
428
|
+
]
|
429
|
+
},
|
430
|
+
"execution_count": 13,
|
431
|
+
"metadata": {},
|
432
|
+
"output_type": "execute_result"
|
433
|
+
}
|
434
|
+
],
|
435
|
+
"source": [
|
436
|
+
"# Predict\n",
|
437
|
+
"y_predict = sess.run(pred, feed_dict: {X => 10.0 })"
|
438
|
+
]
|
439
|
+
},
|
440
|
+
{
|
441
|
+
"cell_type": "code",
|
442
|
+
"execution_count": null,
|
443
|
+
"metadata": {},
|
444
|
+
"outputs": [],
|
445
|
+
"source": []
|
446
|
+
}
|
447
|
+
],
|
448
|
+
"metadata": {
|
449
|
+
"kernelspec": {
|
450
|
+
"display_name": "Ruby 2.4.0",
|
451
|
+
"language": "ruby",
|
452
|
+
"name": "ruby"
|
453
|
+
},
|
454
|
+
"language_info": {
|
455
|
+
"file_extension": ".rb",
|
456
|
+
"mimetype": "application/x-ruby",
|
457
|
+
"name": "ruby",
|
458
|
+
"version": "2.4.0"
|
459
|
+
}
|
460
|
+
},
|
461
|
+
"nbformat": 4,
|
462
|
+
"nbformat_minor": 2
|
463
|
+
}
|