tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,32 @@
1
+ module Tensorflow
2
+ module TensorMixin
3
+ def shape
4
+ @shape ||= begin
5
+ status = Status.new
6
+ shape = []
7
+ if self
8
+ num_dims.times do |i|
9
+ shape << dim(i)
10
+ status.check
11
+ end
12
+ end
13
+ shape
14
+ end
15
+ end
16
+
17
+ def numo
18
+ case dtype
19
+ when NilClass
20
+ nil
21
+ when :variant
22
+ :variant
23
+ when :string
24
+ :string
25
+ else
26
+ klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
27
+ raise "Unknown type: #{dtype}" unless klass
28
+ klass.cast(value)
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,10 @@
1
+ module Tensorflow
2
+ class TensorSpec < BatchableTypeSpec
3
+ attr_reader :shape, :dtype
4
+
5
+ def initialize(shape, dtype)
6
+ @shape = shape
7
+ @dtype = dtype
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,93 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: tensorflow/core/util/event.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ require 'tensorflow/core/framework/summary_pb'
7
+ Google::Protobuf::DescriptorPool.generated_pool.build do
8
+ add_file("tensorflow/core/util/event.proto", :syntax => :proto3) do
9
+ add_message "tensorflow.Event" do
10
+ optional :wall_time, :double, 1
11
+ optional :step, :int64, 2
12
+ oneof :what do
13
+ optional :file_version, :string, 3
14
+ optional :graph_def, :bytes, 4
15
+ optional :summary, :message, 5, "tensorflow.Summary"
16
+ optional :log_message, :message, 6, "tensorflow.LogMessage"
17
+ optional :session_log, :message, 7, "tensorflow.SessionLog"
18
+ optional :tagged_run_metadata, :message, 8, "tensorflow.TaggedRunMetadata"
19
+ optional :meta_graph_def, :bytes, 9
20
+ end
21
+ end
22
+ add_message "tensorflow.LogMessage" do
23
+ optional :level, :enum, 1, "tensorflow.LogMessage.Level"
24
+ optional :message, :string, 2
25
+ end
26
+ add_enum "tensorflow.LogMessage.Level" do
27
+ value :UNKNOWN, 0
28
+ value :DEBUGGING, 10
29
+ value :INFO, 20
30
+ value :WARN, 30
31
+ value :ERROR, 40
32
+ value :FATAL, 50
33
+ end
34
+ add_message "tensorflow.SessionLog" do
35
+ optional :status, :enum, 1, "tensorflow.SessionLog.SessionStatus"
36
+ optional :checkpoint_path, :string, 2
37
+ optional :msg, :string, 3
38
+ end
39
+ add_enum "tensorflow.SessionLog.SessionStatus" do
40
+ value :STATUS_UNSPECIFIED, 0
41
+ value :START, 1
42
+ value :STOP, 2
43
+ value :CHECKPOINT, 3
44
+ end
45
+ add_message "tensorflow.TaggedRunMetadata" do
46
+ optional :tag, :string, 1
47
+ optional :run_metadata, :bytes, 2
48
+ end
49
+ add_message "tensorflow.WatchdogConfig" do
50
+ optional :timeout_ms, :int64, 1
51
+ end
52
+ add_message "tensorflow.RequestedExitCode" do
53
+ optional :exit_code, :int32, 1
54
+ end
55
+ add_message "tensorflow.WorkerHeartbeatRequest" do
56
+ optional :shutdown_mode, :enum, 1, "tensorflow.WorkerShutdownMode"
57
+ optional :watchdog_config, :message, 2, "tensorflow.WatchdogConfig"
58
+ optional :exit_code, :message, 3, "tensorflow.RequestedExitCode"
59
+ end
60
+ add_message "tensorflow.WorkerHeartbeatResponse" do
61
+ optional :health_status, :enum, 1, "tensorflow.WorkerHealth"
62
+ repeated :worker_log, :message, 2, "tensorflow.Event"
63
+ optional :hostname, :string, 3
64
+ end
65
+ add_enum "tensorflow.WorkerHealth" do
66
+ value :OK, 0
67
+ value :RECEIVED_SHUTDOWN_SIGNAL, 1
68
+ value :INTERNAL_ERROR, 2
69
+ value :SHUTTING_DOWN, 3
70
+ end
71
+ add_enum "tensorflow.WorkerShutdownMode" do
72
+ value :DEFAULT, 0
73
+ value :NOT_CONFIGURED, 1
74
+ value :WAIT_FOR_COORDINATOR, 2
75
+ value :SHUTDOWN_AFTER_TIMEOUT, 3
76
+ end
77
+ end
78
+ end
79
+
80
+ module Tensorflow
81
+ Event = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.Event").msgclass
82
+ LogMessage = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage").msgclass
83
+ LogMessage::Level = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage.Level").enummodule
84
+ SessionLog = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog").msgclass
85
+ SessionLog::SessionStatus = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog.SessionStatus").enummodule
86
+ TaggedRunMetadata = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.TaggedRunMetadata").msgclass
87
+ WatchdogConfig = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WatchdogConfig").msgclass
88
+ RequestedExitCode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.RequestedExitCode").msgclass
89
+ WorkerHeartbeatRequest = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatRequest").msgclass
90
+ WorkerHeartbeatResponse = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatResponse").msgclass
91
+ WorkerHealth = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHealth").enummodule
92
+ WorkerShutdownMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerShutdownMode").enummodule
93
+ end
@@ -0,0 +1,33 @@
1
+ # Based on code from https://github.com/jedld/tensor_stream
2
+
3
+ module Tensorflow
4
+ module Train
5
+ class GradientDescentOptimizer < Optimizer
6
+ attr_accessor :learning_rate
7
+
8
+ def initialize(learning_rate, use_locking: false, name: "GradientDescent")
9
+ @learning_rate = learning_rate
10
+ @learning_rate_tensor = nil
11
+ super(name: name, use_locking: use_locking)
12
+ end
13
+
14
+ protected
15
+
16
+ def prepare
17
+ learning_rate = call_if_callable(@learning_rate)
18
+ @learning_rate_tensor = Tensorflow.constant(learning_rate, name: "learning_rate")
19
+ end
20
+
21
+ def apply_dense(grad, var)
22
+ dtype = grad.output_types.first
23
+ learning_rate = if @learning_rate_tensor.output_types.first == dtype
24
+ @learning_rate_tensor
25
+ else
26
+ Tensorflow.cast(@learning_rate_tensor, dtype)
27
+ end
28
+
29
+ RawOps.resource_apply_gradient_descent(var, learning_rate, grad)
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,158 @@
1
+ # Based on code from https://github.com/jedld/tensor_stream
2
+ module Tensorflow
3
+ module Train
4
+ class Optimizer
5
+ attr_reader :name
6
+
7
+ def initialize(name: nil, use_locking: false)
8
+ @name = name
9
+ @use_locking = use_locking
10
+ raise(Error::InvalidArgumentError, "Must specify the optimizer name") unless name
11
+
12
+ @slots = {}
13
+ @non_slots = {}
14
+ end
15
+
16
+ def graph
17
+ ExecutionContext.current
18
+ end
19
+
20
+ def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil)
21
+ grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss)
22
+ if grads_and_vars.empty?
23
+ raise(Error::InvalidArgumentError, "No gradients provided for any variable, check your graph for ops that do not support gradients")
24
+ end
25
+ apply_gradients(grads_and_vars, global_step: global_step, name: name)
26
+ end
27
+
28
+ def apply_gradients(grads_and_vars, global_step: nil, name: nil)
29
+ varlist = grads_and_vars.map { |_grad, var| var }
30
+ #create_slots(varlist)
31
+ #TensorStream.name_scope(name, default: @name) do
32
+ prepare
33
+ apply_ops = grads_and_vars.map do |grad, var|
34
+ #TensorStream.name_scope("update_" + var.op.name) do
35
+ apply_dense(grad, var)
36
+ #end
37
+ end
38
+
39
+ if global_step.nil?
40
+ finish(apply_ops, name)
41
+ else
42
+ global_step.handle.graph.control_dependencies([finish(apply_ops, "update")]) do
43
+ global_step.assign_add(Tensorflow.constant(1, dtype:global_step.dtype))
44
+ end
45
+ end
46
+ #end
47
+ end
48
+
49
+ def compute_gradients(loss, var_list: nil, grad_loss: nil)
50
+ trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES)
51
+
52
+ if trainable_vars.nil? || trainable_vars.empty?
53
+ raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function')
54
+ end
55
+ gradients = Graph::Gradients.new(graph)
56
+ grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss)
57
+
58
+ grads.zip(trainable_vars)
59
+ end
60
+
61
+ def get_slot(var, name)
62
+ named_slots = @slots.fetch(name, nil)
63
+ return nil if named_slots.nil?
64
+
65
+ named_slots.fetch(var_key(var), nil)
66
+ end
67
+
68
+ def get_slot_names
69
+ @slots.keys.sort
70
+ end
71
+
72
+ protected
73
+
74
+ def finish(update_ops, name_scope)
75
+ Control.group(update_ops, name: name_scope)
76
+ end
77
+
78
+ def create_slots(var_list)
79
+ # no implementation
80
+ end
81
+
82
+ def prepare
83
+ # no implementation
84
+ end
85
+
86
+ def apply_dense(_grad, _var)
87
+ raise(Error::UnimplementedError, "Not implemented")
88
+ end
89
+
90
+ ##
91
+ # Find or create a slot initialized with 0.0.
92
+ #
93
+ # Args:
94
+ # var: Variable - A Variable object
95
+ # slot_name: string - Name for the slot
96
+ # op_name: string - Name to use when scoping the Variable that needs to be created
97
+ def zeros_slot(var, slot_name, op_name)
98
+ named_slots = slot_dict(slot_name)
99
+ unless named_slots.key?(var_key(var))
100
+ named_slots[var_key(var)] = create_zeros_slot(var, op_name)
101
+ end
102
+ named_slots[var_key(var)]
103
+ end
104
+
105
+ ##
106
+ # Returns a dict for caching slots created under the given name.
107
+ #
108
+ # Args:
109
+ # slot_name string Name for the slot
110
+ #
111
+ # Returns: A dict that maps primary 'Variable' objects to the slot created
112
+ def slot_dict(slot_name)
113
+ named_slots = @slots.fetch(slot_name, nil)
114
+ if named_slots.nil?
115
+ named_slots = {}
116
+ @slots[slot_name] = named_slots
117
+ end
118
+ named_slots
119
+ end
120
+
121
+ def var_key(var)
122
+ [var.op.graph, var.op.name]
123
+ end
124
+
125
+ def get_non_slot_variable(name, graph: nil)
126
+ non_slot = @non_slots.fetch([name, graph], nil)
127
+ non_slot
128
+ end
129
+
130
+ def call_if_callable(param)
131
+ param.is_a?(Proc) ? param.call : param
132
+ end
133
+
134
+ def create_non_slot_variable(initial_value, name, colocate_with)
135
+ graph = colocate_with.graph
136
+
137
+ key = [name, graph]
138
+ v = @non_slots.fetch(key, nil)
139
+ if v.nil?
140
+ v = TensorStream.variable(initial_value, name: name, trainable: false)
141
+ @non_slots[key] = v
142
+ end
143
+ v
144
+ end
145
+
146
+ ##
147
+ # Find or create a slot for a variable, using an Initializer.
148
+ def get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name)
149
+ named_slots = slot_dict(slot_name)
150
+ unless named_slots.key?(var_key(var))
151
+ new_slot_variable = create_slot_with_initializer(var, initializer, shape, dtype, op_name)
152
+ named_slots[var_key(var)] = new_slot_variable
153
+ end
154
+ named_slots[var_key(var)]
155
+ end
156
+ end
157
+ end
158
+ end
@@ -0,0 +1,4 @@
1
+ module Tensorflow
2
+ class TypeSpec
3
+ end
4
+ end
@@ -0,0 +1,127 @@
1
+ module Tensorflow
2
+ class Variable
3
+ include Operators
4
+
5
+ attr_reader :handle, :dtype, :name
6
+
7
+ def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
8
+ initial_value = case initial_value
9
+ when NilClass
10
+ @dtype = dtype
11
+ shape = []
12
+ initial_value
13
+ when Graph::Operation
14
+ @dtype = dtype || initial_value.dtype
15
+ shape = shape || initial_value.output_shapes.first
16
+ initial_value
17
+ when Tensor
18
+ @dtype = initial_value.dtype
19
+ shape = shape || initial_value.shape
20
+ initial_value
21
+ else
22
+ tensor = Tensor.from_value(initial_value, dtype: dtype)
23
+ @dtype = tensor.dtype
24
+ shape = tensor.shape
25
+ tensor
26
+ end
27
+
28
+ name = name&.to_s
29
+ shared_name = shared_name&.to_s
30
+ unique_name = ExecutionContext.current.unique_name(name || shared_name)
31
+ shared_name ||= unique_name
32
+ @name = unique_name
33
+
34
+ collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
35
+ if trainable
36
+ collections << Graph::GraphKeys::TRAINABLE_VARIABLES
37
+ end
38
+
39
+ ExecutionContext.current.add_to_collections(collections, self)
40
+
41
+ @handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
42
+ self.value = initial_value if initial_value
43
+ end
44
+
45
+ def value_handle
46
+ @value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
47
+ end
48
+
49
+ def value
50
+ case value_handle
51
+ when Eager::TensorHandle
52
+ value_handle.value
53
+ when Graph::Operation
54
+ value_handle
55
+ end
56
+ end
57
+
58
+ def value=(value)
59
+ @initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
60
+ end
61
+
62
+ def initializer
63
+ @initializer
64
+ end
65
+
66
+ def initialized?
67
+ RawOps.var_is_initialized_op(self.handle)
68
+ end
69
+
70
+ # These methods match the operation api to enable gradients and sessions
71
+ def consumers
72
+ self.handle.consumers
73
+ end
74
+
75
+ # This enables executing variables to get the values in a session
76
+ def outputs
77
+ [Graph::OperationOutput.from_index(self.value_handle, 0)]
78
+ end
79
+
80
+ def to_ptr
81
+ self.handle.to_ptr
82
+ end
83
+
84
+ def shape
85
+ self.value_handle.shape
86
+ end
87
+
88
+ def tensor
89
+ raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
90
+ self.value_handle.tensor
91
+ end
92
+
93
+ def rank
94
+ self.shape.size
95
+ end
96
+
97
+ def reshape(shape)
98
+ RawOps.reshape(self, shape)
99
+ end
100
+
101
+ def assign_add(value, dtype: nil)
102
+ @value_handle = nil
103
+ tensor = Tensor.from_value(value, dtype: dtype)
104
+ tensor = Tensorflow.cast(tensor, self.dtype)
105
+ RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
106
+ end
107
+
108
+ def assign_sub(value)
109
+ @value_handle = nil
110
+ tensor = Tensor.from_value(value, dtype: dtype)
111
+ tensor = Tensorflow.cast(tensor, self.dtype)
112
+ RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype)
113
+ end
114
+
115
+ def to_s
116
+ inspect
117
+ end
118
+
119
+ def inspect
120
+ inspection = []
121
+ inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
122
+ inspection << ["shape: #{self.value_handle.shape}"]
123
+ inspection << ["dtype: #{self.value_handle.dtype}"]
124
+ "#<#{self.class} #{inspection.join(", ")}>"
125
+ end
126
+ end
127
+ end