tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,156 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class SessionOptions
4
+ attr_accessor :target
5
+ def self.finalize(pointer)
6
+ proc do
7
+ FFI.TF_DeleteSessionOptions(pointer)
8
+ end
9
+ end
10
+
11
+ def initialize
12
+ @pointer = FFI.TF_NewSessionOptions
13
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
14
+ end
15
+
16
+ def to_ptr
17
+ @pointer
18
+ end
19
+ end
20
+
21
+ class Session
22
+ attr_accessor :graph, :options
23
+
24
+ def self.run(graph)
25
+ session = self.new(graph, SessionOptions.new)
26
+ result = yield session
27
+ session.close
28
+ result
29
+ end
30
+
31
+ def self.finalize(pointer)
32
+ proc do
33
+ FFI.TF_DeleteSession(pointer)
34
+ end
35
+ end
36
+
37
+ def initialize(graph, options)
38
+ @graph = graph
39
+ Status.check do |status|
40
+ @pointer = FFI.TF_NewSession(graph, options, status)
41
+ end
42
+ end
43
+
44
+ def to_ptr
45
+ @pointer
46
+ end
47
+
48
+ def run(operations, feed_dict={})
49
+ operations = Array(operations).flatten.compact
50
+
51
+ key_outputs = feed_dict.keys.map(&:outputs).flatten
52
+ keys_ptr = FFI::Output.array_to_ptr(key_outputs.map(&:output))
53
+
54
+ values = self.values_to_tensors(feed_dict)
55
+ values_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
56
+ values_ptr.write_array_of_pointer(values)
57
+
58
+ # Gather up all the outputs for each operation
59
+ outputs = operations.map do |operation|
60
+ case operation
61
+ when Operation, Variable
62
+ operation.outputs
63
+ when OperationOutput
64
+ operation
65
+ else
66
+ raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
67
+ end
68
+ end.flatten
69
+
70
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
71
+ results_ptr = ::FFI::MemoryPointer.new(:pointer, outputs.length)
72
+
73
+ # Gather up all the targets
74
+ targets = operations.map do |operation|
75
+ case operation
76
+ when Operation, Variable
77
+ operation
78
+ when OperationOutput
79
+ operation.operation
80
+ else
81
+ raise("Unsupported target: #{operation}")
82
+ end
83
+ end
84
+ targets_ptr = ::FFI::MemoryPointer.new(:pointer, targets.length)
85
+ targets_ptr.write_array_of_pointer(targets)
86
+
87
+ run_options = nil
88
+ metadata = nil
89
+
90
+ Status.check do |status|
91
+ FFI.TF_SessionRun(self, run_options,
92
+ # Inputs
93
+ keys_ptr, values_ptr, feed_dict.keys.length,
94
+ # Outputs
95
+ outputs_ptr, results_ptr, outputs.length,
96
+ # Targets
97
+ targets_ptr, operations.length,
98
+ metadata,
99
+ status)
100
+ end
101
+
102
+ results = results_ptr.read_array_of_pointer(outputs.length).map.with_index do |pointer, i|
103
+ output = outputs[i]
104
+ Tensor.from_pointer(pointer).value
105
+ end
106
+
107
+ # For each operation we want to return a single result
108
+ start = 0
109
+ result = operations.reduce(Array.new) do |array, operation|
110
+ length = case operation
111
+ when Operation, Variable
112
+ operation.outputs.length
113
+ when OperationOutput
114
+ 1
115
+ else
116
+ raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
117
+ end
118
+
119
+ if length == 0
120
+ array << nil
121
+ else
122
+ array.concat(results[start, length])
123
+ start += length
124
+ end
125
+ array
126
+ end
127
+
128
+ if operations.length == 1 && results.length == 1
129
+ result.first
130
+ else
131
+ result
132
+ end
133
+ end
134
+
135
+ def close
136
+ Status.check do |status|
137
+ FFI.TF_CloseSession(self, status)
138
+ end
139
+ end
140
+
141
+ def values_to_tensors(values)
142
+ values.map do |key, value|
143
+ case value
144
+ when Tensor
145
+ value
146
+ else
147
+ # The value dtype needs to match the key dtype
148
+ raise(Error::UnknownError, "Cannot determine dtype: #{key}") if key.num_outputs != 1
149
+ dtype = key.output_types.first
150
+ Tensor.new(value, dtype: dtype)
151
+ end
152
+ end
153
+ end
154
+ end
155
+ end
156
+ end
@@ -0,0 +1,32 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module BostonHousing
5
+ def self.load_data(path: "boston_housing.npz", test_split: 0.2, seed: 113)
6
+ file = Utils.get_file(
7
+ path,
8
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/boston_housing.npz",
9
+ file_hash: "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
10
+ )
11
+
12
+ data = Npy.load_npz(file)
13
+
14
+ x = data["x"]
15
+ y = data["y"]
16
+
17
+ len = x.shape[0]
18
+ indices = (0...len).to_a.shuffle(random: ::Random.new(seed))
19
+ x = x[indices, true]
20
+ y = y[indices]
21
+
22
+ x_train = x[0...(len * (1 - test_split)), true]
23
+ y_train = y[0...(len * (1 - test_split))]
24
+ x_test = x[(len * (1 - test_split))..-1, true]
25
+ y_test = y[(len * (1 - test_split))..-1]
26
+
27
+ [[x_train, y_train], [x_test, y_test]]
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module CIFAR10
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data
7
+ # end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module CIFAR100
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data
7
+ # end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,44 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module FashionMNIST
5
+ def self.load_data
6
+ base_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets"
7
+ files = [
8
+ "train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
9
+ "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"
10
+ ]
11
+
12
+ paths = []
13
+ files.each do |file|
14
+ paths << Utils.get_file(file, "#{base_url}/#{file}", cache_subdir: "datasets/fashion-mnist")
15
+ end
16
+
17
+ x_train, y_train, x_test, y_test = nil
18
+
19
+ Zlib::GzipReader.open(paths[0]) do |gz|
20
+ gz.read(8) # move to offset
21
+ y_train = Numo::UInt8.from_string(gz.read)
22
+ end
23
+
24
+ Zlib::GzipReader.open(paths[1]) do |gz|
25
+ gz.read(16) # move to offset
26
+ x_train = Numo::UInt8.from_string(gz.read, [y_train.shape[0], 28, 28])
27
+ end
28
+
29
+ Zlib::GzipReader.open(paths[2]) do |gz|
30
+ gz.read(8) # move to offset
31
+ y_test = Numo::UInt8.from_string(gz.read)
32
+ end
33
+
34
+ Zlib::GzipReader.open(paths[3]) do |gz|
35
+ gz.read(16) # move to offset
36
+ x_test = Numo::UInt8.from_string(gz.read, [y_test.shape[0], 28, 28])
37
+ end
38
+
39
+ [[x_train, y_train], [x_test, y_test]]
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
@@ -0,0 +1,30 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module IMDB
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data(path: "imdb.npz", seed: 113)
7
+ # data = Utils.load_dataset(
8
+ # path,
9
+ # "https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz",
10
+ # "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
11
+ # )
12
+
13
+ # x_train = data["x_train"]
14
+ # labels_train = data["y_train"]
15
+ # x_test = data["x_test"]
16
+ # labels_test = data["y_test"]
17
+ # end
18
+
19
+ def self.get_word_index(path: "imdb_word_index.json")
20
+ file = Utils.get_file(
21
+ path,
22
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json",
23
+ file_hash: "bfafd718b763782e994055a2d397834f"
24
+ )
25
+ JSON.parse(File.read(file))
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,18 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module MNIST
5
+ def self.load_data(path: "mnist.npz")
6
+ file = Utils.get_file(
7
+ path,
8
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
9
+ file_hash: "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
10
+ )
11
+
12
+ data = Npy.load_npz(file)
13
+ [[data["x_train"], data["y_train"]], [data["x_test"], data["y_test"]]]
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,28 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module Reuters
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data(path: "reuters.npz", test_split: 0.2, seed: 113)
7
+ # data = Utils.load_dataset(
8
+ # path,
9
+ # "https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz",
10
+ # "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
11
+ # )
12
+
13
+ # xs = data["x"]
14
+ # labels = data["y"]
15
+ # end
16
+
17
+ def self.get_word_index(path: "reuters_word_index.json")
18
+ file = Utils.get_file(
19
+ path,
20
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters_word_index.json",
21
+ file_hash: "4d44cc38712099c9e383dc6e5f11a921"
22
+ )
23
+ JSON.parse(File.read(file))
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,14 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Conv
5
+ def initialize(rank, filters, kernel_size, activation: nil)
6
+ @rank = rank
7
+ @filter = filters
8
+ @kernel_size = kernel_size
9
+ @activation = activation
10
+ end
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Conv2D < Conv
5
+ def initialize(filters, kernel_size, activation: nil)
6
+ super(2, filters, kernel_size, activation: activation)
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,68 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Dense
5
+ def initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float)
6
+ @units = units
7
+ @activation = activation
8
+ @use_bias = use_bias
9
+ @kernel_initializer = kernel_initializer
10
+ @bias_initializer = bias_initializer
11
+ @dtype = dtype
12
+ @built = false
13
+ end
14
+
15
+ def build(input_shape)
16
+ last_dim = input_shape.last
17
+ @kernel = Utils.add_weight(name: "kernel", shape: [last_dim, @units], initializer: @kernel_initializer, dtype: @dtype)
18
+
19
+ if @use_bias
20
+ @bias = Utils.add_weight(name: "bias", shape: [@units], initializer: @bias_initializer, dtype: @dtype)
21
+ else
22
+ @bias = nil
23
+ end
24
+
25
+ @output_shape = [last_dim, @units]
26
+
27
+ @built = true
28
+ end
29
+
30
+ def output_shape
31
+ @output_shape
32
+ end
33
+
34
+ def count_params
35
+ @units + @kernel.shape.inject(&:*)
36
+ end
37
+
38
+ def call(inputs)
39
+ build(inputs.shape) unless @built
40
+
41
+ rank = inputs.shape.size
42
+
43
+ if rank > 2
44
+ raise Error, "Rank > 2 not supported yet"
45
+ else
46
+ inputs = Tensorflow.cast(inputs, @dtype)
47
+ outputs = Tensorflow.matmul(inputs, @kernel)
48
+ end
49
+
50
+ if @use_bias
51
+ outputs = NN.bias_add(outputs, @bias)
52
+ end
53
+
54
+ case @activation
55
+ when "relu"
56
+ NN.relu(outputs)
57
+ when "softmax"
58
+ NN.softmax(outputs)
59
+ when nil
60
+ outputs
61
+ else
62
+ raise "Unknown activation: #{@activation}"
63
+ end
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,27 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Dropout
5
+ def initialize(rate)
6
+ end
7
+
8
+ def build(input_shape)
9
+ @output_shape = input_shape
10
+ end
11
+
12
+ def call(inputs)
13
+ # TODO implement
14
+ Tensorflow.identity(inputs)
15
+ end
16
+
17
+ def output_shape
18
+ @output_shape
19
+ end
20
+
21
+ def count_params
22
+ 0
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end