tensorbored 2.21.0rc1769983804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorbored/__init__.py +112 -0
- tensorbored/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/__init__.py +125 -0
- tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
- tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
- tensorbored/_vendor/bleach/callbacks.py +32 -0
- tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
- tensorbored/_vendor/bleach/linkifier.py +633 -0
- tensorbored/_vendor/bleach/parse_shim.py +1 -0
- tensorbored/_vendor/bleach/sanitizer.py +638 -0
- tensorbored/_vendor/bleach/six_shim.py +19 -0
- tensorbored/_vendor/webencodings/__init__.py +342 -0
- tensorbored/_vendor/webencodings/labels.py +231 -0
- tensorbored/_vendor/webencodings/mklabels.py +59 -0
- tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
- tensorbored/assets.py +36 -0
- tensorbored/auth.py +102 -0
- tensorbored/backend/__init__.py +0 -0
- tensorbored/backend/application.py +604 -0
- tensorbored/backend/auth_context_middleware.py +38 -0
- tensorbored/backend/client_feature_flags.py +113 -0
- tensorbored/backend/empty_path_redirect.py +46 -0
- tensorbored/backend/event_processing/__init__.py +0 -0
- tensorbored/backend/event_processing/data_ingester.py +276 -0
- tensorbored/backend/event_processing/data_provider.py +535 -0
- tensorbored/backend/event_processing/directory_loader.py +142 -0
- tensorbored/backend/event_processing/directory_watcher.py +272 -0
- tensorbored/backend/event_processing/event_accumulator.py +950 -0
- tensorbored/backend/event_processing/event_file_inspector.py +463 -0
- tensorbored/backend/event_processing/event_file_loader.py +292 -0
- tensorbored/backend/event_processing/event_multiplexer.py +521 -0
- tensorbored/backend/event_processing/event_util.py +68 -0
- tensorbored/backend/event_processing/io_wrapper.py +223 -0
- tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
- tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
- tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
- tensorbored/backend/event_processing/reservoir.py +266 -0
- tensorbored/backend/event_processing/tag_types.py +29 -0
- tensorbored/backend/experiment_id.py +71 -0
- tensorbored/backend/experimental_plugin.py +51 -0
- tensorbored/backend/http_util.py +263 -0
- tensorbored/backend/json_util.py +70 -0
- tensorbored/backend/path_prefix.py +67 -0
- tensorbored/backend/process_graph.py +74 -0
- tensorbored/backend/security_validator.py +202 -0
- tensorbored/compat/__init__.py +69 -0
- tensorbored/compat/proto/__init__.py +0 -0
- tensorbored/compat/proto/allocation_description_pb2.py +35 -0
- tensorbored/compat/proto/api_def_pb2.py +82 -0
- tensorbored/compat/proto/attr_value_pb2.py +80 -0
- tensorbored/compat/proto/cluster_pb2.py +58 -0
- tensorbored/compat/proto/config_pb2.py +271 -0
- tensorbored/compat/proto/coordination_config_pb2.py +45 -0
- tensorbored/compat/proto/cost_graph_pb2.py +87 -0
- tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
- tensorbored/compat/proto/debug_pb2.py +65 -0
- tensorbored/compat/proto/event_pb2.py +149 -0
- tensorbored/compat/proto/full_type_pb2.py +74 -0
- tensorbored/compat/proto/function_pb2.py +157 -0
- tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
- tensorbored/compat/proto/graph_pb2.py +41 -0
- tensorbored/compat/proto/histogram_pb2.py +39 -0
- tensorbored/compat/proto/meta_graph_pb2.py +254 -0
- tensorbored/compat/proto/node_def_pb2.py +61 -0
- tensorbored/compat/proto/op_def_pb2.py +81 -0
- tensorbored/compat/proto/resource_handle_pb2.py +48 -0
- tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
- tensorbored/compat/proto/rpc_options_pb2.py +35 -0
- tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
- tensorbored/compat/proto/saver_pb2.py +38 -0
- tensorbored/compat/proto/step_stats_pb2.py +116 -0
- tensorbored/compat/proto/struct_pb2.py +144 -0
- tensorbored/compat/proto/summary_pb2.py +111 -0
- tensorbored/compat/proto/tensor_description_pb2.py +38 -0
- tensorbored/compat/proto/tensor_pb2.py +68 -0
- tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
- tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
- tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
- tensorbored/compat/proto/types_pb2.py +105 -0
- tensorbored/compat/proto/variable_pb2.py +62 -0
- tensorbored/compat/proto/verifier_config_pb2.py +38 -0
- tensorbored/compat/proto/versions_pb2.py +35 -0
- tensorbored/compat/tensorflow_stub/__init__.py +38 -0
- tensorbored/compat/tensorflow_stub/app.py +124 -0
- tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
- tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
- tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
- tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
- tensorbored/compat/tensorflow_stub/errors.py +507 -0
- tensorbored/compat/tensorflow_stub/flags.py +124 -0
- tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
- tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
- tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
- tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
- tensorbored/context.py +129 -0
- tensorbored/data/__init__.py +0 -0
- tensorbored/data/grpc_provider.py +365 -0
- tensorbored/data/ingester.py +46 -0
- tensorbored/data/proto/__init__.py +0 -0
- tensorbored/data/proto/data_provider_pb2.py +517 -0
- tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
- tensorbored/data/provider.py +1365 -0
- tensorbored/data/server_ingester.py +301 -0
- tensorbored/data_compat.py +159 -0
- tensorbored/dataclass_compat.py +224 -0
- tensorbored/default.py +124 -0
- tensorbored/errors.py +130 -0
- tensorbored/lazy.py +99 -0
- tensorbored/main.py +48 -0
- tensorbored/main_lib.py +62 -0
- tensorbored/manager.py +487 -0
- tensorbored/notebook.py +441 -0
- tensorbored/plugin_util.py +266 -0
- tensorbored/plugins/__init__.py +0 -0
- tensorbored/plugins/audio/__init__.py +0 -0
- tensorbored/plugins/audio/audio_plugin.py +229 -0
- tensorbored/plugins/audio/metadata.py +69 -0
- tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
- tensorbored/plugins/audio/summary.py +230 -0
- tensorbored/plugins/audio/summary_v2.py +124 -0
- tensorbored/plugins/base_plugin.py +367 -0
- tensorbored/plugins/core/__init__.py +0 -0
- tensorbored/plugins/core/core_plugin.py +981 -0
- tensorbored/plugins/custom_scalar/__init__.py +0 -0
- tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
- tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
- tensorbored/plugins/custom_scalar/metadata.py +35 -0
- tensorbored/plugins/custom_scalar/summary.py +79 -0
- tensorbored/plugins/debugger_v2/__init__.py +0 -0
- tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
- tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
- tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
- tensorbored/plugins/distribution/__init__.py +0 -0
- tensorbored/plugins/distribution/compressor.py +158 -0
- tensorbored/plugins/distribution/distributions_plugin.py +116 -0
- tensorbored/plugins/distribution/metadata.py +19 -0
- tensorbored/plugins/graph/__init__.py +0 -0
- tensorbored/plugins/graph/graph_util.py +129 -0
- tensorbored/plugins/graph/graphs_plugin.py +336 -0
- tensorbored/plugins/graph/keras_util.py +328 -0
- tensorbored/plugins/graph/metadata.py +42 -0
- tensorbored/plugins/histogram/__init__.py +0 -0
- tensorbored/plugins/histogram/histograms_plugin.py +144 -0
- tensorbored/plugins/histogram/metadata.py +63 -0
- tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
- tensorbored/plugins/histogram/summary.py +234 -0
- tensorbored/plugins/histogram/summary_v2.py +292 -0
- tensorbored/plugins/hparams/__init__.py +14 -0
- tensorbored/plugins/hparams/_keras.py +93 -0
- tensorbored/plugins/hparams/api.py +130 -0
- tensorbored/plugins/hparams/api_pb2.py +208 -0
- tensorbored/plugins/hparams/backend_context.py +606 -0
- tensorbored/plugins/hparams/download_data.py +158 -0
- tensorbored/plugins/hparams/error.py +26 -0
- tensorbored/plugins/hparams/get_experiment.py +71 -0
- tensorbored/plugins/hparams/hparams_plugin.py +206 -0
- tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
- tensorbored/plugins/hparams/json_format_compat.py +38 -0
- tensorbored/plugins/hparams/list_metric_evals.py +57 -0
- tensorbored/plugins/hparams/list_session_groups.py +1040 -0
- tensorbored/plugins/hparams/metadata.py +125 -0
- tensorbored/plugins/hparams/metrics.py +41 -0
- tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
- tensorbored/plugins/hparams/summary.py +205 -0
- tensorbored/plugins/hparams/summary_v2.py +597 -0
- tensorbored/plugins/image/__init__.py +0 -0
- tensorbored/plugins/image/images_plugin.py +232 -0
- tensorbored/plugins/image/metadata.py +65 -0
- tensorbored/plugins/image/plugin_data_pb2.py +34 -0
- tensorbored/plugins/image/summary.py +159 -0
- tensorbored/plugins/image/summary_v2.py +130 -0
- tensorbored/plugins/mesh/__init__.py +14 -0
- tensorbored/plugins/mesh/mesh_plugin.py +292 -0
- tensorbored/plugins/mesh/metadata.py +152 -0
- tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
- tensorbored/plugins/mesh/summary.py +251 -0
- tensorbored/plugins/mesh/summary_v2.py +214 -0
- tensorbored/plugins/metrics/__init__.py +0 -0
- tensorbored/plugins/metrics/metadata.py +17 -0
- tensorbored/plugins/metrics/metrics_plugin.py +623 -0
- tensorbored/plugins/pr_curve/__init__.py +0 -0
- tensorbored/plugins/pr_curve/metadata.py +75 -0
- tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
- tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
- tensorbored/plugins/pr_curve/summary.py +574 -0
- tensorbored/plugins/profile_redirect/__init__.py +0 -0
- tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
- tensorbored/plugins/projector/__init__.py +67 -0
- tensorbored/plugins/projector/metadata.py +26 -0
- tensorbored/plugins/projector/projector_config_pb2.py +54 -0
- tensorbored/plugins/projector/projector_plugin.py +795 -0
- tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
- tensorbored/plugins/scalar/__init__.py +0 -0
- tensorbored/plugins/scalar/metadata.py +60 -0
- tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
- tensorbored/plugins/scalar/scalars_plugin.py +181 -0
- tensorbored/plugins/scalar/summary.py +109 -0
- tensorbored/plugins/scalar/summary_v2.py +124 -0
- tensorbored/plugins/text/__init__.py +0 -0
- tensorbored/plugins/text/metadata.py +62 -0
- tensorbored/plugins/text/plugin_data_pb2.py +34 -0
- tensorbored/plugins/text/summary.py +114 -0
- tensorbored/plugins/text/summary_v2.py +124 -0
- tensorbored/plugins/text/text_plugin.py +288 -0
- tensorbored/plugins/wit_redirect/__init__.py +0 -0
- tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
- tensorbored/program.py +910 -0
- tensorbored/summary/__init__.py +35 -0
- tensorbored/summary/_output.py +124 -0
- tensorbored/summary/_tf/__init__.py +14 -0
- tensorbored/summary/_tf/summary/__init__.py +178 -0
- tensorbored/summary/_writer.py +105 -0
- tensorbored/summary/v1.py +51 -0
- tensorbored/summary/v2.py +25 -0
- tensorbored/summary/writer/__init__.py +13 -0
- tensorbored/summary/writer/event_file_writer.py +291 -0
- tensorbored/summary/writer/record_writer.py +50 -0
- tensorbored/util/__init__.py +0 -0
- tensorbored/util/encoder.py +116 -0
- tensorbored/util/grpc_util.py +311 -0
- tensorbored/util/img_mime_type_detector.py +40 -0
- tensorbored/util/io_util.py +20 -0
- tensorbored/util/lazy_tensor_creator.py +110 -0
- tensorbored/util/op_evaluator.py +104 -0
- tensorbored/util/platform_util.py +20 -0
- tensorbored/util/tb_logging.py +24 -0
- tensorbored/util/tensor_util.py +617 -0
- tensorbored/util/timing.py +122 -0
- tensorbored/version.py +21 -0
- tensorbored/webfiles.zip +0 -0
- tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
- tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
- tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
- tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
- tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
- tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1035 @@
|
|
|
1
|
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Helper classes for tensor shape inference."""
|
|
16
|
+
|
|
17
|
+
# pytype: skip-file
|
|
18
|
+
|
|
19
|
+
from . import compat, dtypes
|
|
20
|
+
from tensorbored.compat.proto import tensor_shape_pb2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# @tf_export("Dimension")
|
|
24
|
+
class Dimension:
|
|
25
|
+
"""Represents the value of one dimension in a TensorShape."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, value):
|
|
28
|
+
"""Creates a new Dimension with the given value."""
|
|
29
|
+
if value is None:
|
|
30
|
+
self._value = None
|
|
31
|
+
elif isinstance(value, dtypes.DType):
|
|
32
|
+
raise TypeError("Cannot convert %s to Dimension" % value)
|
|
33
|
+
else:
|
|
34
|
+
self._value = int(value)
|
|
35
|
+
if (
|
|
36
|
+
not isinstance(value, compat.bytes_or_text_types)
|
|
37
|
+
and self._value != value
|
|
38
|
+
):
|
|
39
|
+
raise ValueError("Ambiguous dimension: %s" % value)
|
|
40
|
+
if self._value < 0:
|
|
41
|
+
raise ValueError("Dimension %d must be >= 0" % self._value)
|
|
42
|
+
|
|
43
|
+
def __repr__(self):
|
|
44
|
+
return "Dimension(%s)" % repr(self._value)
|
|
45
|
+
|
|
46
|
+
def __str__(self):
|
|
47
|
+
value = self._value
|
|
48
|
+
return "?" if value is None else str(value)
|
|
49
|
+
|
|
50
|
+
def __eq__(self, other):
|
|
51
|
+
"""Returns true if `other` has the same known value as this
|
|
52
|
+
Dimension."""
|
|
53
|
+
try:
|
|
54
|
+
other = as_dimension(other)
|
|
55
|
+
except (TypeError, ValueError):
|
|
56
|
+
return NotImplemented
|
|
57
|
+
if self._value is None or other.value is None:
|
|
58
|
+
return None
|
|
59
|
+
return self._value == other.value
|
|
60
|
+
|
|
61
|
+
def __ne__(self, other):
|
|
62
|
+
"""Returns true if `other` has a different known value from `self`."""
|
|
63
|
+
try:
|
|
64
|
+
other = as_dimension(other)
|
|
65
|
+
except (TypeError, ValueError):
|
|
66
|
+
return NotImplemented
|
|
67
|
+
if self._value is None or other.value is None:
|
|
68
|
+
return None
|
|
69
|
+
return self._value != other.value
|
|
70
|
+
|
|
71
|
+
def __int__(self):
|
|
72
|
+
return self._value
|
|
73
|
+
|
|
74
|
+
# This is needed for Windows.
|
|
75
|
+
# See https://github.com/tensorflow/tensorflow/pull/9780
|
|
76
|
+
def __long__(self):
|
|
77
|
+
return self._value
|
|
78
|
+
|
|
79
|
+
def __index__(self):
|
|
80
|
+
# Allow use in Python 3 range
|
|
81
|
+
return self._value
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def value(self):
|
|
85
|
+
"""The value of this dimension, or None if it is unknown."""
|
|
86
|
+
return self._value
|
|
87
|
+
|
|
88
|
+
def is_convertible_with(self, other):
|
|
89
|
+
"""Returns true if `other` is convertible with this Dimension.
|
|
90
|
+
|
|
91
|
+
Two known Dimensions are convertible if they have the same value.
|
|
92
|
+
An unknown Dimension is convertible with all other Dimensions.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
other: Another Dimension.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
True if this Dimension and `other` are convertible.
|
|
99
|
+
"""
|
|
100
|
+
other = as_dimension(other)
|
|
101
|
+
return (
|
|
102
|
+
self._value is None
|
|
103
|
+
or other.value is None
|
|
104
|
+
or self._value == other.value
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def assert_is_convertible_with(self, other):
|
|
108
|
+
"""Raises an exception if `other` is not convertible with this
|
|
109
|
+
Dimension.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
other: Another Dimension.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
ValueError: If `self` and `other` are not convertible (see
|
|
116
|
+
is_convertible_with).
|
|
117
|
+
"""
|
|
118
|
+
if not self.is_convertible_with(other):
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Dimensions %s and %s are not convertible" % (self, other)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def merge_with(self, other):
|
|
124
|
+
"""Returns a Dimension that combines the information in `self` and
|
|
125
|
+
`other`.
|
|
126
|
+
|
|
127
|
+
Dimensions are combined as follows:
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n)
|
|
131
|
+
tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n)
|
|
132
|
+
tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n)
|
|
133
|
+
tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None)
|
|
134
|
+
tf.Dimension(n) .merge_with(tf.Dimension(m)) # raises ValueError for n != m
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
other: Another Dimension.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
A Dimension containing the combined information of `self` and
|
|
142
|
+
`other`.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
ValueError: If `self` and `other` are not convertible (see
|
|
146
|
+
is_convertible_with).
|
|
147
|
+
"""
|
|
148
|
+
other = as_dimension(other)
|
|
149
|
+
self.assert_is_convertible_with(other)
|
|
150
|
+
if self._value is None:
|
|
151
|
+
return Dimension(other.value)
|
|
152
|
+
else:
|
|
153
|
+
return Dimension(self._value)
|
|
154
|
+
|
|
155
|
+
def __add__(self, other):
|
|
156
|
+
"""Returns the sum of `self` and `other`.
|
|
157
|
+
|
|
158
|
+
Dimensions are summed as follows:
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n)
|
|
162
|
+
tf.Dimension(m) + tf.Dimension(None) == tf.Dimension(None)
|
|
163
|
+
tf.Dimension(None) + tf.Dimension(n) == tf.Dimension(None)
|
|
164
|
+
tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None)
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
A Dimension whose value is the sum of `self` and `other`.
|
|
172
|
+
"""
|
|
173
|
+
other = as_dimension(other)
|
|
174
|
+
if self._value is None or other.value is None:
|
|
175
|
+
return Dimension(None)
|
|
176
|
+
else:
|
|
177
|
+
return Dimension(self._value + other.value)
|
|
178
|
+
|
|
179
|
+
def __radd__(self, other):
|
|
180
|
+
"""Returns the sum of `other` and `self`.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
A Dimension whose value is the sum of `self` and `other`.
|
|
187
|
+
"""
|
|
188
|
+
return self + other
|
|
189
|
+
|
|
190
|
+
def __sub__(self, other):
|
|
191
|
+
"""Returns the subtraction of `other` from `self`.
|
|
192
|
+
|
|
193
|
+
Dimensions are subtracted as follows:
|
|
194
|
+
|
|
195
|
+
```python
|
|
196
|
+
tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n)
|
|
197
|
+
tf.Dimension(m) - tf.Dimension(None) == tf.Dimension(None)
|
|
198
|
+
tf.Dimension(None) - tf.Dimension(n) == tf.Dimension(None)
|
|
199
|
+
tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None)
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
A Dimension whose value is the subtraction of `other` from `self`.
|
|
207
|
+
"""
|
|
208
|
+
other = as_dimension(other)
|
|
209
|
+
if self._value is None or other.value is None:
|
|
210
|
+
return Dimension(None)
|
|
211
|
+
else:
|
|
212
|
+
return Dimension(self._value - other.value)
|
|
213
|
+
|
|
214
|
+
def __rsub__(self, other):
|
|
215
|
+
"""Returns the subtraction of `self` from `other`.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
A Dimension whose value is the subtraction of `self` from `other`.
|
|
222
|
+
"""
|
|
223
|
+
other = as_dimension(other)
|
|
224
|
+
if self._value is None or other.value is None:
|
|
225
|
+
return Dimension(None)
|
|
226
|
+
else:
|
|
227
|
+
return Dimension(other.value - self._value)
|
|
228
|
+
|
|
229
|
+
def __mul__(self, other):
|
|
230
|
+
"""Returns the product of `self` and `other`.
|
|
231
|
+
|
|
232
|
+
Dimensions are summed as follows:
|
|
233
|
+
|
|
234
|
+
```python
|
|
235
|
+
tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n)
|
|
236
|
+
tf.Dimension(m) * tf.Dimension(None) == tf.Dimension(None)
|
|
237
|
+
tf.Dimension(None) * tf.Dimension(n) == tf.Dimension(None)
|
|
238
|
+
tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None)
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
A Dimension whose value is the product of `self` and `other`.
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
other = as_dimension(other)
|
|
249
|
+
except (TypeError, ValueError):
|
|
250
|
+
return NotImplemented
|
|
251
|
+
|
|
252
|
+
if self._value is None or other.value is None:
|
|
253
|
+
return Dimension(None)
|
|
254
|
+
else:
|
|
255
|
+
return Dimension(self._value * other.value)
|
|
256
|
+
|
|
257
|
+
def __rmul__(self, other):
|
|
258
|
+
"""Returns the product of `self` and `other`.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
A Dimension whose value is the product of `self` and `other`.
|
|
265
|
+
"""
|
|
266
|
+
return self * other
|
|
267
|
+
|
|
268
|
+
def __floordiv__(self, other):
|
|
269
|
+
"""Returns the quotient of `self` and `other` rounded down.
|
|
270
|
+
|
|
271
|
+
Dimensions are divided as follows:
|
|
272
|
+
|
|
273
|
+
```python
|
|
274
|
+
tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n)
|
|
275
|
+
tf.Dimension(m) // tf.Dimension(None) == tf.Dimension(None)
|
|
276
|
+
tf.Dimension(None) // tf.Dimension(n) == tf.Dimension(None)
|
|
277
|
+
tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None)
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A `Dimension` whose value is the integer quotient of `self` and `other`.
|
|
285
|
+
"""
|
|
286
|
+
try:
|
|
287
|
+
other = as_dimension(other)
|
|
288
|
+
except (TypeError, ValueError):
|
|
289
|
+
return NotImplemented
|
|
290
|
+
if self._value is None or other.value is None:
|
|
291
|
+
return Dimension(None)
|
|
292
|
+
else:
|
|
293
|
+
return Dimension(self._value // other.value)
|
|
294
|
+
|
|
295
|
+
def __rfloordiv__(self, other):
|
|
296
|
+
"""Returns the quotient of `other` and `self` rounded down.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
A `Dimension` whose value is the integer quotient of `self` and `other`.
|
|
303
|
+
"""
|
|
304
|
+
other = as_dimension(other)
|
|
305
|
+
if self._value is None or other.value is None:
|
|
306
|
+
return Dimension(None)
|
|
307
|
+
else:
|
|
308
|
+
return Dimension(other.value // self._value)
|
|
309
|
+
|
|
310
|
+
def __div__(self, other):
|
|
311
|
+
"""DEPRECATED: Use `__floordiv__` via `x // y` instead.
|
|
312
|
+
|
|
313
|
+
This function exists only for backwards convertibility purposes; new code
|
|
314
|
+
should use `__floordiv__` via the syntax `x // y`. Using `x // y`
|
|
315
|
+
communicates clearly that the result rounds down, and is forward convertible
|
|
316
|
+
to Python 3.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
other: Another `Dimension`.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
A `Dimension` whose value is the integer quotient of `self` and `other`.
|
|
323
|
+
"""
|
|
324
|
+
return self // other
|
|
325
|
+
|
|
326
|
+
def __mod__(self, other):
|
|
327
|
+
"""Returns `self` modulo `other`.
|
|
328
|
+
|
|
329
|
+
Dimension moduli are computed as follows:
|
|
330
|
+
|
|
331
|
+
```python
|
|
332
|
+
tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n)
|
|
333
|
+
tf.Dimension(m) % tf.Dimension(None) == tf.Dimension(None)
|
|
334
|
+
tf.Dimension(None) % tf.Dimension(n) == tf.Dimension(None)
|
|
335
|
+
tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None)
|
|
336
|
+
```
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
A Dimension whose value is `self` modulo `other`.
|
|
343
|
+
"""
|
|
344
|
+
try:
|
|
345
|
+
other = as_dimension(other)
|
|
346
|
+
except (TypeError, ValueError):
|
|
347
|
+
return NotImplemented
|
|
348
|
+
if self._value is None or other.value is None:
|
|
349
|
+
return Dimension(None)
|
|
350
|
+
else:
|
|
351
|
+
return Dimension(self._value % other.value)
|
|
352
|
+
|
|
353
|
+
def __rmod__(self, other):
|
|
354
|
+
"""Returns `other` modulo `self`.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
other: Another Dimension, or a value accepted by `as_dimension`.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
A Dimension whose value is `other` modulo `self`.
|
|
361
|
+
"""
|
|
362
|
+
try:
|
|
363
|
+
other = as_dimension(other)
|
|
364
|
+
except (TypeError, ValueError):
|
|
365
|
+
return NotImplemented
|
|
366
|
+
return other % self
|
|
367
|
+
|
|
368
|
+
def __lt__(self, other):
|
|
369
|
+
"""Returns True if `self` is known to be less than `other`.
|
|
370
|
+
|
|
371
|
+
Dimensions are compared as follows:
|
|
372
|
+
|
|
373
|
+
```python
|
|
374
|
+
(tf.Dimension(m) < tf.Dimension(n)) == (m < n)
|
|
375
|
+
(tf.Dimension(m) < tf.Dimension(None)) == None
|
|
376
|
+
(tf.Dimension(None) < tf.Dimension(n)) == None
|
|
377
|
+
(tf.Dimension(None) < tf.Dimension(None)) == None
|
|
378
|
+
```
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
other: Another Dimension.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
The value of `self.value < other.value` if both are known, otherwise
|
|
385
|
+
None.
|
|
386
|
+
"""
|
|
387
|
+
other = as_dimension(other)
|
|
388
|
+
if self._value is None or other.value is None:
|
|
389
|
+
return None
|
|
390
|
+
else:
|
|
391
|
+
return self._value < other.value
|
|
392
|
+
|
|
393
|
+
def __le__(self, other):
|
|
394
|
+
"""Returns True if `self` is known to be less than or equal to `other`.
|
|
395
|
+
|
|
396
|
+
Dimensions are compared as follows:
|
|
397
|
+
|
|
398
|
+
```python
|
|
399
|
+
(tf.Dimension(m) <= tf.Dimension(n)) == (m <= n)
|
|
400
|
+
(tf.Dimension(m) <= tf.Dimension(None)) == None
|
|
401
|
+
(tf.Dimension(None) <= tf.Dimension(n)) == None
|
|
402
|
+
(tf.Dimension(None) <= tf.Dimension(None)) == None
|
|
403
|
+
```
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
other: Another Dimension.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
The value of `self.value <= other.value` if both are known, otherwise
|
|
410
|
+
None.
|
|
411
|
+
"""
|
|
412
|
+
other = as_dimension(other)
|
|
413
|
+
if self._value is None or other.value is None:
|
|
414
|
+
return None
|
|
415
|
+
else:
|
|
416
|
+
return self._value <= other.value
|
|
417
|
+
|
|
418
|
+
def __gt__(self, other):
|
|
419
|
+
"""Returns True if `self` is known to be greater than `other`.
|
|
420
|
+
|
|
421
|
+
Dimensions are compared as follows:
|
|
422
|
+
|
|
423
|
+
```python
|
|
424
|
+
(tf.Dimension(m) > tf.Dimension(n)) == (m > n)
|
|
425
|
+
(tf.Dimension(m) > tf.Dimension(None)) == None
|
|
426
|
+
(tf.Dimension(None) > tf.Dimension(n)) == None
|
|
427
|
+
(tf.Dimension(None) > tf.Dimension(None)) == None
|
|
428
|
+
```
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
other: Another Dimension.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
The value of `self.value > other.value` if both are known, otherwise
|
|
435
|
+
None.
|
|
436
|
+
"""
|
|
437
|
+
other = as_dimension(other)
|
|
438
|
+
if self._value is None or other.value is None:
|
|
439
|
+
return None
|
|
440
|
+
else:
|
|
441
|
+
return self._value > other.value
|
|
442
|
+
|
|
443
|
+
def __ge__(self, other):
|
|
444
|
+
"""Returns True if `self` is known to be greater than or equal to
|
|
445
|
+
`other`.
|
|
446
|
+
|
|
447
|
+
Dimensions are compared as follows:
|
|
448
|
+
|
|
449
|
+
```python
|
|
450
|
+
(tf.Dimension(m) >= tf.Dimension(n)) == (m >= n)
|
|
451
|
+
(tf.Dimension(m) >= tf.Dimension(None)) == None
|
|
452
|
+
(tf.Dimension(None) >= tf.Dimension(n)) == None
|
|
453
|
+
(tf.Dimension(None) >= tf.Dimension(None)) == None
|
|
454
|
+
```
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
other: Another Dimension.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
The value of `self.value >= other.value` if both are known, otherwise
|
|
461
|
+
None.
|
|
462
|
+
"""
|
|
463
|
+
other = as_dimension(other)
|
|
464
|
+
if self._value is None or other.value is None:
|
|
465
|
+
return None
|
|
466
|
+
else:
|
|
467
|
+
return self._value >= other.value
|
|
468
|
+
|
|
469
|
+
def __reduce__(self):
|
|
470
|
+
return Dimension, (self._value,)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def as_dimension(value):
|
|
474
|
+
"""Converts the given value to a Dimension.
|
|
475
|
+
|
|
476
|
+
A Dimension input will be returned unmodified.
|
|
477
|
+
An input of `None` will be converted to an unknown Dimension.
|
|
478
|
+
An integer input will be converted to a Dimension with that value.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
value: The value to be converted.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
A Dimension corresponding to the given value.
|
|
485
|
+
"""
|
|
486
|
+
if isinstance(value, Dimension):
|
|
487
|
+
return value
|
|
488
|
+
else:
|
|
489
|
+
return Dimension(value)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# @tf_export("TensorShape")
|
|
493
|
+
class TensorShape:
|
|
494
|
+
"""Represents the shape of a `Tensor`.
|
|
495
|
+
|
|
496
|
+
A `TensorShape` represents a possibly-partial shape specification for a
|
|
497
|
+
`Tensor`. It may be one of the following:
|
|
498
|
+
|
|
499
|
+
* *Fully-known shape:* has a known number of dimensions and a known size
|
|
500
|
+
for each dimension. e.g. `TensorShape([16, 256])`
|
|
501
|
+
* *Partially-known shape:* has a known number of dimensions, and an unknown
|
|
502
|
+
size for one or more dimension. e.g. `TensorShape([None, 256])`
|
|
503
|
+
* *Unknown shape:* has an unknown number of dimensions, and an unknown
|
|
504
|
+
size in all dimensions. e.g. `TensorShape(None)`
|
|
505
|
+
|
|
506
|
+
If a tensor is produced by an operation of type `"Foo"`, its shape
|
|
507
|
+
may be inferred if there is a registered shape function for
|
|
508
|
+
`"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
|
|
509
|
+
for details of shape functions and how to register them. Alternatively,
|
|
510
|
+
the shape may be set explicitly using @{tf.Tensor.set_shape}.
|
|
511
|
+
"""
|
|
512
|
+
|
|
513
|
+
def __init__(self, dims):
|
|
514
|
+
"""Creates a new TensorShape with the given dimensions.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
dims: A list of Dimensions, or None if the shape is unspecified.
|
|
518
|
+
DEPRECATED: A single integer is treated as a singleton list.
|
|
519
|
+
|
|
520
|
+
Raises:
|
|
521
|
+
TypeError: If dims cannot be converted to a list of dimensions.
|
|
522
|
+
"""
|
|
523
|
+
# TODO(irving): Eliminate the single integer special case.
|
|
524
|
+
if dims is None:
|
|
525
|
+
self._dims = None
|
|
526
|
+
elif isinstance(dims, compat.bytes_or_text_types):
|
|
527
|
+
raise TypeError(
|
|
528
|
+
"A string has ambiguous TensorShape, please wrap in a "
|
|
529
|
+
"list or convert to an int: %s" % dims
|
|
530
|
+
)
|
|
531
|
+
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
|
|
532
|
+
if dims.unknown_rank:
|
|
533
|
+
self._dims = None
|
|
534
|
+
else:
|
|
535
|
+
self._dims = [
|
|
536
|
+
# Protos store variable-size dimensions as -1
|
|
537
|
+
as_dimension(dim.size if dim.size != -1 else None)
|
|
538
|
+
for dim in dims.dim
|
|
539
|
+
]
|
|
540
|
+
elif isinstance(dims, TensorShape):
|
|
541
|
+
self._dims = dims.dims
|
|
542
|
+
else:
|
|
543
|
+
try:
|
|
544
|
+
dims_iter = iter(dims)
|
|
545
|
+
except TypeError:
|
|
546
|
+
# Treat as a singleton dimension
|
|
547
|
+
self._dims = [as_dimension(dims)]
|
|
548
|
+
else:
|
|
549
|
+
# Got a list of dimensions
|
|
550
|
+
self._dims = [as_dimension(d) for d in dims_iter]
|
|
551
|
+
self._ndims = None
|
|
552
|
+
|
|
553
|
+
def __repr__(self):
|
|
554
|
+
return "TensorShape(%r)" % self._dims
|
|
555
|
+
|
|
556
|
+
def __str__(self):
|
|
557
|
+
if self.ndims is None:
|
|
558
|
+
return "<unknown>"
|
|
559
|
+
elif self.ndims == 1:
|
|
560
|
+
return "(%s,)" % self._dims[0]
|
|
561
|
+
else:
|
|
562
|
+
return "(%s)" % ", ".join(str(d) for d in self._dims)
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def dims(self):
|
|
566
|
+
"""Returns a list of Dimensions, or None if the shape is
|
|
567
|
+
unspecified."""
|
|
568
|
+
return self._dims
|
|
569
|
+
|
|
570
|
+
@dims.setter
|
|
571
|
+
def dims(self, dims):
|
|
572
|
+
self._dims = dims
|
|
573
|
+
self._ndims = None
|
|
574
|
+
|
|
575
|
+
@property
|
|
576
|
+
def ndims(self):
|
|
577
|
+
"""Returns the rank of this shape, or None if it is unspecified."""
|
|
578
|
+
if self._dims is None:
|
|
579
|
+
return None
|
|
580
|
+
else:
|
|
581
|
+
if self._ndims is None:
|
|
582
|
+
self._ndims = len(self._dims)
|
|
583
|
+
return self._ndims
|
|
584
|
+
|
|
585
|
+
def __len__(self):
|
|
586
|
+
"""Returns the rank of this shape, or raises ValueError if
|
|
587
|
+
unspecified."""
|
|
588
|
+
if self._dims is None:
|
|
589
|
+
raise ValueError(
|
|
590
|
+
"Cannot take the length of Shape with unknown rank."
|
|
591
|
+
)
|
|
592
|
+
return self.ndims
|
|
593
|
+
|
|
594
|
+
def __bool__(self):
|
|
595
|
+
"""Returns True if this shape contains non-zero information."""
|
|
596
|
+
return self._dims is not None
|
|
597
|
+
|
|
598
|
+
# Python 3 wants __bool__, Python 2.7 wants __nonzero__
|
|
599
|
+
__nonzero__ = __bool__
|
|
600
|
+
|
|
601
|
+
def __iter__(self):
|
|
602
|
+
"""Returns `self.dims` if the rank is known, otherwise raises
|
|
603
|
+
ValueError."""
|
|
604
|
+
if self._dims is None:
|
|
605
|
+
raise ValueError("Cannot iterate over a shape with unknown rank.")
|
|
606
|
+
else:
|
|
607
|
+
return iter(self._dims)
|
|
608
|
+
|
|
609
|
+
def __getitem__(self, key):
|
|
610
|
+
"""Returns the value of a dimension or a shape, depending on the key.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
key: If `key` is an integer, returns the dimension at that index;
|
|
614
|
+
otherwise if `key` is a slice, returns a TensorShape whose
|
|
615
|
+
dimensions are those selected by the slice from `self`.
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
A dimension if `key` is an integer, or a `TensorShape` if `key` is a
|
|
619
|
+
slice.
|
|
620
|
+
|
|
621
|
+
Raises:
|
|
622
|
+
ValueError: If `key` is a slice, and any of its elements are negative, or
|
|
623
|
+
if `self` is completely unknown and the step is set.
|
|
624
|
+
"""
|
|
625
|
+
if self._dims is not None:
|
|
626
|
+
if isinstance(key, slice):
|
|
627
|
+
return TensorShape(self._dims[key])
|
|
628
|
+
else:
|
|
629
|
+
return self._dims[key]
|
|
630
|
+
else:
|
|
631
|
+
if isinstance(key, slice):
|
|
632
|
+
start = key.start if key.start is not None else 0
|
|
633
|
+
stop = key.stop
|
|
634
|
+
|
|
635
|
+
if key.step is not None:
|
|
636
|
+
# TODO(mrry): Handle these maybe.
|
|
637
|
+
raise ValueError("Steps are not yet handled")
|
|
638
|
+
if stop is None:
|
|
639
|
+
# NOTE(mrry): This implies that TensorShape(None) is convertible with
|
|
640
|
+
# TensorShape(None)[1:], which is obviously not true. It would be
|
|
641
|
+
# possible to track the number of dimensions symbolically,
|
|
642
|
+
# and perhaps we should do that.
|
|
643
|
+
return unknown_shape()
|
|
644
|
+
elif start < 0 or stop < 0:
|
|
645
|
+
# TODO(mrry): Handle this better, as it will be useful for handling
|
|
646
|
+
# suffixes of otherwise unknown shapes.
|
|
647
|
+
return unknown_shape()
|
|
648
|
+
else:
|
|
649
|
+
return unknown_shape(ndims=stop - start)
|
|
650
|
+
else:
|
|
651
|
+
return Dimension(None)
|
|
652
|
+
|
|
653
|
+
def num_elements(self):
|
|
654
|
+
"""Returns the total number of elements, or none for incomplete
|
|
655
|
+
shapes."""
|
|
656
|
+
if self.is_fully_defined():
|
|
657
|
+
size = 1
|
|
658
|
+
for dim in self._dims:
|
|
659
|
+
size *= dim.value
|
|
660
|
+
return size
|
|
661
|
+
else:
|
|
662
|
+
return None
|
|
663
|
+
|
|
664
|
+
def merge_with(self, other):
|
|
665
|
+
"""Returns a `TensorShape` combining the information in `self` and
|
|
666
|
+
`other`.
|
|
667
|
+
|
|
668
|
+
The dimensions in `self` and `other` are merged elementwise,
|
|
669
|
+
according to the rules defined for `Dimension.merge_with()`.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
other: Another `TensorShape`.
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
A `TensorShape` containing the combined information of `self` and
|
|
676
|
+
`other`.
|
|
677
|
+
|
|
678
|
+
Raises:
|
|
679
|
+
ValueError: If `self` and `other` are not convertible.
|
|
680
|
+
"""
|
|
681
|
+
other = as_shape(other)
|
|
682
|
+
if self._dims is None:
|
|
683
|
+
return other
|
|
684
|
+
else:
|
|
685
|
+
try:
|
|
686
|
+
self.assert_same_rank(other)
|
|
687
|
+
new_dims = []
|
|
688
|
+
for i, dim in enumerate(self._dims):
|
|
689
|
+
new_dims.append(dim.merge_with(other[i]))
|
|
690
|
+
return TensorShape(new_dims)
|
|
691
|
+
except ValueError:
|
|
692
|
+
raise ValueError(
|
|
693
|
+
"Shapes %s and %s are not convertible" % (self, other)
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
def concatenate(self, other):
|
|
697
|
+
"""Returns the concatenation of the dimension in `self` and `other`.
|
|
698
|
+
|
|
699
|
+
*N.B.* If either `self` or `other` is completely unknown,
|
|
700
|
+
concatenation will discard information about the other shape. In
|
|
701
|
+
future, we might support concatenation that preserves this
|
|
702
|
+
information for use with slicing.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
other: Another `TensorShape`.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
A `TensorShape` whose dimensions are the concatenation of the
|
|
709
|
+
dimensions in `self` and `other`.
|
|
710
|
+
"""
|
|
711
|
+
# TODO(mrry): Handle the case where we concatenate a known shape with a
|
|
712
|
+
# completely unknown shape, so that we can use the partial information.
|
|
713
|
+
other = as_shape(other)
|
|
714
|
+
if self._dims is None or other.dims is None:
|
|
715
|
+
return unknown_shape()
|
|
716
|
+
else:
|
|
717
|
+
return TensorShape(self._dims + other.dims)
|
|
718
|
+
|
|
719
|
+
def assert_same_rank(self, other):
|
|
720
|
+
"""Raises an exception if `self` and `other` do not have convertible
|
|
721
|
+
ranks.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
other: Another `TensorShape`.
|
|
725
|
+
|
|
726
|
+
Raises:
|
|
727
|
+
ValueError: If `self` and `other` do not represent shapes with the
|
|
728
|
+
same rank.
|
|
729
|
+
"""
|
|
730
|
+
other = as_shape(other)
|
|
731
|
+
if self.ndims is not None and other.ndims is not None:
|
|
732
|
+
if self.ndims != other.ndims:
|
|
733
|
+
raise ValueError(
|
|
734
|
+
"Shapes %s and %s must have the same rank" % (self, other)
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
def assert_has_rank(self, rank):
|
|
738
|
+
"""Raises an exception if `self` is not convertible with the given
|
|
739
|
+
`rank`.
|
|
740
|
+
|
|
741
|
+
Args:
|
|
742
|
+
rank: An integer.
|
|
743
|
+
|
|
744
|
+
Raises:
|
|
745
|
+
ValueError: If `self` does not represent a shape with the given `rank`.
|
|
746
|
+
"""
|
|
747
|
+
if self.ndims not in (None, rank):
|
|
748
|
+
raise ValueError("Shape %s must have rank %d" % (self, rank))
|
|
749
|
+
|
|
750
|
+
def with_rank(self, rank):
|
|
751
|
+
"""Returns a shape based on `self` with the given rank.
|
|
752
|
+
|
|
753
|
+
This method promotes a completely unknown shape to one with a
|
|
754
|
+
known rank.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
rank: An integer.
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
A shape that is at least as specific as `self` with the given rank.
|
|
761
|
+
|
|
762
|
+
Raises:
|
|
763
|
+
ValueError: If `self` does not represent a shape with the given `rank`.
|
|
764
|
+
"""
|
|
765
|
+
try:
|
|
766
|
+
return self.merge_with(unknown_shape(ndims=rank))
|
|
767
|
+
except ValueError:
|
|
768
|
+
raise ValueError("Shape %s must have rank %d" % (self, rank))
|
|
769
|
+
|
|
770
|
+
def with_rank_at_least(self, rank):
|
|
771
|
+
"""Returns a shape based on `self` with at least the given rank.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
rank: An integer.
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
A shape that is at least as specific as `self` with at least the given
|
|
778
|
+
rank.
|
|
779
|
+
|
|
780
|
+
Raises:
|
|
781
|
+
ValueError: If `self` does not represent a shape with at least the given
|
|
782
|
+
`rank`.
|
|
783
|
+
"""
|
|
784
|
+
if self.ndims is not None and self.ndims < rank:
|
|
785
|
+
raise ValueError(
|
|
786
|
+
"Shape %s must have rank at least %d" % (self, rank)
|
|
787
|
+
)
|
|
788
|
+
else:
|
|
789
|
+
return self
|
|
790
|
+
|
|
791
|
+
def with_rank_at_most(self, rank):
|
|
792
|
+
"""Returns a shape based on `self` with at most the given rank.
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
rank: An integer.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
A shape that is at least as specific as `self` with at most the given
|
|
799
|
+
rank.
|
|
800
|
+
|
|
801
|
+
Raises:
|
|
802
|
+
ValueError: If `self` does not represent a shape with at most the given
|
|
803
|
+
`rank`.
|
|
804
|
+
"""
|
|
805
|
+
if self.ndims is not None and self.ndims > rank:
|
|
806
|
+
raise ValueError(
|
|
807
|
+
"Shape %s must have rank at most %d" % (self, rank)
|
|
808
|
+
)
|
|
809
|
+
else:
|
|
810
|
+
return self
|
|
811
|
+
|
|
812
|
+
def is_convertible_with(self, other):
|
|
813
|
+
"""Returns True iff `self` is convertible with `other`.
|
|
814
|
+
|
|
815
|
+
Two possibly-partially-defined shapes are convertible if there
|
|
816
|
+
exists a fully-defined shape that both shapes can represent. Thus,
|
|
817
|
+
convertibility allows the shape inference code to reason about
|
|
818
|
+
partially-defined shapes. For example:
|
|
819
|
+
|
|
820
|
+
* TensorShape(None) is convertible with all shapes.
|
|
821
|
+
|
|
822
|
+
* TensorShape([None, None]) is convertible with all two-dimensional
|
|
823
|
+
shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
|
|
824
|
+
not convertible with, for example, TensorShape([None]) or
|
|
825
|
+
TensorShape([None, None, None]).
|
|
826
|
+
|
|
827
|
+
* TensorShape([32, None]) is convertible with all two-dimensional shapes
|
|
828
|
+
with size 32 in the 0th dimension, and also TensorShape([None, None])
|
|
829
|
+
and TensorShape(None). It is not convertible with, for example,
|
|
830
|
+
TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
|
|
831
|
+
|
|
832
|
+
* TensorShape([32, 784]) is convertible with itself, and also
|
|
833
|
+
TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
|
|
834
|
+
None]) and TensorShape(None). It is not convertible with, for example,
|
|
835
|
+
TensorShape([32, 1, 784]) or TensorShape([None]).
|
|
836
|
+
|
|
837
|
+
The convertibility relation is reflexive and symmetric, but not
|
|
838
|
+
transitive. For example, TensorShape([32, 784]) is convertible with
|
|
839
|
+
TensorShape(None), and TensorShape(None) is convertible with
|
|
840
|
+
TensorShape([4, 4]), but TensorShape([32, 784]) is not convertible with
|
|
841
|
+
TensorShape([4, 4]).
|
|
842
|
+
|
|
843
|
+
Args:
|
|
844
|
+
other: Another TensorShape.
|
|
845
|
+
|
|
846
|
+
Returns:
|
|
847
|
+
True iff `self` is convertible with `other`.
|
|
848
|
+
"""
|
|
849
|
+
other = as_shape(other)
|
|
850
|
+
if self._dims is not None and other.dims is not None:
|
|
851
|
+
if self.ndims != other.ndims:
|
|
852
|
+
return False
|
|
853
|
+
for x_dim, y_dim in zip(self._dims, other.dims):
|
|
854
|
+
if not x_dim.is_convertible_with(y_dim):
|
|
855
|
+
return False
|
|
856
|
+
return True
|
|
857
|
+
|
|
858
|
+
def assert_is_convertible_with(self, other):
|
|
859
|
+
"""Raises exception if `self` and `other` do not represent the same
|
|
860
|
+
shape.
|
|
861
|
+
|
|
862
|
+
This method can be used to assert that there exists a shape that both
|
|
863
|
+
`self` and `other` represent.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
other: Another TensorShape.
|
|
867
|
+
|
|
868
|
+
Raises:
|
|
869
|
+
ValueError: If `self` and `other` do not represent the same shape.
|
|
870
|
+
"""
|
|
871
|
+
if not self.is_convertible_with(other):
|
|
872
|
+
raise ValueError(
|
|
873
|
+
"Shapes %s and %s are inconvertible" % (self, other)
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
def most_specific_convertible_shape(self, other):
|
|
877
|
+
"""Returns the most specific TensorShape convertible with `self` and
|
|
878
|
+
`other`.
|
|
879
|
+
|
|
880
|
+
* TensorShape([None, 1]) is the most specific TensorShape convertible with
|
|
881
|
+
both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
|
|
882
|
+
TensorShape(None) is also convertible with above mentioned TensorShapes.
|
|
883
|
+
|
|
884
|
+
* TensorShape([1, 2, 3]) is the most specific TensorShape convertible with
|
|
885
|
+
both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
|
|
886
|
+
less specific TensorShapes convertible with above mentioned TensorShapes,
|
|
887
|
+
e.g. TensorShape([1, 2, None]), TensorShape(None).
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
other: Another `TensorShape`.
|
|
891
|
+
|
|
892
|
+
Returns:
|
|
893
|
+
A `TensorShape` which is the most specific convertible shape of `self`
|
|
894
|
+
and `other`.
|
|
895
|
+
"""
|
|
896
|
+
|
|
897
|
+
other = as_shape(other)
|
|
898
|
+
if (
|
|
899
|
+
self._dims is None
|
|
900
|
+
or other.dims is None
|
|
901
|
+
or self.ndims != other.ndims
|
|
902
|
+
):
|
|
903
|
+
return unknown_shape()
|
|
904
|
+
|
|
905
|
+
dims = [Dimension(None)] * self.ndims
|
|
906
|
+
for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
|
|
907
|
+
if d1 is not None and d2 is not None and d1 == d2:
|
|
908
|
+
dims[i] = d1
|
|
909
|
+
return TensorShape(dims)
|
|
910
|
+
|
|
911
|
+
def is_fully_defined(self):
|
|
912
|
+
"""Returns True iff `self` is fully defined in every dimension."""
|
|
913
|
+
return self._dims is not None and all(
|
|
914
|
+
dim.value is not None for dim in self._dims
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
def assert_is_fully_defined(self):
|
|
918
|
+
"""Raises an exception if `self` is not fully defined in every
|
|
919
|
+
dimension.
|
|
920
|
+
|
|
921
|
+
Raises:
|
|
922
|
+
ValueError: If `self` does not have a known value for every dimension.
|
|
923
|
+
"""
|
|
924
|
+
if not self.is_fully_defined():
|
|
925
|
+
raise ValueError("Shape %s is not fully defined" % self)
|
|
926
|
+
|
|
927
|
+
def as_list(self):
|
|
928
|
+
"""Returns a list of integers or `None` for each dimension.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
A list of integers or `None` for each dimension.
|
|
932
|
+
|
|
933
|
+
Raises:
|
|
934
|
+
ValueError: If `self` is an unknown shape with an unknown rank.
|
|
935
|
+
"""
|
|
936
|
+
if self._dims is None:
|
|
937
|
+
raise ValueError(
|
|
938
|
+
"as_list() is not defined on an unknown TensorShape."
|
|
939
|
+
)
|
|
940
|
+
return [dim.value for dim in self._dims]
|
|
941
|
+
|
|
942
|
+
def as_proto(self):
|
|
943
|
+
"""Returns this shape as a `TensorShapeProto`."""
|
|
944
|
+
if self._dims is None:
|
|
945
|
+
return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
|
|
946
|
+
else:
|
|
947
|
+
return tensor_shape_pb2.TensorShapeProto(
|
|
948
|
+
dim=[
|
|
949
|
+
tensor_shape_pb2.TensorShapeProto.Dim(
|
|
950
|
+
size=-1 if d.value is None else d.value
|
|
951
|
+
)
|
|
952
|
+
for d in self._dims
|
|
953
|
+
]
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
def __eq__(self, other):
|
|
957
|
+
"""Returns True if `self` is equivalent to `other`."""
|
|
958
|
+
try:
|
|
959
|
+
other = as_shape(other)
|
|
960
|
+
except TypeError:
|
|
961
|
+
return NotImplemented
|
|
962
|
+
return self._dims == other.dims
|
|
963
|
+
|
|
964
|
+
def __ne__(self, other):
|
|
965
|
+
"""Returns True if `self` is known to be different from `other`."""
|
|
966
|
+
try:
|
|
967
|
+
other = as_shape(other)
|
|
968
|
+
except TypeError:
|
|
969
|
+
return NotImplemented
|
|
970
|
+
if self.ndims is None or other.ndims is None:
|
|
971
|
+
raise ValueError(
|
|
972
|
+
"The inequality of unknown TensorShapes is undefined."
|
|
973
|
+
)
|
|
974
|
+
if self.ndims != other.ndims:
|
|
975
|
+
return True
|
|
976
|
+
return self._dims != other.dims
|
|
977
|
+
|
|
978
|
+
def __reduce__(self):
|
|
979
|
+
return TensorShape, (self._dims,)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def as_shape(shape):
|
|
983
|
+
"""Converts the given object to a TensorShape."""
|
|
984
|
+
if isinstance(shape, TensorShape):
|
|
985
|
+
return shape
|
|
986
|
+
else:
|
|
987
|
+
return TensorShape(shape)
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def unknown_shape(ndims=None):
|
|
991
|
+
"""Returns an unknown TensorShape, optionally with a known rank.
|
|
992
|
+
|
|
993
|
+
Args:
|
|
994
|
+
ndims: (Optional) If specified, the number of dimensions in the shape.
|
|
995
|
+
|
|
996
|
+
Returns:
|
|
997
|
+
An unknown TensorShape.
|
|
998
|
+
"""
|
|
999
|
+
if ndims is None:
|
|
1000
|
+
return TensorShape(None)
|
|
1001
|
+
else:
|
|
1002
|
+
return TensorShape([Dimension(None)] * ndims)
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
_SCALAR_SHAPE = TensorShape([])
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def scalar():
|
|
1009
|
+
"""Returns a shape representing a scalar."""
|
|
1010
|
+
return _SCALAR_SHAPE
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def vector(length):
|
|
1014
|
+
"""Returns a shape representing a vector.
|
|
1015
|
+
|
|
1016
|
+
Args:
|
|
1017
|
+
length: The length of the vector, which may be None if unknown.
|
|
1018
|
+
|
|
1019
|
+
Returns:
|
|
1020
|
+
A TensorShape representing a vector of the given length.
|
|
1021
|
+
"""
|
|
1022
|
+
return TensorShape([length])
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def matrix(rows, cols):
|
|
1026
|
+
"""Returns a shape representing a matrix.
|
|
1027
|
+
|
|
1028
|
+
Args:
|
|
1029
|
+
rows: The number of rows in the matrix, which may be None if unknown.
|
|
1030
|
+
cols: The number of columns in the matrix, which may be None if unknown.
|
|
1031
|
+
|
|
1032
|
+
Returns:
|
|
1033
|
+
A TensorShape representing a matrix of the given size.
|
|
1034
|
+
"""
|
|
1035
|
+
return TensorShape([rows, cols])
|