tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,156 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class SessionOptions
|
4
|
+
attr_accessor :target
|
5
|
+
def self.finalize(pointer)
|
6
|
+
proc do
|
7
|
+
FFI.TF_DeleteSessionOptions(pointer)
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
def initialize
|
12
|
+
@pointer = FFI.TF_NewSessionOptions
|
13
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
14
|
+
end
|
15
|
+
|
16
|
+
def to_ptr
|
17
|
+
@pointer
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
class Session
|
22
|
+
attr_accessor :graph, :options
|
23
|
+
|
24
|
+
def self.run(graph)
|
25
|
+
session = self.new(graph, SessionOptions.new)
|
26
|
+
result = yield session
|
27
|
+
session.close
|
28
|
+
result
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.finalize(pointer)
|
32
|
+
proc do
|
33
|
+
FFI.TF_DeleteSession(pointer)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def initialize(graph, options)
|
38
|
+
@graph = graph
|
39
|
+
Status.check do |status|
|
40
|
+
@pointer = FFI.TF_NewSession(graph, options, status)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
def to_ptr
|
45
|
+
@pointer
|
46
|
+
end
|
47
|
+
|
48
|
+
def run(operations, feed_dict={})
|
49
|
+
operations = Array(operations).flatten.compact
|
50
|
+
|
51
|
+
key_outputs = feed_dict.keys.map(&:outputs).flatten
|
52
|
+
keys_ptr = FFI::Output.array_to_ptr(key_outputs.map(&:output))
|
53
|
+
|
54
|
+
values = self.values_to_tensors(feed_dict)
|
55
|
+
values_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
|
56
|
+
values_ptr.write_array_of_pointer(values)
|
57
|
+
|
58
|
+
# Gather up all the outputs for each operation
|
59
|
+
outputs = operations.map do |operation|
|
60
|
+
case operation
|
61
|
+
when Operation, Variable
|
62
|
+
operation.outputs
|
63
|
+
when OperationOutput
|
64
|
+
operation
|
65
|
+
else
|
66
|
+
raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
|
67
|
+
end
|
68
|
+
end.flatten
|
69
|
+
|
70
|
+
outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
|
71
|
+
results_ptr = ::FFI::MemoryPointer.new(:pointer, outputs.length)
|
72
|
+
|
73
|
+
# Gather up all the targets
|
74
|
+
targets = operations.map do |operation|
|
75
|
+
case operation
|
76
|
+
when Operation, Variable
|
77
|
+
operation
|
78
|
+
when OperationOutput
|
79
|
+
operation.operation
|
80
|
+
else
|
81
|
+
raise("Unsupported target: #{operation}")
|
82
|
+
end
|
83
|
+
end
|
84
|
+
targets_ptr = ::FFI::MemoryPointer.new(:pointer, targets.length)
|
85
|
+
targets_ptr.write_array_of_pointer(targets)
|
86
|
+
|
87
|
+
run_options = nil
|
88
|
+
metadata = nil
|
89
|
+
|
90
|
+
Status.check do |status|
|
91
|
+
FFI.TF_SessionRun(self, run_options,
|
92
|
+
# Inputs
|
93
|
+
keys_ptr, values_ptr, feed_dict.keys.length,
|
94
|
+
# Outputs
|
95
|
+
outputs_ptr, results_ptr, outputs.length,
|
96
|
+
# Targets
|
97
|
+
targets_ptr, operations.length,
|
98
|
+
metadata,
|
99
|
+
status)
|
100
|
+
end
|
101
|
+
|
102
|
+
results = results_ptr.read_array_of_pointer(outputs.length).map.with_index do |pointer, i|
|
103
|
+
output = outputs[i]
|
104
|
+
Tensor.from_pointer(pointer).value
|
105
|
+
end
|
106
|
+
|
107
|
+
# For each operation we want to return a single result
|
108
|
+
start = 0
|
109
|
+
result = operations.reduce(Array.new) do |array, operation|
|
110
|
+
length = case operation
|
111
|
+
when Operation, Variable
|
112
|
+
operation.outputs.length
|
113
|
+
when OperationOutput
|
114
|
+
1
|
115
|
+
else
|
116
|
+
raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
|
117
|
+
end
|
118
|
+
|
119
|
+
if length == 0
|
120
|
+
array << nil
|
121
|
+
else
|
122
|
+
array.concat(results[start, length])
|
123
|
+
start += length
|
124
|
+
end
|
125
|
+
array
|
126
|
+
end
|
127
|
+
|
128
|
+
if operations.length == 1 && results.length == 1
|
129
|
+
result.first
|
130
|
+
else
|
131
|
+
result
|
132
|
+
end
|
133
|
+
end
|
134
|
+
|
135
|
+
def close
|
136
|
+
Status.check do |status|
|
137
|
+
FFI.TF_CloseSession(self, status)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
|
141
|
+
def values_to_tensors(values)
|
142
|
+
values.map do |key, value|
|
143
|
+
case value
|
144
|
+
when Tensor
|
145
|
+
value
|
146
|
+
else
|
147
|
+
# The value dtype needs to match the key dtype
|
148
|
+
raise(Error::UnknownError, "Cannot determine dtype: #{key}") if key.num_outputs != 1
|
149
|
+
dtype = key.output_types.first
|
150
|
+
Tensor.new(value, dtype: dtype)
|
151
|
+
end
|
152
|
+
end
|
153
|
+
end
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
@@ -0,0 +1,32 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Datasets
|
4
|
+
module BostonHousing
|
5
|
+
def self.load_data(path: "boston_housing.npz", test_split: 0.2, seed: 113)
|
6
|
+
file = Utils.get_file(
|
7
|
+
path,
|
8
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/boston_housing.npz",
|
9
|
+
file_hash: "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
|
10
|
+
)
|
11
|
+
|
12
|
+
data = Npy.load_npz(file)
|
13
|
+
|
14
|
+
x = data["x"]
|
15
|
+
y = data["y"]
|
16
|
+
|
17
|
+
len = x.shape[0]
|
18
|
+
indices = (0...len).to_a.shuffle(random: ::Random.new(seed))
|
19
|
+
x = x[indices, true]
|
20
|
+
y = y[indices]
|
21
|
+
|
22
|
+
x_train = x[0...(len * (1 - test_split)), true]
|
23
|
+
y_train = y[0...(len * (1 - test_split))]
|
24
|
+
x_test = x[(len * (1 - test_split))..-1, true]
|
25
|
+
y_test = y[(len * (1 - test_split))..-1]
|
26
|
+
|
27
|
+
[[x_train, y_train], [x_test, y_test]]
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Datasets
|
4
|
+
module FashionMNIST
|
5
|
+
def self.load_data
|
6
|
+
base_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets"
|
7
|
+
files = [
|
8
|
+
"train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
|
9
|
+
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"
|
10
|
+
]
|
11
|
+
|
12
|
+
paths = []
|
13
|
+
files.each do |file|
|
14
|
+
paths << Utils.get_file(file, "#{base_url}/#{file}", cache_subdir: "datasets/fashion-mnist")
|
15
|
+
end
|
16
|
+
|
17
|
+
x_train, y_train, x_test, y_test = nil
|
18
|
+
|
19
|
+
Zlib::GzipReader.open(paths[0]) do |gz|
|
20
|
+
gz.read(8) # move to offset
|
21
|
+
y_train = Numo::UInt8.from_string(gz.read)
|
22
|
+
end
|
23
|
+
|
24
|
+
Zlib::GzipReader.open(paths[1]) do |gz|
|
25
|
+
gz.read(16) # move to offset
|
26
|
+
x_train = Numo::UInt8.from_string(gz.read, [y_train.shape[0], 28, 28])
|
27
|
+
end
|
28
|
+
|
29
|
+
Zlib::GzipReader.open(paths[2]) do |gz|
|
30
|
+
gz.read(8) # move to offset
|
31
|
+
y_test = Numo::UInt8.from_string(gz.read)
|
32
|
+
end
|
33
|
+
|
34
|
+
Zlib::GzipReader.open(paths[3]) do |gz|
|
35
|
+
gz.read(16) # move to offset
|
36
|
+
x_test = Numo::UInt8.from_string(gz.read, [y_test.shape[0], 28, 28])
|
37
|
+
end
|
38
|
+
|
39
|
+
[[x_train, y_train], [x_test, y_test]]
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Datasets
|
4
|
+
module IMDB
|
5
|
+
# unfortunately, npy can't read pickle and numo can't store objects
|
6
|
+
# def self.load_data(path: "imdb.npz", seed: 113)
|
7
|
+
# data = Utils.load_dataset(
|
8
|
+
# path,
|
9
|
+
# "https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz",
|
10
|
+
# "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
|
11
|
+
# )
|
12
|
+
|
13
|
+
# x_train = data["x_train"]
|
14
|
+
# labels_train = data["y_train"]
|
15
|
+
# x_test = data["x_test"]
|
16
|
+
# labels_test = data["y_test"]
|
17
|
+
# end
|
18
|
+
|
19
|
+
def self.get_word_index(path: "imdb_word_index.json")
|
20
|
+
file = Utils.get_file(
|
21
|
+
path,
|
22
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json",
|
23
|
+
file_hash: "bfafd718b763782e994055a2d397834f"
|
24
|
+
)
|
25
|
+
JSON.parse(File.read(file))
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Datasets
|
4
|
+
module MNIST
|
5
|
+
def self.load_data(path: "mnist.npz")
|
6
|
+
file = Utils.get_file(
|
7
|
+
path,
|
8
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
|
9
|
+
file_hash: "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
|
10
|
+
)
|
11
|
+
|
12
|
+
data = Npy.load_npz(file)
|
13
|
+
[[data["x_train"], data["y_train"]], [data["x_test"], data["y_test"]]]
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Datasets
|
4
|
+
module Reuters
|
5
|
+
# unfortunately, npy can't read pickle and numo can't store objects
|
6
|
+
# def self.load_data(path: "reuters.npz", test_split: 0.2, seed: 113)
|
7
|
+
# data = Utils.load_dataset(
|
8
|
+
# path,
|
9
|
+
# "https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz",
|
10
|
+
# "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
|
11
|
+
# )
|
12
|
+
|
13
|
+
# xs = data["x"]
|
14
|
+
# labels = data["y"]
|
15
|
+
# end
|
16
|
+
|
17
|
+
def self.get_word_index(path: "reuters_word_index.json")
|
18
|
+
file = Utils.get_file(
|
19
|
+
path,
|
20
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters_word_index.json",
|
21
|
+
file_hash: "4d44cc38712099c9e383dc6e5f11a921"
|
22
|
+
)
|
23
|
+
JSON.parse(File.read(file))
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Layers
|
4
|
+
class Dense
|
5
|
+
def initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float)
|
6
|
+
@units = units
|
7
|
+
@activation = activation
|
8
|
+
@use_bias = use_bias
|
9
|
+
@kernel_initializer = kernel_initializer
|
10
|
+
@bias_initializer = bias_initializer
|
11
|
+
@dtype = dtype
|
12
|
+
@built = false
|
13
|
+
end
|
14
|
+
|
15
|
+
def build(input_shape)
|
16
|
+
last_dim = input_shape.last
|
17
|
+
@kernel = Utils.add_weight(name: "kernel", shape: [last_dim, @units], initializer: @kernel_initializer, dtype: @dtype)
|
18
|
+
|
19
|
+
if @use_bias
|
20
|
+
@bias = Utils.add_weight(name: "bias", shape: [@units], initializer: @bias_initializer, dtype: @dtype)
|
21
|
+
else
|
22
|
+
@bias = nil
|
23
|
+
end
|
24
|
+
|
25
|
+
@output_shape = [last_dim, @units]
|
26
|
+
|
27
|
+
@built = true
|
28
|
+
end
|
29
|
+
|
30
|
+
def output_shape
|
31
|
+
@output_shape
|
32
|
+
end
|
33
|
+
|
34
|
+
def count_params
|
35
|
+
@units + @kernel.shape.inject(&:*)
|
36
|
+
end
|
37
|
+
|
38
|
+
def call(inputs)
|
39
|
+
build(inputs.shape) unless @built
|
40
|
+
|
41
|
+
rank = inputs.shape.size
|
42
|
+
|
43
|
+
if rank > 2
|
44
|
+
raise Error, "Rank > 2 not supported yet"
|
45
|
+
else
|
46
|
+
inputs = Tensorflow.cast(inputs, @dtype)
|
47
|
+
outputs = Tensorflow.matmul(inputs, @kernel)
|
48
|
+
end
|
49
|
+
|
50
|
+
if @use_bias
|
51
|
+
outputs = NN.bias_add(outputs, @bias)
|
52
|
+
end
|
53
|
+
|
54
|
+
case @activation
|
55
|
+
when "relu"
|
56
|
+
NN.relu(outputs)
|
57
|
+
when "softmax"
|
58
|
+
NN.softmax(outputs)
|
59
|
+
when nil
|
60
|
+
outputs
|
61
|
+
else
|
62
|
+
raise "Unknown activation: #{@activation}"
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Layers
|
4
|
+
class Dropout
|
5
|
+
def initialize(rate)
|
6
|
+
end
|
7
|
+
|
8
|
+
def build(input_shape)
|
9
|
+
@output_shape = input_shape
|
10
|
+
end
|
11
|
+
|
12
|
+
def call(inputs)
|
13
|
+
# TODO implement
|
14
|
+
Tensorflow.identity(inputs)
|
15
|
+
end
|
16
|
+
|
17
|
+
def output_shape
|
18
|
+
@output_shape
|
19
|
+
end
|
20
|
+
|
21
|
+
def count_params
|
22
|
+
0
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|