tensorbored 2.21.0rc1769983804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (271) hide show
  1. tensorbored/__init__.py +112 -0
  2. tensorbored/_vendor/__init__.py +0 -0
  3. tensorbored/_vendor/bleach/__init__.py +125 -0
  4. tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
  5. tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
  6. tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
  7. tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
  8. tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
  9. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
  10. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
  11. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
  12. tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
  13. tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
  14. tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
  15. tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
  16. tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
  17. tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
  18. tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
  19. tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
  20. tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
  21. tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
  22. tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
  23. tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
  24. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
  25. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
  26. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
  27. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
  28. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
  29. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
  30. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
  31. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
  32. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
  33. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
  34. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
  35. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
  36. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
  37. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
  38. tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
  39. tensorbored/_vendor/bleach/callbacks.py +32 -0
  40. tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
  41. tensorbored/_vendor/bleach/linkifier.py +633 -0
  42. tensorbored/_vendor/bleach/parse_shim.py +1 -0
  43. tensorbored/_vendor/bleach/sanitizer.py +638 -0
  44. tensorbored/_vendor/bleach/six_shim.py +19 -0
  45. tensorbored/_vendor/webencodings/__init__.py +342 -0
  46. tensorbored/_vendor/webencodings/labels.py +231 -0
  47. tensorbored/_vendor/webencodings/mklabels.py +59 -0
  48. tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
  49. tensorbored/assets.py +36 -0
  50. tensorbored/auth.py +102 -0
  51. tensorbored/backend/__init__.py +0 -0
  52. tensorbored/backend/application.py +604 -0
  53. tensorbored/backend/auth_context_middleware.py +38 -0
  54. tensorbored/backend/client_feature_flags.py +113 -0
  55. tensorbored/backend/empty_path_redirect.py +46 -0
  56. tensorbored/backend/event_processing/__init__.py +0 -0
  57. tensorbored/backend/event_processing/data_ingester.py +276 -0
  58. tensorbored/backend/event_processing/data_provider.py +535 -0
  59. tensorbored/backend/event_processing/directory_loader.py +142 -0
  60. tensorbored/backend/event_processing/directory_watcher.py +272 -0
  61. tensorbored/backend/event_processing/event_accumulator.py +950 -0
  62. tensorbored/backend/event_processing/event_file_inspector.py +463 -0
  63. tensorbored/backend/event_processing/event_file_loader.py +292 -0
  64. tensorbored/backend/event_processing/event_multiplexer.py +521 -0
  65. tensorbored/backend/event_processing/event_util.py +68 -0
  66. tensorbored/backend/event_processing/io_wrapper.py +223 -0
  67. tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
  68. tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
  69. tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
  70. tensorbored/backend/event_processing/reservoir.py +266 -0
  71. tensorbored/backend/event_processing/tag_types.py +29 -0
  72. tensorbored/backend/experiment_id.py +71 -0
  73. tensorbored/backend/experimental_plugin.py +51 -0
  74. tensorbored/backend/http_util.py +263 -0
  75. tensorbored/backend/json_util.py +70 -0
  76. tensorbored/backend/path_prefix.py +67 -0
  77. tensorbored/backend/process_graph.py +74 -0
  78. tensorbored/backend/security_validator.py +202 -0
  79. tensorbored/compat/__init__.py +69 -0
  80. tensorbored/compat/proto/__init__.py +0 -0
  81. tensorbored/compat/proto/allocation_description_pb2.py +35 -0
  82. tensorbored/compat/proto/api_def_pb2.py +82 -0
  83. tensorbored/compat/proto/attr_value_pb2.py +80 -0
  84. tensorbored/compat/proto/cluster_pb2.py +58 -0
  85. tensorbored/compat/proto/config_pb2.py +271 -0
  86. tensorbored/compat/proto/coordination_config_pb2.py +45 -0
  87. tensorbored/compat/proto/cost_graph_pb2.py +87 -0
  88. tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
  89. tensorbored/compat/proto/debug_pb2.py +65 -0
  90. tensorbored/compat/proto/event_pb2.py +149 -0
  91. tensorbored/compat/proto/full_type_pb2.py +74 -0
  92. tensorbored/compat/proto/function_pb2.py +157 -0
  93. tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
  94. tensorbored/compat/proto/graph_pb2.py +41 -0
  95. tensorbored/compat/proto/histogram_pb2.py +39 -0
  96. tensorbored/compat/proto/meta_graph_pb2.py +254 -0
  97. tensorbored/compat/proto/node_def_pb2.py +61 -0
  98. tensorbored/compat/proto/op_def_pb2.py +81 -0
  99. tensorbored/compat/proto/resource_handle_pb2.py +48 -0
  100. tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
  101. tensorbored/compat/proto/rpc_options_pb2.py +35 -0
  102. tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
  103. tensorbored/compat/proto/saver_pb2.py +38 -0
  104. tensorbored/compat/proto/step_stats_pb2.py +116 -0
  105. tensorbored/compat/proto/struct_pb2.py +144 -0
  106. tensorbored/compat/proto/summary_pb2.py +111 -0
  107. tensorbored/compat/proto/tensor_description_pb2.py +38 -0
  108. tensorbored/compat/proto/tensor_pb2.py +68 -0
  109. tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
  110. tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
  111. tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
  112. tensorbored/compat/proto/types_pb2.py +105 -0
  113. tensorbored/compat/proto/variable_pb2.py +62 -0
  114. tensorbored/compat/proto/verifier_config_pb2.py +38 -0
  115. tensorbored/compat/proto/versions_pb2.py +35 -0
  116. tensorbored/compat/tensorflow_stub/__init__.py +38 -0
  117. tensorbored/compat/tensorflow_stub/app.py +124 -0
  118. tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
  119. tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
  120. tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
  121. tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
  122. tensorbored/compat/tensorflow_stub/errors.py +507 -0
  123. tensorbored/compat/tensorflow_stub/flags.py +124 -0
  124. tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
  125. tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
  126. tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
  127. tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
  128. tensorbored/context.py +129 -0
  129. tensorbored/data/__init__.py +0 -0
  130. tensorbored/data/grpc_provider.py +365 -0
  131. tensorbored/data/ingester.py +46 -0
  132. tensorbored/data/proto/__init__.py +0 -0
  133. tensorbored/data/proto/data_provider_pb2.py +517 -0
  134. tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
  135. tensorbored/data/provider.py +1365 -0
  136. tensorbored/data/server_ingester.py +301 -0
  137. tensorbored/data_compat.py +159 -0
  138. tensorbored/dataclass_compat.py +224 -0
  139. tensorbored/default.py +124 -0
  140. tensorbored/errors.py +130 -0
  141. tensorbored/lazy.py +99 -0
  142. tensorbored/main.py +48 -0
  143. tensorbored/main_lib.py +62 -0
  144. tensorbored/manager.py +487 -0
  145. tensorbored/notebook.py +441 -0
  146. tensorbored/plugin_util.py +266 -0
  147. tensorbored/plugins/__init__.py +0 -0
  148. tensorbored/plugins/audio/__init__.py +0 -0
  149. tensorbored/plugins/audio/audio_plugin.py +229 -0
  150. tensorbored/plugins/audio/metadata.py +69 -0
  151. tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
  152. tensorbored/plugins/audio/summary.py +230 -0
  153. tensorbored/plugins/audio/summary_v2.py +124 -0
  154. tensorbored/plugins/base_plugin.py +367 -0
  155. tensorbored/plugins/core/__init__.py +0 -0
  156. tensorbored/plugins/core/core_plugin.py +981 -0
  157. tensorbored/plugins/custom_scalar/__init__.py +0 -0
  158. tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
  159. tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
  160. tensorbored/plugins/custom_scalar/metadata.py +35 -0
  161. tensorbored/plugins/custom_scalar/summary.py +79 -0
  162. tensorbored/plugins/debugger_v2/__init__.py +0 -0
  163. tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
  164. tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
  165. tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
  166. tensorbored/plugins/distribution/__init__.py +0 -0
  167. tensorbored/plugins/distribution/compressor.py +158 -0
  168. tensorbored/plugins/distribution/distributions_plugin.py +116 -0
  169. tensorbored/plugins/distribution/metadata.py +19 -0
  170. tensorbored/plugins/graph/__init__.py +0 -0
  171. tensorbored/plugins/graph/graph_util.py +129 -0
  172. tensorbored/plugins/graph/graphs_plugin.py +336 -0
  173. tensorbored/plugins/graph/keras_util.py +328 -0
  174. tensorbored/plugins/graph/metadata.py +42 -0
  175. tensorbored/plugins/histogram/__init__.py +0 -0
  176. tensorbored/plugins/histogram/histograms_plugin.py +144 -0
  177. tensorbored/plugins/histogram/metadata.py +63 -0
  178. tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
  179. tensorbored/plugins/histogram/summary.py +234 -0
  180. tensorbored/plugins/histogram/summary_v2.py +292 -0
  181. tensorbored/plugins/hparams/__init__.py +14 -0
  182. tensorbored/plugins/hparams/_keras.py +93 -0
  183. tensorbored/plugins/hparams/api.py +130 -0
  184. tensorbored/plugins/hparams/api_pb2.py +208 -0
  185. tensorbored/plugins/hparams/backend_context.py +606 -0
  186. tensorbored/plugins/hparams/download_data.py +158 -0
  187. tensorbored/plugins/hparams/error.py +26 -0
  188. tensorbored/plugins/hparams/get_experiment.py +71 -0
  189. tensorbored/plugins/hparams/hparams_plugin.py +206 -0
  190. tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
  191. tensorbored/plugins/hparams/json_format_compat.py +38 -0
  192. tensorbored/plugins/hparams/list_metric_evals.py +57 -0
  193. tensorbored/plugins/hparams/list_session_groups.py +1040 -0
  194. tensorbored/plugins/hparams/metadata.py +125 -0
  195. tensorbored/plugins/hparams/metrics.py +41 -0
  196. tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
  197. tensorbored/plugins/hparams/summary.py +205 -0
  198. tensorbored/plugins/hparams/summary_v2.py +597 -0
  199. tensorbored/plugins/image/__init__.py +0 -0
  200. tensorbored/plugins/image/images_plugin.py +232 -0
  201. tensorbored/plugins/image/metadata.py +65 -0
  202. tensorbored/plugins/image/plugin_data_pb2.py +34 -0
  203. tensorbored/plugins/image/summary.py +159 -0
  204. tensorbored/plugins/image/summary_v2.py +130 -0
  205. tensorbored/plugins/mesh/__init__.py +14 -0
  206. tensorbored/plugins/mesh/mesh_plugin.py +292 -0
  207. tensorbored/plugins/mesh/metadata.py +152 -0
  208. tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
  209. tensorbored/plugins/mesh/summary.py +251 -0
  210. tensorbored/plugins/mesh/summary_v2.py +214 -0
  211. tensorbored/plugins/metrics/__init__.py +0 -0
  212. tensorbored/plugins/metrics/metadata.py +17 -0
  213. tensorbored/plugins/metrics/metrics_plugin.py +623 -0
  214. tensorbored/plugins/pr_curve/__init__.py +0 -0
  215. tensorbored/plugins/pr_curve/metadata.py +75 -0
  216. tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
  217. tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
  218. tensorbored/plugins/pr_curve/summary.py +574 -0
  219. tensorbored/plugins/profile_redirect/__init__.py +0 -0
  220. tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
  221. tensorbored/plugins/projector/__init__.py +67 -0
  222. tensorbored/plugins/projector/metadata.py +26 -0
  223. tensorbored/plugins/projector/projector_config_pb2.py +54 -0
  224. tensorbored/plugins/projector/projector_plugin.py +795 -0
  225. tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
  226. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
  227. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
  228. tensorbored/plugins/scalar/__init__.py +0 -0
  229. tensorbored/plugins/scalar/metadata.py +60 -0
  230. tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
  231. tensorbored/plugins/scalar/scalars_plugin.py +181 -0
  232. tensorbored/plugins/scalar/summary.py +109 -0
  233. tensorbored/plugins/scalar/summary_v2.py +124 -0
  234. tensorbored/plugins/text/__init__.py +0 -0
  235. tensorbored/plugins/text/metadata.py +62 -0
  236. tensorbored/plugins/text/plugin_data_pb2.py +34 -0
  237. tensorbored/plugins/text/summary.py +114 -0
  238. tensorbored/plugins/text/summary_v2.py +124 -0
  239. tensorbored/plugins/text/text_plugin.py +288 -0
  240. tensorbored/plugins/wit_redirect/__init__.py +0 -0
  241. tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
  242. tensorbored/program.py +910 -0
  243. tensorbored/summary/__init__.py +35 -0
  244. tensorbored/summary/_output.py +124 -0
  245. tensorbored/summary/_tf/__init__.py +14 -0
  246. tensorbored/summary/_tf/summary/__init__.py +178 -0
  247. tensorbored/summary/_writer.py +105 -0
  248. tensorbored/summary/v1.py +51 -0
  249. tensorbored/summary/v2.py +25 -0
  250. tensorbored/summary/writer/__init__.py +13 -0
  251. tensorbored/summary/writer/event_file_writer.py +291 -0
  252. tensorbored/summary/writer/record_writer.py +50 -0
  253. tensorbored/util/__init__.py +0 -0
  254. tensorbored/util/encoder.py +116 -0
  255. tensorbored/util/grpc_util.py +311 -0
  256. tensorbored/util/img_mime_type_detector.py +40 -0
  257. tensorbored/util/io_util.py +20 -0
  258. tensorbored/util/lazy_tensor_creator.py +110 -0
  259. tensorbored/util/op_evaluator.py +104 -0
  260. tensorbored/util/platform_util.py +20 -0
  261. tensorbored/util/tb_logging.py +24 -0
  262. tensorbored/util/tensor_util.py +617 -0
  263. tensorbored/util/timing.py +122 -0
  264. tensorbored/version.py +21 -0
  265. tensorbored/webfiles.zip +0 -0
  266. tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
  267. tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
  268. tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
  269. tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
  270. tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
  271. tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1035 @@
