tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,156 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class SessionOptions
4
+ attr_accessor :target
5
+ def self.finalize(pointer)
6
+ proc do
7
+ FFI.TF_DeleteSessionOptions(pointer)
8
+ end
9
+ end
10
+
11
+ def initialize
12
+ @pointer = FFI.TF_NewSessionOptions
13
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
14
+ end
15
+
16
+ def to_ptr
17
+ @pointer
18
+ end
19
+ end
20
+
21
+ class Session
22
+ attr_accessor :graph, :options
23
+
24
+ def self.run(graph)
25
+ session = self.new(graph, SessionOptions.new)
26
+ result = yield session
27
+ session.close
28
+ result
29
+ end
30
+
31
+ def self.finalize(pointer)
32
+ proc do
33
+ FFI.TF_DeleteSession(pointer)
34
+ end
35
+ end
36
+
37
+ def initialize(graph, options)
38
+ @graph = graph
39
+ Status.check do |status|
40
+ @pointer = FFI.TF_NewSession(graph, options, status)
41
+ end
42
+ end
43
+
44
+ def to_ptr
45
+ @pointer
46
+ end
47
+
48
+ def run(operations, feed_dict={})
49
+ operations = Array(operations).flatten.compact
50
+
51
+ key_outputs = feed_dict.keys.map(&:outputs).flatten
52
+ keys_ptr = FFI::Output.array_to_ptr(key_outputs.map(&:output))
53
+
54
+ values = self.values_to_tensors(feed_dict)
55
+ values_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
56
+ values_ptr.write_array_of_pointer(values)
57
+
58
+ # Gather up all the outputs for each operation
59
+ outputs = operations.map do |operation|
60
+ case operation
61
+ when Operation, Variable
62
+ operation.outputs
63
+ when OperationOutput
64
+ operation
65
+ else
66
+ raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
67
+ end
68
+ end.flatten
69
+
70
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
71
+ results_ptr = ::FFI::MemoryPointer.new(:pointer, outputs.length)
72
+
73
+ # Gather up all the targets
74
+ targets = operations.map do |operation|
75
+ case operation
76
+ when Operation, Variable
77
+ operation
78
+ when OperationOutput
79
+ operation.operation
80
+ else
81
+ raise("Unsupported target: #{operation}")
82
+ end
83
+ end
84
+ targets_ptr = ::FFI::MemoryPointer.new(:pointer, targets.length)
85
+ targets_ptr.write_array_of_pointer(targets)
86
+
87
+ run_options = nil
88
+ metadata = nil
89
+
90
+ Status.check do |status|
91
+ FFI.TF_SessionRun(self, run_options,
92
+ # Inputs
93
+ keys_ptr, values_ptr, feed_dict.keys.length,
94
+ # Outputs
95
+ outputs_ptr, results_ptr, outputs.length,
96
+ # Targets
97
+ targets_ptr, operations.length,
98
+ metadata,
99
+ status)
100
+ end
101
+
102
+ results = results_ptr.read_array_of_pointer(outputs.length).map.with_index do |pointer, i|
103
+ output = outputs[i]
104
+ Tensor.from_pointer(pointer).value
105
+ end
106
+
107
+ # For each operation we want to return a single result
108
+ start = 0
109
+ result = operations.reduce(Array.new) do |array, operation|
110
+ length = case operation
111
+ when Operation, Variable
112
+ operation.outputs.length
113
+ when OperationOutput
114
+ 1
115
+ else
116
+ raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
117
+ end
118
+
119
+ if length == 0
120
+ array << nil
121
+ else
122
+ array.concat(results[start, length])
123
+ start += length
124
+ end
125
+ array
126
+ end
127
+
128
+ if operations.length == 1 && results.length == 1
129
+ result.first
130
+ else
131
+ result
132
+ end
133
+ end
134
+
135
+ def close
136
+ Status.check do |status|
137
+ FFI.TF_CloseSession(self, status)
138
+ end
139
+ end
140
+
141
+ def values_to_tensors(values)
142
+ values.map do |key, value|
143
+ case value
144
+ when Tensor
145
+ value
146
+ else
147
+ # The value dtype needs to match the key dtype
148
+ raise(Error::UnknownError, "Cannot determine dtype: #{key}") if key.num_outputs != 1
149
+ dtype = key.output_types.first
150
+ Tensor.new(value, dtype: dtype)
151
+ end
152
+ end
153
+ end
154
+ end
155
+ end
156
+ end
@@ -0,0 +1,32 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module BostonHousing
5
+ def self.load_data(path: "boston_housing.npz", test_split: 0.2, seed: 113)
6
+ file = Utils.get_file(
7
+ path,
8
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/boston_housing.npz",
9
+ file_hash: "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
10
+ )
11
+
12
+ data = Npy.load_npz(file)
13
+
14
+ x = data["x"]
15
+ y = data["y"]
16
+
17
+ len = x.shape[0]
18
+ indices = (0...len).to_a.shuffle(random: ::Random.new(seed))
19
+ x = x[indices, true]
20
+ y = y[indices]
21
+
22
+ x_train = x[0...(len * (1 - test_split)), true]
23
+ y_train = y[0...(len * (1 - test_split))]
24
+ x_test = x[(len * (1 - test_split))..-1, true]
25
+ y_test = y[(len * (1 - test_split))..-1]
26
+
27
+ [[x_train, y_train], [x_test, y_test]]
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module CIFAR10
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data
7
+ # end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module CIFAR100
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data
7
+ # end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,44 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module FashionMNIST
5
+ def self.load_data
6
+ base_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets"
7
+ files = [
8
+ "train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
9
+ "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"
10
+ ]
11
+
12
+ paths = []
13
+ files.each do |file|
14
+ paths << Utils.get_file(file, "#{base_url}/#{file}", cache_subdir: "datasets/fashion-mnist")
15
+ end
16
+
17
+ x_train, y_train, x_test, y_test = nil
18
+
19
+ Zlib::GzipReader.open(paths[0]) do |gz|
20
+ gz.read(8) # move to offset
21
+ y_train = Numo::UInt8.from_string(gz.read)
22
+ end
23
+
24
+ Zlib::GzipReader.open(paths[1]) do |gz|
25
+ gz.read(16) # move to offset
26
+ x_train = Numo::UInt8.from_string(gz.read, [y_train.shape[0], 28, 28])
27
+ end
28
+
29
+ Zlib::GzipReader.open(paths[2]) do |gz|
30
+ gz.read(8) # move to offset
31
+ y_test = Numo::UInt8.from_string(gz.read)
32
+ end
33
+
34
+ Zlib::GzipReader.open(paths[3]) do |gz|
35
+ gz.read(16) # move to offset
36
+ x_test = Numo::UInt8.from_string(gz.read, [y_test.shape[0], 28, 28])
37
+ end
38
+
39
+ [[x_train, y_train], [x_test, y_test]]
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
@@ -0,0 +1,30 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module IMDB
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data(path: "imdb.npz", seed: 113)
7
+ # data = Utils.load_dataset(
8
+ # path,
9
+ # "https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz",
10
+ # "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
11
+ # )
12
+
13
+ # x_train = data["x_train"]
14
+ # labels_train = data["y_train"]
15
+ # x_test = data["x_test"]
16
+ # labels_test = data["y_test"]
17
+ # end
18
+
19
+ def self.get_word_index(path: "imdb_word_index.json")
20
+ file = Utils.get_file(
21
+ path,
22
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json",
23
+ file_hash: "bfafd718b763782e994055a2d397834f"
24
+ )
25
+ JSON.parse(File.read(file))
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,18 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module MNIST
5
+ def self.load_data(path: "mnist.npz")
6
+ file = Utils.get_file(
7
+ path,
8
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
9
+ file_hash: "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
10
+ )
11
+
12
+ data = Npy.load_npz(file)
13
+ [[data["x_train"], data["y_train"]], [data["x_test"], data["y_test"]]]
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,28 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Datasets
4
+ module Reuters
5
+ # unfortunately, npy can't read pickle and numo can't store objects
6
+ # def self.load_data(path: "reuters.npz", test_split: 0.2, seed: 113)
7
+ # data = Utils.load_dataset(
8
+ # path,
9
+ # "https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz",
10
+ # "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
11
+ # )
12
+
13
+ # xs = data["x"]
14
+ # labels = data["y"]
15
+ # end
16
+
17
+ def self.get_word_index(path: "reuters_word_index.json")
18
+ file = Utils.get_file(
19
+ path,
20
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters_word_index.json",
21
+ file_hash: "4d44cc38712099c9e383dc6e5f11a921"
22
+ )
23
+ JSON.parse(File.read(file))
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,14 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Conv
5
+ def initialize(rank, filters, kernel_size, activation: nil)
6
+ @rank = rank
7
+ @filter = filters
8
+ @kernel_size = kernel_size
9
+ @activation = activation
10
+ end
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,11 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Conv2D < Conv
5
+ def initialize(filters, kernel_size, activation: nil)
6
+ super(2, filters, kernel_size, activation: activation)
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,68 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Dense
5
+ def initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float)
6
+ @units = units
7
+ @activation = activation
8
+ @use_bias = use_bias
9
+ @kernel_initializer = kernel_initializer
10
+ @bias_initializer = bias_initializer
11
+ @dtype = dtype
12
+ @built = false
13
+ end
14
+
15
+ def build(input_shape)
16
+ last_dim = input_shape.last
17
+ @kernel = Utils.add_weight(name: "kernel", shape: [last_dim, @units], initializer: @kernel_initializer, dtype: @dtype)
18
+
19
+ if @use_bias
20
+ @bias = Utils.add_weight(name: "bias", shape: [@units], initializer: @bias_initializer, dtype: @dtype)
21
+ else
22
+ @bias = nil
23
+ end
24
+
25
+ @output_shape = [last_dim, @units]
26
+
27
+ @built = true
28
+ end
29
+
30
+ def output_shape
31
+ @output_shape
32
+ end
33
+
34
+ def count_params
35
+ @units + @kernel.shape.inject(&:*)
36
+ end
37
+
38
+ def call(inputs)
39
+ build(inputs.shape) unless @built
40
+
41
+ rank = inputs.shape.size
42
+
43
+ if rank > 2
44
+ raise Error, "Rank > 2 not supported yet"
45
+ else
46
+ inputs = Tensorflow.cast(inputs, @dtype)
47
+ outputs = Tensorflow.matmul(inputs, @kernel)
48
+ end
49
+
50
+ if @use_bias
51
+ outputs = NN.bias_add(outputs, @bias)
52
+ end
53
+
54
+ case @activation
55
+ when "relu"
56
+ NN.relu(outputs)
57
+ when "softmax"
58
+ NN.softmax(outputs)
59
+ when nil
60
+ outputs
61
+ else
62
+ raise "Unknown activation: #{@activation}"
63
+ end
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,27 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Dropout
5
+ def initialize(rate)
6
+ end
7
+
8
+ def build(input_shape)
9
+ @output_shape = input_shape
10
+ end
11
+
12
+ def call(inputs)
13
+ # TODO implement
14
+ Tensorflow.identity(inputs)
15
+ end
16
+
17
+ def output_shape
18
+ @output_shape
19
+ end
20
+
21
+ def count_params
22
+ 0
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end