tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,32 @@
1
+ module Tensorflow
2
+ module TensorMixin
3
+ def shape
4
+ @shape ||= begin
5
+ status = Status.new
6
+ shape = []
7
+ if self
8
+ num_dims.times do |i|
9
+ shape << dim(i)
10
+ status.check
11
+ end
12
+ end
13
+ shape
14
+ end
15
+ end
16
+
17
+ def numo
18
+ case dtype
19
+ when NilClass
20
+ nil
21
+ when :variant
22
+ :variant
23
+ when :string
24
+ :string
25
+ else
26
+ klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
27
+ raise "Unknown type: #{dtype}" unless klass
28
+ klass.cast(value)
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,10 @@
1
+ module Tensorflow
2
+ class TensorSpec < BatchableTypeSpec
3
+ attr_reader :shape, :dtype
4
+
5
+ def initialize(shape, dtype)
6
+ @shape = shape
7
+ @dtype = dtype
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,93 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: tensorflow/core/util/event.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ require 'tensorflow/core/framework/summary_pb'
7
+ Google::Protobuf::DescriptorPool.generated_pool.build do
8
+ add_file("tensorflow/core/util/event.proto", :syntax => :proto3) do
9
+ add_message "tensorflow.Event" do
10
+ optional :wall_time, :double, 1
11
+ optional :step, :int64, 2
12
+ oneof :what do
13
+ optional :file_version, :string, 3
14
+ optional :graph_def, :bytes, 4
15
+ optional :summary, :message, 5, "tensorflow.Summary"
16
+ optional :log_message, :message, 6, "tensorflow.LogMessage"
17
+ optional :session_log, :message, 7, "tensorflow.SessionLog"
18
+ optional :tagged_run_metadata, :message, 8, "tensorflow.TaggedRunMetadata"
19
+ optional :meta_graph_def, :bytes, 9
20
+ end
21
+ end
22
+ add_message "tensorflow.LogMessage" do
23
+ optional :level, :enum, 1, "tensorflow.LogMessage.Level"
24
+ optional :message, :string, 2
25
+ end
26
+ add_enum "tensorflow.LogMessage.Level" do
27
+ value :UNKNOWN, 0
28
+ value :DEBUGGING, 10
29
+ value :INFO, 20
30
+ value :WARN, 30
31
+ value :ERROR, 40
32
+ value :FATAL, 50
33
+ end
34
+ add_message "tensorflow.SessionLog" do
35
+ optional :status, :enum, 1, "tensorflow.SessionLog.SessionStatus"
36
+ optional :checkpoint_path, :string, 2
37
+ optional :msg, :string, 3
38
+ end
39
+ add_enum "tensorflow.SessionLog.SessionStatus" do
40
+ value :STATUS_UNSPECIFIED, 0
41
+ value :START, 1
42
+ value :STOP, 2
43
+ value :CHECKPOINT, 3
44
+ end
45
+ add_message "tensorflow.TaggedRunMetadata" do
46
+ optional :tag, :string, 1
47
+ optional :run_metadata, :bytes, 2
48
+ end
49
+ add_message "tensorflow.WatchdogConfig" do
50
+ optional :timeout_ms, :int64, 1
51
+ end
52
+ add_message "tensorflow.RequestedExitCode" do
53
+ optional :exit_code, :int32, 1
54
+ end
55
+ add_message "tensorflow.WorkerHeartbeatRequest" do
56
+ optional :shutdown_mode, :enum, 1, "tensorflow.WorkerShutdownMode"
57
+ optional :watchdog_config, :message, 2, "tensorflow.WatchdogConfig"
58
+ optional :exit_code, :message, 3, "tensorflow.RequestedExitCode"
59
+ end
60
+ add_message "tensorflow.WorkerHeartbeatResponse" do
61
+ optional :health_status, :enum, 1, "tensorflow.WorkerHealth"
62
+ repeated :worker_log, :message, 2, "tensorflow.Event"
63
+ optional :hostname, :string, 3
64
+ end
65
+ add_enum "tensorflow.WorkerHealth" do
66
+ value :OK, 0
67
+ value :RECEIVED_SHUTDOWN_SIGNAL, 1
68
+ value :INTERNAL_ERROR, 2
69
+ value :SHUTTING_DOWN, 3
70
+ end
71
+ add_enum "tensorflow.WorkerShutdownMode" do
72
+ value :DEFAULT, 0
73
+ value :NOT_CONFIGURED, 1
74
+ value :WAIT_FOR_COORDINATOR, 2
75
+ value :SHUTDOWN_AFTER_TIMEOUT, 3
76
+ end
77
+ end
78
+ end
79
+
80
+ module Tensorflow
81
+ Event = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.Event").msgclass
82
+ LogMessage = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage").msgclass
83
+ LogMessage::Level = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage.Level").enummodule
84
+ SessionLog = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog").msgclass
85
+ SessionLog::SessionStatus = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog.SessionStatus").enummodule
86
+ TaggedRunMetadata = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.TaggedRunMetadata").msgclass
87
+ WatchdogConfig = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WatchdogConfig").msgclass
88
+ RequestedExitCode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.RequestedExitCode").msgclass
89
+ WorkerHeartbeatRequest = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatRequest").msgclass
90
+ WorkerHeartbeatResponse = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatResponse").msgclass
91
+ WorkerHealth = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHealth").enummodule
92
+ WorkerShutdownMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerShutdownMode").enummodule
93
+ end
@@ -0,0 +1,33 @@
1
+ # Based on code from https://github.com/jedld/tensor_stream
2
+
3
+ module Tensorflow
4
+ module Train
5
+ class GradientDescentOptimizer < Optimizer
6
+ attr_accessor :learning_rate
7
+
8
+ def initialize(learning_rate, use_locking: false, name: "GradientDescent")
9
+ @learning_rate = learning_rate
10
+ @learning_rate_tensor = nil
11
+ super(name: name, use_locking: use_locking)
12
+ end
13
+
14
+ protected
15
+
16
+ def prepare
17
+ learning_rate = call_if_callable(@learning_rate)
18
+ @learning_rate_tensor = Tensorflow.constant(learning_rate, name: "learning_rate")
19
+ end
20
+
21
+ def apply_dense(grad, var)
22
+ dtype = grad.output_types.first
23
+ learning_rate = if @learning_rate_tensor.output_types.first == dtype
24
+ @learning_rate_tensor
25
+ else
26
+ Tensorflow.cast(@learning_rate_tensor, dtype)
27
+ end
28
+
29
+ RawOps.resource_apply_gradient_descent(var, learning_rate, grad)
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,158 @@
1
+ # Based on code from https://github.com/jedld/tensor_stream
2
+ module Tensorflow
3
+ module Train
4
+ class Optimizer
5
+ attr_reader :name
6
+
7
+ def initialize(name: nil, use_locking: false)
8
+ @name = name
9
+ @use_locking = use_locking
10
+ raise(Error::InvalidArgumentError, "Must specify the optimizer name") unless name
11
+
12
+ @slots = {}
13
+ @non_slots = {}
14
+ end
15
+
16
+ def graph
17
+ ExecutionContext.current
18
+ end
19
+
20
+ def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil)
21
+ grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss)
22
+ if grads_and_vars.empty?
23
+ raise(Error::InvalidArgumentError, "No gradients provided for any variable, check your graph for ops that do not support gradients")
24
+ end
25
+ apply_gradients(grads_and_vars, global_step: global_step, name: name)
26
+ end
27
+
28
+ def apply_gradients(grads_and_vars, global_step: nil, name: nil)
29
+ varlist = grads_and_vars.map { |_grad, var| var }
30
+ #create_slots(varlist)
31
+ #TensorStream.name_scope(name, default: @name) do
32
+ prepare
33
+ apply_ops = grads_and_vars.map do |grad, var|
34
+ #TensorStream.name_scope("update_" + var.op.name) do
35
+ apply_dense(grad, var)
36
+ #end
37
+ end
38
+
39
+ if global_step.nil?
40
+ finish(apply_ops, name)
41
+ else
42
+ global_step.handle.graph.control_dependencies([finish(apply_ops, "update")]) do
43
+ global_step.assign_add(Tensorflow.constant(1, dtype:global_step.dtype))
44
+ end
45
+ end
46
+ #end
47
+ end
48
+
49
+ def compute_gradients(loss, var_list: nil, grad_loss: nil)
50
+ trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES)
51
+
52
+ if trainable_vars.nil? || trainable_vars.empty?
53
+ raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function')
54
+ end
55
+ gradients = Graph::Gradients.new(graph)
56
+ grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss)
57
+
58
+ grads.zip(trainable_vars)
59
+ end
60
+
61
+ def get_slot(var, name)
62
+ named_slots = @slots.fetch(name, nil)
63
+ return nil if named_slots.nil?
64
+
65
+ named_slots.fetch(var_key(var), nil)
66
+ end
67
+
68
+ def get_slot_names
69
+ @slots.keys.sort
70
+ end
71
+
72
+ protected
73
+
74
+ def finish(update_ops, name_scope)
75
+ Control.group(update_ops, name: name_scope)
76
+ end
77
+
78
+ def create_slots(var_list)
79
+ # no implementation
80
+ end
81
+
82
+ def prepare
83
+ # no implementation
84
+ end
85
+
86
+ def apply_dense(_grad, _var)
87
+ raise(Error::UnimplementedError, "Not implemented")
88
+ end
89
+
90
+ ##
91
+ # Find or create a slot initialized with 0.0.
92
+ #
93
+ # Args:
94
+ # var: Variable - A Variable object
95
+ # slot_name: string - Name for the slot
96
+ # op_name: string - Name to use when scoping the Variable that needs to be created
97
+ def zeros_slot(var, slot_name, op_name)
98
+ named_slots = slot_dict(slot_name)
99
+ unless named_slots.key?(var_key(var))
100
+ named_slots[var_key(var)] = create_zeros_slot(var, op_name)
101
+ end
102
+ named_slots[var_key(var)]
103
+ end
104
+
105
+ ##
106
+ # Returns a dict for caching slots created under the given name.
107
+ #
108
+ # Args:
109
+ # slot_name string Name for the slot
110
+ #
111
+ # Returns: A dict that maps primary 'Variable' objects to the slot created
112
+ def slot_dict(slot_name)
113
+ named_slots = @slots.fetch(slot_name, nil)
114
+ if named_slots.nil?
115
+ named_slots = {}
116
+ @slots[slot_name] = named_slots
117
+ end
118
+ named_slots
119
+ end
120
+
121
+ def var_key(var)
122
+ [var.op.graph, var.op.name]
123
+ end
124
+
125
+ def get_non_slot_variable(name, graph: nil)
126
+ non_slot = @non_slots.fetch([name, graph], nil)
127
+ non_slot
128
+ end
129
+
130
+ def call_if_callable(param)
131
+ param.is_a?(Proc) ? param.call : param
132
+ end
133
+
134
+ def create_non_slot_variable(initial_value, name, colocate_with)
135
+ graph = colocate_with.graph
136
+
137
+ key = [name, graph]
138
+ v = @non_slots.fetch(key, nil)
139
+ if v.nil?
140
+ v = TensorStream.variable(initial_value, name: name, trainable: false)
141
+ @non_slots[key] = v
142
+ end
143
+ v
144
+ end
145
+
146
+ ##
147
+ # Find or create a slot for a variable, using an Initializer.
148
+ def get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name)
149
+ named_slots = slot_dict(slot_name)
150
+ unless named_slots.key?(var_key(var))
151
+ new_slot_variable = create_slot_with_initializer(var, initializer, shape, dtype, op_name)
152
+ named_slots[var_key(var)] = new_slot_variable
153
+ end
154
+ named_slots[var_key(var)]
155
+ end
156
+ end
157
+ end
158
+ end
@@ -0,0 +1,4 @@
1
+ module Tensorflow
2
+ class TypeSpec
3
+ end
4
+ end
@@ -0,0 +1,127 @@
1
+ module Tensorflow
2
+ class Variable
3
+ include Operators
4
+
5
+ attr_reader :handle, :dtype, :name
6
+
7
+ def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
8
+ initial_value = case initial_value
9
+ when NilClass
10
+ @dtype = dtype
11
+ shape = []
12
+ initial_value
13
+ when Graph::Operation
14
+ @dtype = dtype || initial_value.dtype
15
+ shape = shape || initial_value.output_shapes.first
16
+ initial_value
17
+ when Tensor
18
+ @dtype = initial_value.dtype
19
+ shape = shape || initial_value.shape
20
+ initial_value
21
+ else
22
+ tensor = Tensor.from_value(initial_value, dtype: dtype)
23
+ @dtype = tensor.dtype
24
+ shape = tensor.shape
25
+ tensor
26
+ end
27
+
28
+ name = name&.to_s
29
+ shared_name = shared_name&.to_s
30
+ unique_name = ExecutionContext.current.unique_name(name || shared_name)
31
+ shared_name ||= unique_name
32
+ @name = unique_name
33
+
34
+ collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
35
+ if trainable
36
+ collections << Graph::GraphKeys::TRAINABLE_VARIABLES
37
+ end
38
+
39
+ ExecutionContext.current.add_to_collections(collections, self)
40
+
41
+ @handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
42
+ self.value = initial_value if initial_value
43
+ end
44
+
45
+ def value_handle
46
+ @value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
47
+ end
48
+
49
+ def value
50
+ case value_handle
51
+ when Eager::TensorHandle
52
+ value_handle.value
53
+ when Graph::Operation
54
+ value_handle
55
+ end
56
+ end
57
+
58
+ def value=(value)
59
+ @initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
60
+ end
61
+
62
+ def initializer
63
+ @initializer
64
+ end
65
+
66
+ def initialized?
67
+ RawOps.var_is_initialized_op(self.handle)
68
+ end
69
+
70
+ # These methods match the operation api to enable gradients and sessions
71
+ def consumers
72
+ self.handle.consumers
73
+ end
74
+
75
+ # This enables executing variables to get the values in a session
76
+ def outputs
77
+ [Graph::OperationOutput.from_index(self.value_handle, 0)]
78
+ end
79
+
80
+ def to_ptr
81
+ self.handle.to_ptr
82
+ end
83
+
84
+ def shape
85
+ self.value_handle.shape
86
+ end
87
+
88
+ def tensor
89
+ raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
90
+ self.value_handle.tensor
91
+ end
92
+
93
+ def rank
94
+ self.shape.size
95
+ end
96
+
97
+ def reshape(shape)
98
+ RawOps.reshape(self, shape)
99
+ end
100
+
101
+ def assign_add(value, dtype: nil)
102
+ @value_handle = nil
103
+ tensor = Tensor.from_value(value, dtype: dtype)
104
+ tensor = Tensorflow.cast(tensor, self.dtype)
105
+ RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
106
+ end
107
+
108
+ def assign_sub(value)
109
+ @value_handle = nil
110
+ tensor = Tensor.from_value(value, dtype: dtype)
111
+ tensor = Tensorflow.cast(tensor, self.dtype)
112
+ RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype)
113
+ end
114
+
115
+ def to_s
116
+ inspect
117
+ end
118
+
119
+ def inspect
120
+ inspection = []
121
+ inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
122
+ inspection << ["shape: #{self.value_handle.shape}"]
123
+ inspection << ["dtype: #{self.value_handle.dtype}"]
124
+ "#<#{self.class} #{inspection.join(", ")}>"
125
+ end
126
+ end
127
+ end