1
+ # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Helper classes for tensor shape inference."""
16
+
17
+ # pytype: skip-file
18
+
19
+ from . import compat, dtypes
20
+ from tensorbored.compat.proto import tensor_shape_pb2
21
+
22
+
23
+ # @tf_export("Dimension")
24
+ class Dimension:
25
+ """Represents the value of one dimension in a TensorShape."""
26
+
27
+ def __init__(self, value):
28
+ """Creates a new Dimension with the given value."""
29
+ if value is None:
30
+ self._value = None
31
+ elif isinstance(value, dtypes.DType):
32
+ raise TypeError("Cannot convert %s to Dimension" % value)
33
+ else:
34
+ self._value = int(value)
35
+ if (
36
+ not isinstance(value, compat.bytes_or_text_types)
37
+ and self._value != value
38
+ ):
39
+ raise ValueError("Ambiguous dimension: %s" % value)
40
+ if self._value < 0:
41
+ raise ValueError("Dimension %d must be >= 0" % self._value)
42
+
43
+ def __repr__(self):
44
+ return "Dimension(%s)" % repr(self._value)
45
+
46
+ def __str__(self):
47
+ value = self._value
48
+ return "?" if value is None else str(value)
49
+
50
+ def __eq__(self, other):
51
+ """Returns true if `other` has the same known value as this
52
+ Dimension."""
53
+ try:
54
+ other = as_dimension(other)
55
+ except (TypeError, ValueError):
56
+ return NotImplemented
57
+ if self._value is None or other.value is None:
58
+ return None
59
+ return self._value == other.value
60
+
61
+ def __ne__(self, other):
62
+ """Returns true if `other` has a different known value from `self`."""
63
+ try:
64
+ other = as_dimension(other)
65
+ except (TypeError, ValueError):
66
+ return NotImplemented
67
+ if self._value is None or other.value is None:
68
+ return None
69
+ return self._value != other.value
70
+
71
+ def __int__(self):
72
+ return self._value
73
+
74
+ # This is needed for Windows.
75
+ # See https://github.com/tensorflow/tensorflow/pull/9780
76
+ def __long__(self):
77
+ return self._value
78
+
79
+ def __index__(self):
80
+ # Allow use in Python 3 range
81
+ return self._value
82
+
83
+ @property
84
+ def value(self):
85
+ """The value of this dimension, or None if it is unknown."""
86
+ return self._value
87
+
88
+ def is_convertible_with(self, other):
89
+ """Returns true if `other` is convertible with this Dimension.
90
+
91
+ Two known Dimensions are convertible if they have the same value.
92
+ An unknown Dimension is convertible with all other Dimensions.
93
+
94
+ Args:
95
+ other: Another Dimension.
96
+
97
+ Returns:
98
+ True if this Dimension and `other` are convertible.
99
+ """
100
+ other = as_dimension(other)
101
+ return (
102
+ self._value is None
103
+ or other.value is None
104
+ or self._value == other.value
105
+ )
106
+
107
+ def assert_is_convertible_with(self, other):
108
+ """Raises an exception if `other` is not convertible with this
109
+ Dimension.
110
+
111
+ Args:
112
+ other: Another Dimension.
113
+
114
+ Raises:
115
+ ValueError: If `self` and `other` are not convertible (see
116
+ is_convertible_with).
117
+ """
118
+ if not self.is_convertible_with(other):
119
+ raise ValueError(
120
+ "Dimensions %s and %s are not convertible" % (self, other)
121
+ )
122
+
123
+ def merge_with(self, other):
124
+ """Returns a Dimension that combines the information in `self` and
125
+ `other`.
126
+
127
+ Dimensions are combined as follows:
128
+
129
+ ```python
130
+ tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n)
131
+ tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n)
132
+ tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n)
133
+ tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None)
134
+ tf.Dimension(n) .merge_with(tf.Dimension(m)) # raises ValueError for n != m
135
+ ```
136
+
137
+ Args:
138
+ other: Another Dimension.
139
+
140
+ Returns:
141
+ A Dimension containing the combined information of `self` and
142
+ `other`.
143
+
144
+ Raises:
145
+ ValueError: If `self` and `other` are not convertible (see
146
+ is_convertible_with).
147
+ """
148
+ other = as_dimension(other)
149
+ self.assert_is_convertible_with(other)
150
+ if self._value is None:
151
+ return Dimension(other.value)
152
+ else:
153
+ return Dimension(self._value)
154
+
155
+ def __add__(self, other):
156
+ """Returns the sum of `self` and `other`.
157
+
158
+ Dimensions are summed as follows:
159
+
160
+ ```python
161
+ tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n)
162
+ tf.Dimension(m) + tf.Dimension(None) == tf.Dimension(None)
163
+ tf.Dimension(None) + tf.Dimension(n) == tf.Dimension(None)
164
+ tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None)
165
+ ```
166
+
167
+ Args:
168
+ other: Another Dimension, or a value accepted by `as_dimension`.
169
+
170
+ Returns:
171
+ A Dimension whose value is the sum of `self` and `other`.
172
+ """
173
+ other = as_dimension(other)
174
+ if self._value is None or other.value is None:
175
+ return Dimension(None)
176
+ else:
177
+ return Dimension(self._value + other.value)
178
+
179
+ def __radd__(self, other):
180
+ """Returns the sum of `other` and `self`.
181
+
182
+ Args:
183
+ other: Another Dimension, or a value accepted by `as_dimension`.
184
+
185
+ Returns:
186
+ A Dimension whose value is the sum of `self` and `other`.
187
+ """
188
+ return self + other
189
+
190
+ def __sub__(self, other):
191
+ """Returns the subtraction of `other` from `self`.
192
+
193
+ Dimensions are subtracted as follows:
194
+
195
+ ```python
196
+ tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n)
197
+ tf.Dimension(m) - tf.Dimension(None) == tf.Dimension(None)
198
+ tf.Dimension(None) - tf.Dimension(n) == tf.Dimension(None)
199
+ tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None)
200
+ ```
201
+
202
+ Args:
203
+ other: Another Dimension, or a value accepted by `as_dimension`.
204
+
205
+ Returns:
206
+ A Dimension whose value is the subtraction of `other` from `self`.
207
+ """
208
+ other = as_dimension(other)
209
+ if self._value is None or other.value is None:
210
+ return Dimension(None)
211
+ else:
212
+ return Dimension(self._value - other.value)
213
+
214
+ def __rsub__(self, other):
215
+ """Returns the subtraction of `self` from `other`.
216
+
217
+ Args:
218
+ other: Another Dimension, or a value accepted by `as_dimension`.
219
+
220
+ Returns:
221
+ A Dimension whose value is the subtraction of `self` from `other`.
222
+ """
223
+ other = as_dimension(other)
224
+ if self._value is None or other.value is None:
225
+ return Dimension(None)
226
+ else:
227
+ return Dimension(other.value - self._value)
228
+
229
+ def __mul__(self, other):
230
+ """Returns the product of `self` and `other`.
231
+
232
+ Dimensions are summed as follows:
233
+
234
+ ```python
235
+ tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n)
236
+ tf.Dimension(m) * tf.Dimension(None) == tf.Dimension(None)
237
+ tf.Dimension(None) * tf.Dimension(n) == tf.Dimension(None)
238
+ tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None)
239
+ ```
240
+
241
+ Args:
242
+ other: Another Dimension, or a value accepted by `as_dimension`.
243
+
244
+ Returns:
245
+ A Dimension whose value is the product of `self` and `other`.
246
+ """
247
+ try:
248
+ other = as_dimension(other)
249
+ except (TypeError, ValueError):
250
+ return NotImplemented
251
+
252
+ if self._value is None or other.value is None:
253
+ return Dimension(None)
254
+ else:
255
+ return Dimension(self._value * other.value)
256
+
257
+ def __rmul__(self, other):
258
+ """Returns the product of `self` and `other`.
259
+
260
+ Args:
261
+ other: Another Dimension, or a value accepted by `as_dimension`.
262
+
263
+ Returns:
264
+ A Dimension whose value is the product of `self` and `other`.
265
+ """
266
+ return self * other
267
+
268
+ def __floordiv__(self, other):
269
+ """Returns the quotient of `self` and `other` rounded down.
270
+
271
+ Dimensions are divided as follows:
272
+
273
+ ```python
274
+ tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n)
275
+ tf.Dimension(m) // tf.Dimension(None) == tf.Dimension(None)
276
+ tf.Dimension(None) // tf.Dimension(n) == tf.Dimension(None)
277
+ tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None)
278
+ ```
279
+
280
+ Args:
281
+ other: Another Dimension, or a value accepted by `as_dimension`.
282
+
283
+ Returns:
284
+ A `Dimension` whose value is the integer quotient of `self` and `other`.
285
+ """
286
+ try:
287
+ other = as_dimension(other)
288
+ except (TypeError, ValueError):
289
+ return NotImplemented
290
+ if self._value is None or other.value is None:
291
+ return Dimension(None)
292
+ else:
293
+ return Dimension(self._value // other.value)
294
+
295
+ def __rfloordiv__(self, other):
296
+ """Returns the quotient of `other` and `self` rounded down.
297
+
298
+ Args:
299
+ other: Another Dimension, or a value accepted by `as_dimension`.
300
+
301
+ Returns:
302
+ A `Dimension` whose value is the integer quotient of `self` and `other`.
303
+ """
304
+ other = as_dimension(other)
305
+ if self._value is None or other.value is None:
306
+ return Dimension(None)
307
+ else:
308
+ return Dimension(other.value // self._value)
309
+
310
+ def __div__(self, other):
311
+ """DEPRECATED: Use `__floordiv__` via `x // y` instead.
312
+
313
+ This function exists only for backwards convertibility purposes; new code
314
+ should use `__floordiv__` via the syntax `x // y`. Using `x // y`
315
+ communicates clearly that the result rounds down, and is forward convertible
316
+ to Python 3.
317
+
318
+ Args:
319
+ other: Another `Dimension`.
320
+
321
+ Returns:
322
+ A `Dimension` whose value is the integer quotient of `self` and `other`.
323
+ """
324
+ return self // other
325
+
326
+ def __mod__(self, other):
327
+ """Returns `self` modulo `other`.
328
+
329
+ Dimension moduli are computed as follows:
330
+
331
+ ```python
332
+ tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n)
333
+ tf.Dimension(m) % tf.Dimension(None) == tf.Dimension(None)
334
+ tf.Dimension(None) % tf.Dimension(n) == tf.Dimension(None)
335
+ tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None)
336
+ ```
337
+
338
+ Args:
339
+ other: Another Dimension, or a value accepted by `as_dimension`.
340
+
341
+ Returns:
342
+ A Dimension whose value is `self` modulo `other`.
343
+ """
344
+ try:
345
+ other = as_dimension(other)
346
+ except (TypeError, ValueError):
347
+ return NotImplemented
348
+ if self._value is None or other.value is None:
349
+ return Dimension(None)
350
+ else:
351
+ return Dimension(self._value % other.value)
352
+
353
+ def __rmod__(self, other):
354
+ """Returns `other` modulo `self`.
355
+
356
+ Args:
357
+ other: Another Dimension, or a value accepted by `as_dimension`.
358
+
359
+ Returns:
360
+ A Dimension whose value is `other` modulo `self`.
361
+ """
362
+ try:
363
+ other = as_dimension(other)
364
+ except (TypeError, ValueError):
365
+ return NotImplemented
366
+ return other % self
367
+
368
+ def __lt__(self, other):
369
+ """Returns True if `self` is known to be less than `other`.
370
+
371
+ Dimensions are compared as follows:
372
+
373
+ ```python
374
+ (tf.Dimension(m) < tf.Dimension(n)) == (m < n)
375
+ (tf.Dimension(m) < tf.Dimension(None)) == None
376
+ (tf.Dimension(None) < tf.Dimension(n)) == None
377
+ (tf.Dimension(None) < tf.Dimension(None)) == None
378
+ ```
379
+
380
+ Args:
381
+ other: Another Dimension.
382
+
383
+ Returns:
384
+ The value of `self.value < other.value` if both are known, otherwise
385
+ None.
386
+ """
387
+ other = as_dimension(other)
388
+ if self._value is None or other.value is None:
389
+ return None
390
+ else:
391
+ return self._value < other.value
392
+
393
+ def __le__(self, other):
394
+ """Returns True if `self` is known to be less than or equal to `other`.
395
+
396
+ Dimensions are compared as follows:
397
+
398
+ ```python
399
+ (tf.Dimension(m) <= tf.Dimension(n)) == (m <= n)
400
+ (tf.Dimension(m) <= tf.Dimension(None)) == None
401
+ (tf.Dimension(None) <= tf.Dimension(n)) == None
402
+ (tf.Dimension(None) <= tf.Dimension(None)) == None
403
+ ```
404
+
405
+ Args:
406
+ other: Another Dimension.
407
+
408
+ Returns:
409
+ The value of `self.value <= other.value` if both are known, otherwise
410
+ None.
411
+ """
412
+ other = as_dimension(other)
413
+ if self._value is None or other.value is None:
414
+ return None
415
+ else:
416
+ return self._value <= other.value
417
+
418
+ def __gt__(self, other):
419
+ """Returns True if `self` is known to be greater than `other`.
420
+
421
+ Dimensions are compared as follows:
422
+
423
+ ```python
424
+ (tf.Dimension(m) > tf.Dimension(n)) == (m > n)
425
+ (tf.Dimension(m) > tf.Dimension(None)) == None
426
+ (tf.Dimension(None) > tf.Dimension(n)) == None
427
+ (tf.Dimension(None) > tf.Dimension(None)) == None
428
+ ```
429
+
430
+ Args:
431
+ other: Another Dimension.
432
+
433
+ Returns:
434
+ The value of `self.value > other.value` if both are known, otherwise
435
+ None.
436
+ """
437
+ other = as_dimension(other)
438
+ if self._value is None or other.value is None:
439
+ return None
440
+ else:
441
+ return self._value > other.value
442
+
443
+ def __ge__(self, other):
444
+ """Returns True if `self` is known to be greater than or equal to
445
+ `other`.
446
+
447
+ Dimensions are compared as follows:
448
+
449
+ ```python
450
+ (tf.Dimension(m) >= tf.Dimension(n)) == (m >= n)
451
+ (tf.Dimension(m) >= tf.Dimension(None)) == None
452
+ (tf.Dimension(None) >= tf.Dimension(n)) == None
453
+ (tf.Dimension(None) >= tf.Dimension(None)) == None
454
+ ```
455
+
456
+ Args:
457
+ other: Another Dimension.
458
+
459
+ Returns:
460
+ The value of `self.value >= other.value` if both are known, otherwise
461
+ None.
462
+ """
463
+ other = as_dimension(other)
464
+ if self._value is None or other.value is None:
465
+ return None
466
+ else:
467
+ return self._value >= other.value
468
+
469
+ def __reduce__(self):
470
+ return Dimension, (self._value,)
471
+
472
+
473
+ def as_dimension(value):
474
+ """Converts the given value to a Dimension.
475
+
476
+ A Dimension input will be returned unmodified.
477
+ An input of `None` will be converted to an unknown Dimension.
478
+ An integer input will be converted to a Dimension with that value.
479
+
480
+ Args:
481
+ value: The value to be converted.
482
+
483
+ Returns:
484
+ A Dimension corresponding to the given value.
485
+ """
486
+ if isinstance(value, Dimension):
487
+ return value
488
+ else:
489
+ return Dimension(value)
490
+
491
+
492
+ # @tf_export("TensorShape")
493
+ class TensorShape:
494
+ """Represents the shape of a `Tensor`.
495
+
496
+ A `TensorShape` represents a possibly-partial shape specification for a
497
+ `Tensor`. It may be one of the following:
498
+
499
+ * *Fully-known shape:* has a known number of dimensions and a known size
500
+ for each dimension. e.g. `TensorShape([16, 256])`
501
+ * *Partially-known shape:* has a known number of dimensions, and an unknown
502
+ size for one or more dimension. e.g. `TensorShape([None, 256])`
503
+ * *Unknown shape:* has an unknown number of dimensions, and an unknown
504
+ size in all dimensions. e.g. `TensorShape(None)`
505
+
506
+ If a tensor is produced by an operation of type `"Foo"`, its shape
507
+ may be inferred if there is a registered shape function for
508
+ `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
509
+ for details of shape functions and how to register them. Alternatively,
510
+ the shape may be set explicitly using @{tf.Tensor.set_shape}.
511
+ """
512
+
513
+ def __init__(self, dims):
514
+ """Creates a new TensorShape with the given dimensions.
515
+
516
+ Args:
517
+ dims: A list of Dimensions, or None if the shape is unspecified.
518
+ DEPRECATED: A single integer is treated as a singleton list.
519
+
520
+ Raises:
521
+ TypeError: If dims cannot be converted to a list of dimensions.
522
+ """
523
+ # TODO(irving): Eliminate the single integer special case.
524
+ if dims is None:
525
+ self._dims = None
526
+ elif isinstance(dims, compat.bytes_or_text_types):
527
+ raise TypeError(
528
+ "A string has ambiguous TensorShape, please wrap in a "
529
+ "list or convert to an int: %s" % dims
530
+ )
531
+ elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
532
+ if dims.unknown_rank:
533
+ self._dims = None
534
+ else:
535
+ self._dims = [
536
+ # Protos store variable-size dimensions as -1
537
+ as_dimension(dim.size if dim.size != -1 else None)
538
+ for dim in dims.dim
539
+ ]
540
+ elif isinstance(dims, TensorShape):
541
+ self._dims = dims.dims
542
+ else:
543
+ try:
544
+ dims_iter = iter(dims)
545
+ except TypeError:
546
+ # Treat as a singleton dimension
547
+ self._dims = [as_dimension(dims)]
548
+ else:
549
+ # Got a list of dimensions
550
+ self._dims = [as_dimension(d) for d in dims_iter]
551
+ self._ndims = None
552
+
553
+ def __repr__(self):
554
+ return "TensorShape(%r)" % self._dims
555
+
556
+ def __str__(self):
557
+ if self.ndims is None:
558
+ return "<unknown>"
559
+ elif self.ndims == 1:
560
+ return "(%s,)" % self._dims[0]
561
+ else:
562
+ return "(%s)" % ", ".join(str(d) for d in self._dims)
563
+
564
+ @property
565
+ def dims(self):
566
+ """Returns a list of Dimensions, or None if the shape is
567
+ unspecified."""
568
+ return self._dims
569
+
570
+ @dims.setter
571
+ def dims(self, dims):
572
+ self._dims = dims
573
+ self._ndims = None
574
+
575
+ @property
576
+ def ndims(self):
577
+ """Returns the rank of this shape, or None if it is unspecified."""
578
+ if self._dims is None:
579
+ return None
580
+ else:
581
+ if self._ndims is None:
582
+ self._ndims = len(self._dims)
583
+ return self._ndims
584
+
585
+ def __len__(self):
586
+ """Returns the rank of this shape, or raises ValueError if
587
+ unspecified."""
588
+ if self._dims is None:
589
+ raise ValueError(
590
+ "Cannot take the length of Shape with unknown rank."
591
+ )
592
+ return self.ndims
593
+
594
+ def __bool__(self):
595
+ """Returns True if this shape contains non-zero information."""
596
+ return self._dims is not None
597
+
598
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
599
+ __nonzero__ = __bool__
600
+
601
+ def __iter__(self):
602
+ """Returns `self.dims` if the rank is known, otherwise raises
603
+ ValueError."""
604
+ if self._dims is None:
605
+ raise ValueError("Cannot iterate over a shape with unknown rank.")
606
+ else:
607
+ return iter(self._dims)
608
+
609
+ def __getitem__(self, key):
610
+ """Returns the value of a dimension or a shape, depending on the key.
611
+
612
+ Args:
613
+ key: If `key` is an integer, returns the dimension at that index;
614
+ otherwise if `key` is a slice, returns a TensorShape whose
615
+ dimensions are those selected by the slice from `self`.
616
+
617
+ Returns:
618
+ A dimension if `key` is an integer, or a `TensorShape` if `key` is a
619
+ slice.
620
+
621
+ Raises:
622
+ ValueError: If `key` is a slice, and any of its elements are negative, or
623
+ if `self` is completely unknown and the step is set.
624
+ """
625
+ if self._dims is not None:
626
+ if isinstance(key, slice):
627
+ return TensorShape(self._dims[key])
628
+ else:
629
+ return self._dims[key]
630
+ else:
631
+ if isinstance(key, slice):
632
+ start = key.start if key.start is not None else 0
633
+ stop = key.stop
634
+
635
+ if key.step is not None:
636
+ # TODO(mrry): Handle these maybe.
637
+ raise ValueError("Steps are not yet handled")
638
+ if stop is None:
639
+ # NOTE(mrry): This implies that TensorShape(None) is convertible with
640
+ # TensorShape(None)[1:], which is obviously not true. It would be
641
+ # possible to track the number of dimensions symbolically,
642
+ # and perhaps we should do that.
643
+ return unknown_shape()
644
+ elif start < 0 or stop < 0:
645
+ # TODO(mrry): Handle this better, as it will be useful for handling
646
+ # suffixes of otherwise unknown shapes.
647
+ return unknown_shape()
648
+ else:
649
+ return unknown_shape(ndims=stop - start)
650
+ else:
651
+ return Dimension(None)
652
+
653
+ def num_elements(self):
654
+ """Returns the total number of elements, or none for incomplete
655
+ shapes."""
656
+ if self.is_fully_defined():
657
+ size = 1
658
+ for dim in self._dims:
659
+ size *= dim.value
660
+ return size
661
+ else:
662
+ return None
663
+
664
+ def merge_with(self, other):
665
+ """Returns a `TensorShape` combining the information in `self` and
666
+ `other`.
667
+
668
+ The dimensions in `self` and `other` are merged elementwise,
669
+ according to the rules defined for `Dimension.merge_with()`.
670
+
671
+ Args:
672
+ other: Another `TensorShape`.
673
+
674
+ Returns:
675
+ A `TensorShape` containing the combined information of `self` and
676
+ `other`.
677
+
678
+ Raises:
679
+ ValueError: If `self` and `other` are not convertible.
680
+ """
681
+ other = as_shape(other)
682
+ if self._dims is None:
683
+ return other
684
+ else:
685
+ try:
686
+ self.assert_same_rank(other)
687
+ new_dims = []
688
+ for i, dim in enumerate(self._dims):
689
+ new_dims.append(dim.merge_with(other[i]))
690
+ return TensorShape(new_dims)
691
+ except ValueError:
692
+ raise ValueError(
693
+ "Shapes %s and %s are not convertible" % (self, other)
694
+ )
695
+
696
+ def concatenate(self, other):
697
+ """Returns the concatenation of the dimension in `self` and `other`.
698
+
699
+ *N.B.* If either `self` or `other` is completely unknown,
700
+ concatenation will discard information about the other shape. In
701
+ future, we might support concatenation that preserves this
702
+ information for use with slicing.
703
+
704
+ Args:
705
+ other: Another `TensorShape`.
706
+
707
+ Returns:
708
+ A `TensorShape` whose dimensions are the concatenation of the
709
+ dimensions in `self` and `other`.
710
+ """
711
+ # TODO(mrry): Handle the case where we concatenate a known shape with a
712
+ # completely unknown shape, so that we can use the partial information.
713
+ other = as_shape(other)
714
+ if self._dims is None or other.dims is None:
715
+ return unknown_shape()
716
+ else:
717
+ return TensorShape(self._dims + other.dims)
718
+
719
+ def assert_same_rank(self, other):
720
+ """Raises an exception if `self` and `other` do not have convertible
721
+ ranks.
722
+
723
+ Args:
724
+ other: Another `TensorShape`.
725
+
726
+ Raises:
727
+ ValueError: If `self` and `other` do not represent shapes with the
728
+ same rank.
729
+ """
730
+ other = as_shape(other)
731
+ if self.ndims is not None and other.ndims is not None:
732
+ if self.ndims != other.ndims:
733
+ raise ValueError(
734
+ "Shapes %s and %s must have the same rank" % (self, other)
735
+ )
736
+
737
+ def assert_has_rank(self, rank):
738
+ """Raises an exception if `self` is not convertible with the given
739
+ `rank`.
740
+
741
+ Args:
742
+ rank: An integer.
743
+
744
+ Raises:
745
+ ValueError: If `self` does not represent a shape with the given `rank`.
746
+ """
747
+ if self.ndims not in (None, rank):
748
+ raise ValueError("Shape %s must have rank %d" % (self, rank))
749
+
750
+ def with_rank(self, rank):
751
+ """Returns a shape based on `self` with the given rank.
752
+
753
+ This method promotes a completely unknown shape to one with a
754
+ known rank.
755
+
756
+ Args:
757
+ rank: An integer.
758
+
759
+ Returns:
760
+ A shape that is at least as specific as `self` with the given rank.
761
+
762
+ Raises:
763
+ ValueError: If `self` does not represent a shape with the given `rank`.
764
+ """
765
+ try:
766
+ return self.merge_with(unknown_shape(ndims=rank))
767
+ except ValueError:
768
+ raise ValueError("Shape %s must have rank %d" % (self, rank))
769
+
770
+ def with_rank_at_least(self, rank):
771
+ """Returns a shape based on `self` with at least the given rank.
772
+
773
+ Args:
774
+ rank: An integer.
775
+
776
+ Returns:
777
+ A shape that is at least as specific as `self` with at least the given
778
+ rank.
779
+
780
+ Raises:
781
+ ValueError: If `self` does not represent a shape with at least the given
782
+ `rank`.
783
+ """
784
+ if self.ndims is not None and self.ndims < rank:
785
+ raise ValueError(
786
+ "Shape %s must have rank at least %d" % (self, rank)
787
+ )
788
+ else:
789
+ return self
790
+
791
+ def with_rank_at_most(self, rank):
792
+ """Returns a shape based on `self` with at most the given rank.
793
+
794
+ Args:
795
+ rank: An integer.
796
+
797
+ Returns:
798
+ A shape that is at least as specific as `self` with at most the given
799
+ rank.
800
+
801
+ Raises:
802
+ ValueError: If `self` does not represent a shape with at most the given
803
+ `rank`.
804
+ """
805
+ if self.ndims is not None and self.ndims > rank:
806
+ raise ValueError(
807
+ "Shape %s must have rank at most %d" % (self, rank)
808
+ )
809
+ else:
810
+ return self
811
+
812
+ def is_convertible_with(self, other):
813
+ """Returns True iff `self` is convertible with `other`.
814
+
815
+ Two possibly-partially-defined shapes are convertible if there
816
+ exists a fully-defined shape that both shapes can represent. Thus,
817
+ convertibility allows the shape inference code to reason about
818
+ partially-defined shapes. For example:
819
+
820
+ * TensorShape(None) is convertible with all shapes.
821
+
822
+ * TensorShape([None, None]) is convertible with all two-dimensional
823
+ shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
824
+ not convertible with, for example, TensorShape([None]) or
825
+ TensorShape([None, None, None]).
826
+
827
+ * TensorShape([32, None]) is convertible with all two-dimensional shapes
828
+ with size 32 in the 0th dimension, and also TensorShape([None, None])
829
+ and TensorShape(None). It is not convertible with, for example,
830
+ TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
831
+
832
+ * TensorShape([32, 784]) is convertible with itself, and also
833
+ TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
834
+ None]) and TensorShape(None). It is not convertible with, for example,
835
+ TensorShape([32, 1, 784]) or TensorShape([None]).
836
+
837
+ The convertibility relation is reflexive and symmetric, but not
838
+ transitive. For example, TensorShape([32, 784]) is convertible with
839
+ TensorShape(None), and TensorShape(None) is convertible with
840
+ TensorShape([4, 4]), but TensorShape([32, 784]) is not convertible with
841
+ TensorShape([4, 4]).
842
+
843
+ Args:
844
+ other: Another TensorShape.
845
+
846
+ Returns:
847
+ True iff `self` is convertible with `other`.
848
+ """
849
+ other = as_shape(other)
850
+ if self._dims is not None and other.dims is not None:
851
+ if self.ndims != other.ndims:
852
+ return False
853
+ for x_dim, y_dim in zip(self._dims, other.dims):
854
+ if not x_dim.is_convertible_with(y_dim):
855
+ return False
856
+ return True
857
+
858
+ def assert_is_convertible_with(self, other):
859
+ """Raises exception if `self` and `other` do not represent the same
860
+ shape.
861
+
862
+ This method can be used to assert that there exists a shape that both
863
+ `self` and `other` represent.
864
+
865
+ Args:
866
+ other: Another TensorShape.
867
+
868
+ Raises:
869
+ ValueError: If `self` and `other` do not represent the same shape.
870
+ """
871
+ if not self.is_convertible_with(other):
872
+ raise ValueError(
873
+ "Shapes %s and %s are inconvertible" % (self, other)
874
+ )
875
+
876
+ def most_specific_convertible_shape(self, other):
877
+ """Returns the most specific TensorShape convertible with `self` and
878
+ `other`.
879
+
880
+ * TensorShape([None, 1]) is the most specific TensorShape convertible with
881
+ both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
882
+ TensorShape(None) is also convertible with above mentioned TensorShapes.
883
+
884
+ * TensorShape([1, 2, 3]) is the most specific TensorShape convertible with
885
+ both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
886
+ less specific TensorShapes convertible with above mentioned TensorShapes,
887
+ e.g. TensorShape([1, 2, None]), TensorShape(None).
888
+
889
+ Args:
890
+ other: Another `TensorShape`.
891
+
892
+ Returns:
893
+ A `TensorShape` which is the most specific convertible shape of `self`
894
+ and `other`.
895
+ """
896
+
897
+ other = as_shape(other)
898
+ if (
899
+ self._dims is None
900
+ or other.dims is None
901
+ or self.ndims != other.ndims
902
+ ):
903
+ return unknown_shape()
904
+
905
+ dims = [Dimension(None)] * self.ndims
906
+ for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
907
+ if d1 is not None and d2 is not None and d1 == d2:
908
+ dims[i] = d1
909
+ return TensorShape(dims)
910
+
911
+ def is_fully_defined(self):
912
+ """Returns True iff `self` is fully defined in every dimension."""
913
+ return self._dims is not None and all(
914
+ dim.value is not None for dim in self._dims
915
+ )
916
+
917
+ def assert_is_fully_defined(self):
918
+ """Raises an exception if `self` is not fully defined in every
919
+ dimension.
920
+
921
+ Raises:
922
+ ValueError: If `self` does not have a known value for every dimension.
923
+ """
924
+ if not self.is_fully_defined():
925
+ raise ValueError("Shape %s is not fully defined" % self)
926
+
927
+ def as_list(self):
928
+ """Returns a list of integers or `None` for each dimension.
929
+
930
+ Returns:
931
+ A list of integers or `None` for each dimension.
932
+
933
+ Raises:
934
+ ValueError: If `self` is an unknown shape with an unknown rank.
935
+ """
936
+ if self._dims is None:
937
+ raise ValueError(
938
+ "as_list() is not defined on an unknown TensorShape."
939
+ )
940
+ return [dim.value for dim in self._dims]
941
+
942
+ def as_proto(self):
943
+ """Returns this shape as a `TensorShapeProto`."""
944
+ if self._dims is None:
945
+ return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
946
+ else:
947
+ return tensor_shape_pb2.TensorShapeProto(
948
+ dim=[
949
+ tensor_shape_pb2.TensorShapeProto.Dim(
950
+ size=-1 if d.value is None else d.value
951
+ )
952
+ for d in self._dims
953
+ ]
954
+ )
955
+
956
+ def __eq__(self, other):
957
+ """Returns True if `self` is equivalent to `other`."""
958
+ try:
959
+ other = as_shape(other)
960
+ except TypeError:
961
+ return NotImplemented
962
+ return self._dims == other.dims
963
+
964
+ def __ne__(self, other):
965
+ """Returns True if `self` is known to be different from `other`."""
966
+ try:
967
+ other = as_shape(other)
968
+ except TypeError:
969
+ return NotImplemented
970
+ if self.ndims is None or other.ndims is None:
971
+ raise ValueError(
972
+ "The inequality of unknown TensorShapes is undefined."
973
+ )
974
+ if self.ndims != other.ndims:
975
+ return True
976
+ return self._dims != other.dims
977
+
978
+ def __reduce__(self):
979
+ return TensorShape, (self._dims,)
980
+
981
+
982
+ def as_shape(shape):
983
+ """Converts the given object to a TensorShape."""
984
+ if isinstance(shape, TensorShape):
985
+ return shape
986
+ else:
987
+ return TensorShape(shape)
988
+
989
+
990
+ def unknown_shape(ndims=None):
991
+ """Returns an unknown TensorShape, optionally with a known rank.
992
+
993
+ Args:
994
+ ndims: (Optional) If specified, the number of dimensions in the shape.
995
+
996
+ Returns:
997
+ An unknown TensorShape.
998
+ """
999
+ if ndims is None:
1000
+ return TensorShape(None)
1001
+ else:
1002
+ return TensorShape([Dimension(None)] * ndims)
1003
+
1004
+
1005
+ _SCALAR_SHAPE = TensorShape([])
1006
+
1007
+
1008
+ def scalar():
1009
+ """Returns a shape representing a scalar."""
1010
+ return _SCALAR_SHAPE
1011
+
1012
+
1013
+ def vector(length):
1014
+ """Returns a shape representing a vector.
1015
+
1016
+ Args:
1017
+ length: The length of the vector, which may be None if unknown.
1018
+
1019
+ Returns:
1020
+ A TensorShape representing a vector of the given length.
1021
+ """
1022
+ return TensorShape([length])
1023
+
1024
+
1025
+ def matrix(rows, cols):
1026
+ """Returns a shape representing a matrix.
1027
+
1028
+ Args:
1029
+ rows: The number of rows in the matrix, which may be None if unknown.
1030
+ cols: The number of columns in the matrix, which may be None if unknown.
1031
+
1032
+ Returns:
1033
+ A TensorShape representing a matrix of the given size.
1034
+ """
1035
+ return TensorShape([rows, cols])