tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,25 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Layers
4
+ class Flatten
5
+ def initialize(input_shape: nil)
6
+ @input_shape = input_shape
7
+ end
8
+
9
+ def output_shape
10
+ flattened_dim = @input_shape.inject(&:*)
11
+ [-1, flattened_dim]
12
+ end
13
+
14
+ def count_params
15
+ 0
16
+ end
17
+
18
+ def call(inputs)
19
+ flattened_dim = inputs.shape[1..-1].inject(&:*)
20
+ Tensorflow.reshape(inputs, [-1, flattened_dim])
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,14 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Losses
4
+ class SparseCategoricalCrossentropy
5
+ def call(target, output)
6
+ output = Math.log(output)
7
+ target = Tensorflow.cast(target, :int64)
8
+ cost, _ = RawOps.sparse_softmax_cross_entropy_with_logits(features: output, labels: target)
9
+ Math.reduce_sum(cost) / Tensorflow.cast(RawOps.size(input: cost), :float)
10
+ end
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,30 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Metrics
4
+ class Mean
5
+ def initialize(name: nil, dtype: :float)
6
+ @dtype = dtype
7
+ @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
8
+ @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
9
+ end
10
+
11
+ def call(*args)
12
+ update_state(*args)
13
+ end
14
+
15
+ def update_state(values)
16
+ input = Tensorflow.cast(input, destination_dtype: @dtype)
17
+ @total.assign_add(Math.reduce_sum(input))
18
+ @count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
19
+ end
20
+
21
+ def result
22
+ RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
23
+ end
24
+
25
+ def reset_states
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,17 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Metrics
4
+ class SparseCategoricalAccuracy < Mean
5
+ def update_state(y_true, y_pred)
6
+ y_pred = RawOps.arg_max(y_pred, -1)
7
+
8
+ # if y_pred.dtype != y_true.dtype
9
+ # y_pred = Tensorflow.cast(y_pred, y_true.dtype)
10
+ # end
11
+
12
+ super(Math.equal(y_true, y_pred))
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,6 @@
1
+ module Tensorflow
2
+ module Keras
3
+ class Model
4
+ end
5
+ end
6
+ end
@@ -0,0 +1,56 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Models
4
+ class Sequential
5
+ def initialize(layers = [])
6
+ @layers = []
7
+
8
+ layers.each do |layer|
9
+ add(layer)
10
+ end
11
+ end
12
+
13
+ def add(layer)
14
+ @layers << layer
15
+ end
16
+
17
+ def compile(optimizer: nil, loss: nil, metrics: nil)
18
+ raise "Not implemented"
19
+ end
20
+
21
+ def fit(x, y, epochs: nil)
22
+ raise "Not implemented"
23
+ end
24
+
25
+ def evaluate(x, y)
26
+ raise "Not implemented"
27
+ end
28
+
29
+ def summary
30
+ sep = "_________________________________________________________________\n"
31
+
32
+ output_shape = nil
33
+ @layers.each do |layer|
34
+ layer.build(output_shape) if layer.respond_to?(:build)
35
+ output_shape = layer.output_shape
36
+ end
37
+
38
+ total_params = @layers.map(&:count_params).sum
39
+
40
+ summary = String.new("")
41
+ summary << "Model: \"sequential\"\n"
42
+ summary << sep
43
+ summary << "Layer (type) Output Shape Param # \n"
44
+ summary << "=================================================================\n"
45
+ summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
46
+ summary << "=================================================================\n"
47
+ summary << "Total params: #{total_params}\n"
48
+ summary << "Trainable params: #{total_params}\n"
49
+ summary << "Non-trainable params: 0\n"
50
+ summary << sep
51
+ puts summary
52
+ end
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,8 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Optimizers
4
+ class Adam
5
+ end
6
+ end
7
+ end
8
+ end
@@ -0,0 +1,22 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Preprocessing
4
+ module Image
5
+ class << self
6
+ def load_img(path, target_size: nil)
7
+ img = MiniMagick::Image.open(path)
8
+ if target_size
9
+ # TODO make resize consistent with Python
10
+ img.resize "#{target_size.map(&:to_i).join("x")}!", "-filter", "point"
11
+ end
12
+ img
13
+ end
14
+
15
+ def img_to_array(img)
16
+ Numo::SFloat.cast(img.get_pixels)
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,83 @@
1
+ module Tensorflow
2
+ module Keras
3
+ module Utils
4
+ class << self
5
+ def add_weight(name: nil, shape: [], initializer: nil, dtype: :float)
6
+ variable = Variable.new(shape: shape, name: name, dtype: dtype)
7
+ initial_value =
8
+ case initializer
9
+ when "zeros"
10
+ Tensorflow.fill(shape, 0.0)
11
+ when "glorot_uniform"
12
+ # TODO compute fans
13
+ fan_in = shape[0]
14
+ fan_out = shape[1]
15
+ scale = 1.0
16
+ scale /= [1.0, (fan_in + fan_out) / 2.0].max
17
+ limit = ::Math.sqrt(3.0 * scale)
18
+
19
+ minval = -limit
20
+ maxval = limit
21
+
22
+ rnd = RawOps.random_uniform(shape: shape, dtype: :float)
23
+ Math.add(rnd * (maxval - minval), minval)
24
+ else
25
+ raise Error, "Unknown initializer: #{initializer}"
26
+ end
27
+
28
+ variable.value = initial_value
29
+ variable
30
+ end
31
+
32
+ def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets")
33
+ # destination
34
+ # TODO handle this better
35
+ raise "No HOME" unless ENV["HOME"]
36
+ dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"
37
+ FileUtils.mkdir_p(File.dirname(dest))
38
+
39
+ return dest if File.exist?(dest)
40
+
41
+ temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
42
+ temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
43
+
44
+ digest = file_hash&.size == 32 ? Digest::MD5.new : Digest::SHA2.new
45
+
46
+ uri = URI(origin)
47
+
48
+ # Net::HTTP automatically adds Accept-Encoding for compression
49
+ # of response bodies and automatically decompresses gzip
50
+ # and deflateresponses unless a Range header was sent.
51
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
52
+ Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
53
+ request = Net::HTTP::Get.new(uri)
54
+
55
+ puts "Downloading data from #{origin}"
56
+ i = 0
57
+ File.open(temp_path, "wb") do |f|
58
+ http.request(request) do |response|
59
+ response.read_body do |chunk|
60
+ f.write(chunk)
61
+ digest.update(chunk)
62
+
63
+ # print progress
64
+ putc "." if i % 50 == 0
65
+ i += 1
66
+ end
67
+ end
68
+ puts # newline
69
+ end
70
+ end
71
+
72
+ if file_hash && digest.hexdigest != file_hash
73
+ raise Error, "Bad hash"
74
+ end
75
+
76
+ FileUtils.mv(temp_path, dest)
77
+
78
+ dest
79
+ end
80
+ end
81
+ end
82
+ end
83
+ end
@@ -0,0 +1,57 @@
1
+ require 'set'
2
+
3
+ module Tensorflow
4
+ class NameScope
5
+ attr_reader :stack, :names
6
+
7
+ def initialize
8
+ @stack = Array.new
9
+ @names = Set.new
10
+ end
11
+
12
+ def name_scope(base_name)
13
+ name = self.unique_name(base_name)
14
+ stack.push(name)
15
+
16
+ begin
17
+ yield current_scope if block_given?
18
+ ensure
19
+ stack.pop
20
+ end
21
+ end
22
+
23
+ def current_scope
24
+ if self.stack.last.nil?
25
+ nil
26
+ else
27
+ self.stack.join("/")
28
+ end
29
+ end
30
+
31
+ def scoped_name(name)
32
+ base_name = case
33
+ when self.stack.empty?
34
+ name
35
+ when self.stack.last.nil?
36
+ name
37
+ else
38
+ "#{self.current_scope}/#{name}"
39
+ end
40
+
41
+ self.unique_name(base_name)
42
+ end
43
+
44
+ def unique_name(name)
45
+ return nil unless name
46
+
47
+ i = 0
48
+ check_name = name
49
+ while self.names.include?(check_name.downcase)
50
+ i += 1
51
+ check_name = "#{name}_#{i}"
52
+ end
53
+ self.names << check_name.downcase unless check_name.nil?
54
+ check_name
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,49 @@
1
+ module Tensorflow
2
+ class OpDefBuilder
3
+ def self.unknown_shape_inference_func
4
+ @unknown_shape_func ||= FFI.ffi_libraries.first.find_function('TF_ShapeInferenceContextSetUnknownShape')
5
+ end
6
+
7
+ def self.finalize(pointer)
8
+ proc do
9
+ FFI::TF_DeleteOpDefinitionBuilder(pointer)
10
+ end
11
+ end
12
+
13
+ def initialize(name)
14
+ @pointer = FFI.TF_NewOpDefinitionBuilder(name)
15
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
16
+ end
17
+
18
+ def to_ptr
19
+ @pointer
20
+ end
21
+
22
+ def attr(spec)
23
+ FFI.TF_OpDefinitionBuilderAddAttr(self, spec)
24
+ self
25
+ end
26
+
27
+ def input(spec)
28
+ FFI.TF_OpDefinitionBuilderAddInput(self, spec)
29
+ self
30
+ end
31
+
32
+ def output(spec)
33
+ FFI.TF_OpDefinitionBuilderAddOutput(self, spec)
34
+ self
35
+ end
36
+
37
+ def shape_inference(func)
38
+ FFI.TF_OpDefinitionBuilderSetShapeInferenceFunction(self, func)
39
+ self
40
+ end
41
+
42
+ def register
43
+ ObjectSpace.undefine_finalizer(self)
44
+ Status.check do |status|
45
+ FFI.TF_RegisterOpDefinition(self, status)
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ module Audio
3
+ class << self
4
+ def decode_wav(contents, desired_channels: -1, desired_samples: -1)
5
+ RawOps.decode_wav(contents: contents, desired_channels: desired_channels, desired_samples: desired_samples)
6
+ end
7
+
8
+ def encode_wav(audio, sample_rate)
9
+ RawOps.encode_wav(audio: audio, sample_rate: sample_rate)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,29 @@
1
+ module Tensorflow
2
+ module Bitwise
3
+ class << self
4
+ def bitwise_and(x, y)
5
+ RawOps.bitwise_and(x, y)
6
+ end
7
+
8
+ def bitwise_or(x, y)
9
+ RawOps.bitwise_or(x, y)
10
+ end
11
+
12
+ def bitwise_xor(x, y)
13
+ RawOps.bitwise_xor(x, y)
14
+ end
15
+
16
+ def invert(x)
17
+ RawOps.invert(x: x)
18
+ end
19
+
20
+ def left_shift(x, y)
21
+ RawOps.left_shift(x, y)
22
+ end
23
+
24
+ def right_shift(x, y)
25
+ RawOps.right_shift(x, y)
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ module Control
3
+ def self.group(inputs, attrs={})
4
+ graph = ExecutionContext.current(inputs)
5
+ return if graph.is_a?(Eager::Context)
6
+
7
+ graph.control_dependencies(inputs) do
8
+ RawOps.no_op
9
+ end
10
+ end
11
+ end
12
+ end
13
+
@@ -0,0 +1,21 @@
1
+ module Tensorflow
2
+ module Ops
3
+ def self.broadcast_mul(vector, matrix)
4
+ vector = Tensorflow.expand_dims(vector, -1)
5
+ vector * matrix
6
+ end
7
+
8
+ Graph::Gradients.register('SparseSoftmaxCrossEntropyWithLogits') do |gradient, outputs, inputs|
9
+ message = <<~EOS
10
+ Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused
11
+ implementation's interaction with tf.gradients()
12
+ EOS
13
+
14
+ graph = gradient.graph
15
+ operation = outputs[0].operation
16
+ sparse_softmax_grad_without_gradient = Tensorflow.prevent_gradient(operation[1], message: message)
17
+ op = Ops.broadcast_mul(gradient, sparse_softmax_grad_without_gradient)
18
+ op.outputs
19
+ end
20
+ end
21
+ end