tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,25 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Flatten
5
+ def initialize(input_shape: nil)
6
+ @input_shape = input_shape
7
+ end
8
+
9
+ def output_shape
10
+ flattened_dim = @input_shape.inject(&:*)
11
+ [-1, flattened_dim]
12
+ end
13
+
14
+ def count_params
15
+ 0
16
+ end
17
+
18
+ def call(inputs)
19
+ flattened_dim = inputs.shape[1..-1].inject(&:*)
20
+ Tensorflow.reshape(inputs, [-1, flattened_dim])
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,14 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Losses
4
+ class SparseCategoricalCrossentropy
5
+ def call(target, output)
6
+ output = Math.log(output)
7
+ target = Tensorflow.cast(target, :int64)
8
+ cost, _ = RawOps.sparse_softmax_cross_entropy_with_logits(features: output, labels: target)
9
+ Math.reduce_sum(cost) / Tensorflow.cast(RawOps.size(input: cost), :float)
10
+ end
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,30 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Metrics
4
+ class Mean
5
+ def initialize(name: nil, dtype: :float)
6
+ @dtype = dtype
7
+ @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
8
+ @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
9
+ end
10
+
11
+ def call(*args)
12
+ update_state(*args)
13
+ end
14
+
15
+ def update_state(values)
16
+ input = Tensorflow.cast(input, destination_dtype: @dtype)
17
+ @total.assign_add(Math.reduce_sum(input))
18
+ @count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
19
+ end
20
+
21
+ def result
22
+ RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
23
+ end
24
+
25
+ def reset_states
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,17 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Metrics
4
+ class SparseCategoricalAccuracy < Mean
5
+ def update_state(y_true, y_pred)
6
+ y_pred = RawOps.arg_max(y_pred, -1)
7
+
8
+ # if y_pred.dtype != y_true.dtype
9
+ # y_pred = Tensorflow.cast(y_pred, y_true.dtype)
10
+ # end
11
+
12
+ super(Math.equal(y_true, y_pred))
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,6 @@
1
+ module Tensorflow
2
+ module Keras
3
+ class Model
4
+ end
5
+ end
6
+ end
@@ -0,0 +1,56 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Models
4
+ class Sequential
5
+ def initialize(layers = [])
6
+ @layers = []
7
+
8
+ layers.each do |layer|
9
+ add(layer)
10
+ end
11
+ end
12
+
13
+ def add(layer)
14
+ @layers << layer
15
+ end
16
+
17
+ def compile(optimizer: nil, loss: nil, metrics: nil)
18
+ raise "Not implemented"
19
+ end
20
+
21
+ def fit(x, y, epochs: nil)
22
+ raise "Not implemented"
23
+ end
24
+
25
+ def evaluate(x, y)
26
+ raise "Not implemented"
27
+ end
28
+
29
+ def summary
30
+ sep = "_________________________________________________________________\n"
31
+
32
+ output_shape = nil
33
+ @layers.each do |layer|
34
+ layer.build(output_shape) if layer.respond_to?(:build)
35
+ output_shape = layer.output_shape
36
+ end
37
+
38
+ total_params = @layers.map(&:count_params).sum
39
+
40
+ summary = String.new("")
41
+ summary << "Model: \"sequential\"\n"
42
+ summary << sep
43
+ summary << "Layer (type) Output Shape Param # \n"
44
+ summary << "=================================================================\n"
45
+ summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
46
+ summary << "=================================================================\n"
47
+ summary << "Total params: #{total_params}\n"
48
+ summary << "Trainable params: #{total_params}\n"
49
+ summary << "Non-trainable params: 0\n"
50
+ summary << sep
51
+ puts summary
52
+ end
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,8 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Optimizers
4
+ class Adam
5
+ end
6
+ end
7
+ end
8
+ end
@@ -0,0 +1,22 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Preprocessing
4
+ module Image
5
+ class << self
6
+ def load_img(path, target_size: nil)
7
+ img = MiniMagick::Image.open(path)
8
+ if target_size
9
+ # TODO make resize consistent with Python
10
+ img.resize "#{target_size.map(&:to_i).join("x")}!", "-filter", "point"
11
+ end
12
+ img
13
+ end
14
+
15
+ def img_to_array(img)
16
+ Numo::SFloat.cast(img.get_pixels)
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,83 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Utils
4
+ class << self
5
+ def add_weight(name: nil, shape: [], initializer: nil, dtype: :float)
6
+ variable = Variable.new(shape: shape, name: name, dtype: dtype)
7
+ initial_value =
8
+ case initializer
9
+ when "zeros"
10
+ Tensorflow.fill(shape, 0.0)
11
+ when "glorot_uniform"
12
+ # TODO compute fans
13
+ fan_in = shape[0]
14
+ fan_out = shape[1]
15
+ scale = 1.0
16
+ scale /= [1.0, (fan_in + fan_out) / 2.0].max
17
+ limit = ::Math.sqrt(3.0 * scale)
18
+
19
+ minval = -limit
20
+ maxval = limit
21
+
22
+ rnd = RawOps.random_uniform(shape: shape, dtype: :float)
23
+ Math.add(rnd * (maxval - minval), minval)
24
+ else
25
+ raise Error, "Unknown initializer: #{initializer}"
26
+ end
27
+
28
+ variable.value = initial_value
29
+ variable
30
+ end
31
+
32
+ def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets")
33
+ # destination
34
+ # TODO handle this better
35
+ raise "No HOME" unless ENV["HOME"]
36
+ dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"
37
+ FileUtils.mkdir_p(File.dirname(dest))
38
+
39
+ return dest if File.exist?(dest)
40
+
41
+ temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
42
+ temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
43
+
44
+ digest = file_hash&.size == 32 ? Digest::MD5.new : Digest::SHA2.new
45
+
46
+ uri = URI(origin)
47
+
48
+ # Net::HTTP automatically adds Accept-Encoding for compression
49
+ # of response bodies and automatically decompresses gzip
50
+ # and deflateresponses unless a Range header was sent.
51
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
52
+ Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
53
+ request = Net::HTTP::Get.new(uri)
54
+
55
+ puts "Downloading data from #{origin}"
56
+ i = 0
57
+ File.open(temp_path, "wb") do |f|
58
+ http.request(request) do |response|
59
+ response.read_body do |chunk|
60
+ f.write(chunk)
61
+ digest.update(chunk)
62
+
63
+ # print progress
64
+ putc "." if i % 50 == 0
65
+ i += 1
66
+ end
67
+ end
68
+ puts # newline
69
+ end
70
+ end
71
+
72
+ if file_hash && digest.hexdigest != file_hash
73
+ raise Error, "Bad hash"
74
+ end
75
+
76
+ FileUtils.mv(temp_path, dest)
77
+
78
+ dest
79
+ end
80
+ end
81
+ end
82
+ end
83
+ end
@@ -0,0 +1,57 @@
1
+ require 'set'
2
+
3
+ module Tensorflow
4
+ class NameScope
5
+ attr_reader :stack, :names
6
+
7
+ def initialize
8
+ @stack = Array.new
9
+ @names = Set.new
10
+ end
11
+
12
+ def name_scope(base_name)
13
+ name = self.unique_name(base_name)
14
+ stack.push(name)
15
+
16
+ begin
17
+ yield current_scope if block_given?
18
+ ensure
19
+ stack.pop
20
+ end
21
+ end
22
+
23
+ def current_scope
24
+ if self.stack.last.nil?
25
+ nil
26
+ else
27
+ self.stack.join("/")
28
+ end
29
+ end
30
+
31
+ def scoped_name(name)
32
+ base_name = case
33
+ when self.stack.empty?
34
+ name
35
+ when self.stack.last.nil?
36
+ name
37
+ else
38
+ "#{self.current_scope}/#{name}"
39
+ end
40
+
41
+ self.unique_name(base_name)
42
+ end
43
+
44
+ def unique_name(name)
45
+ return nil unless name
46
+
47
+ i = 0
48
+ check_name = name
49
+ while self.names.include?(check_name.downcase)
50
+ i += 1
51
+ check_name = "#{name}_#{i}"
52
+ end
53
+ self.names << check_name.downcase unless check_name.nil?
54
+ check_name
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,49 @@
1
+ module Tensorflow
2
+ class OpDefBuilder
3
+ def self.unknown_shape_inference_func
4
+ @unknown_shape_func ||= FFI.ffi_libraries.first.find_function('TF_ShapeInferenceContextSetUnknownShape')
5
+ end
6
+
7
+ def self.finalize(pointer)
8
+ proc do
9
+ FFI::TF_DeleteOpDefinitionBuilder(pointer)
10
+ end
11
+ end
12
+
13
+ def initialize(name)
14
+ @pointer = FFI.TF_NewOpDefinitionBuilder(name)
15
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
16
+ end
17
+
18
+ def to_ptr
19
+ @pointer
20
+ end
21
+
22
+ def attr(spec)
23
+ FFI.TF_OpDefinitionBuilderAddAttr(self, spec)
24
+ self
25
+ end
26
+
27
+ def input(spec)
28
+ FFI.TF_OpDefinitionBuilderAddInput(self, spec)
29
+ self
30
+ end
31
+
32
+ def output(spec)
33
+ FFI.TF_OpDefinitionBuilderAddOutput(self, spec)
34
+ self
35
+ end
36
+
37
+ def shape_inference(func)
38
+ FFI.TF_OpDefinitionBuilderSetShapeInferenceFunction(self, func)
39
+ self
40
+ end
41
+
42
+ def register
43
+ ObjectSpace.undefine_finalizer(self)
44
+ Status.check do |status|
45
+ FFI.TF_RegisterOpDefinition(self, status)
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ module Audio
3
+ class << self
4
+ def decode_wav(contents, desired_channels: -1, desired_samples: -1)
5
+ RawOps.decode_wav(contents: contents, desired_channels: desired_channels, desired_samples: desired_samples)
6
+ end
7
+
8
+ def encode_wav(audio, sample_rate)
9
+ RawOps.encode_wav(audio: audio, sample_rate: sample_rate)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,29 @@
1
+ module Tensorflow
2
+ module Bitwise
3
+ class << self
4
+ def bitwise_and(x, y)
5
+ RawOps.bitwise_and(x, y)
6
+ end
7
+
8
+ def bitwise_or(x, y)
9
+ RawOps.bitwise_or(x, y)
10
+ end
11
+
12
+ def bitwise_xor(x, y)
13
+ RawOps.bitwise_xor(x, y)
14
+ end
15
+
16
+ def invert(x)
17
+ RawOps.invert(x: x)
18
+ end
19
+
20
+ def left_shift(x, y)
21
+ RawOps.left_shift(x, y)
22
+ end
23
+
24
+ def right_shift(x, y)
25
+ RawOps.right_shift(x, y)
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ module Control
3
+ def self.group(inputs, attrs={})
4
+ graph = ExecutionContext.current(inputs)
5
+ return if graph.is_a?(Eager::Context)
6
+
7
+ graph.control_dependencies(inputs) do
8
+ RawOps.no_op
9
+ end
10
+ end
11
+ end
12
+ end
13
+
@@ -0,0 +1,21 @@
1
+ module Tensorflow
2
+ module Ops
3
+ def self.broadcast_mul(vector, matrix)
4
+ vector = Tensorflow.expand_dims(vector, -1)
5
+ vector * matrix
6
+ end
7
+
8
+ Graph::Gradients.register('SparseSoftmaxCrossEntropyWithLogits') do |gradient, outputs, inputs|
9
+ message = <<~EOS
10
+ Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused
11
+ implementation's interaction with tf.gradients()
12
+ EOS
13
+
14
+ graph = gradient.graph
15
+ operation = outputs[0].operation
16
+ sparse_softmax_grad_without_gradient = Tensorflow.prevent_gradient(operation[1], message: message)
17
+ op = Ops.broadcast_mul(gradient, sparse_softmax_grad_without_gradient)
18
+ op.outputs
19
+ end
20
+ end
21
+ end