tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250221__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (113) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  14. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  15. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  17. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  18. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
  19. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  20. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  21. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  22. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  23. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  24. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  25. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  26. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  27. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  28. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  30. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  31. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  32. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  33. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  34. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  35. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  36. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  37. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  38. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  39. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  40. tensorflow/include/tensorflow/core/public/version.h +112 -127
  41. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  42. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  43. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  44. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
  45. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  46. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  47. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  48. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  49. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  50. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  51. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  52. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  53. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  54. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  55. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  56. tensorflow/include/xla/service/constant_value.h +1 -0
  57. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  58. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  59. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  60. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  61. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  62. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  63. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  64. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  65. tensorflow/python/_pywrap_mlir.pyd +0 -0
  66. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  67. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  68. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  69. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  70. tensorflow/python/_pywrap_tfe.pyd +0 -0
  71. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  72. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  73. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  74. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  75. tensorflow/python/compat/compat.py +1 -1
  76. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  77. tensorflow/python/eager/imperative_grad.py +5 -5
  78. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  79. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  80. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  81. tensorflow/python/eager/tape.py +2 -2
  82. tensorflow/python/framework/_dtypes.pyd +0 -0
  83. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  84. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  85. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  86. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  87. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  88. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  89. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  90. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  91. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  92. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  93. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  94. tensorflow/python/ops/summary_ops_v2.py +5 -1
  95. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  96. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  97. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  98. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  99. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  100. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  101. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  102. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  103. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  104. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  105. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  106. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  107. tensorflow/python/util/_tf_stack.pyd +0 -0
  108. tensorflow/tools/pip_package/setup.py +2 -2
  109. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
  110. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +113 -106
  111. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
  112. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
  113. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -52,9 +52,9 @@ limitations under the License.
52
52
  #include "xla/pjrt/pjrt_client.h"
53
53
  #include "xla/pjrt/pjrt_common.h"
54
54
  #include "xla/pjrt/pjrt_compiler.h"
55
- #include "xla/pjrt/pjrt_device_description.h"
56
55
  #include "xla/pjrt/pjrt_executable.h"
57
56
  #include "xla/pjrt/pjrt_future.h"
57
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
58
58
  #include "xla/pjrt/tracked_device_buffer.h"
59
59
  #include "xla/pjrt/transpose.h"
60
60
  #include "xla/pjrt/utils.h"
@@ -77,54 +77,6 @@ limitations under the License.
77
77
 
78
78
  namespace xla {
79
79
 
80
- class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
81
- public:
82
- explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
83
- int process_index = 0)
84
- : id_(id),
85
- process_index_(process_index),
86
- device_kind_(std::move(device_kind)) {}
87
-
88
- int id() const override { return id_; }
89
-
90
- int process_index() const override { return process_index_; }
91
-
92
- absl::string_view device_kind() const override { return device_kind_; }
93
-
94
- absl::string_view ToString() const override { return to_string_; }
95
-
96
- absl::string_view DebugString() const override { return debug_string_; }
97
-
98
- absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
99
-
100
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
101
- const override {
102
- return attributes_;
103
- }
104
-
105
- void SetAttributes(
106
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
107
- attributes_ = std::move(attributes);
108
- }
109
-
110
- void SetDebugString(std::string debug_string) {
111
- debug_string_ = std::move(debug_string);
112
- }
113
-
114
- void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
115
-
116
- void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
117
-
118
- private:
119
- const int id_;
120
- const int process_index_;
121
- const std::string device_kind_;
122
- std::string debug_string_ = "<unknown SE device>";
123
- std::string to_string_ = "<unknown SE device>";
124
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
125
- std::array<int, 1> coords_;
126
- };
127
-
128
80
  class PjRtStreamExecutorDevice : public PjRtDevice {
129
81
  public:
130
82
  explicit PjRtStreamExecutorDevice(
@@ -0,0 +1,75 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
13
+ #define XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
14
+
15
+ #include <array>
16
+ #include <string>
17
+ #include <utility>
18
+
19
+ #include "absl/container/flat_hash_map.h"
20
+ #include "absl/strings/string_view.h"
21
+ #include "absl/types/span.h"
22
+ #include "xla/pjrt/pjrt_device_description.h"
23
+
24
+ namespace xla {
25
+
26
+ class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
27
+ public:
28
+ explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
29
+ int process_index = 0)
30
+ : id_(id),
31
+ process_index_(process_index),
32
+ device_kind_(std::move(device_kind)) {}
33
+
34
+ int id() const override { return id_; }
35
+
36
+ int process_index() const override { return process_index_; }
37
+
38
+ absl::string_view device_kind() const override { return device_kind_; }
39
+
40
+ absl::string_view ToString() const override { return to_string_; }
41
+
42
+ absl::string_view DebugString() const override { return debug_string_; }
43
+
44
+ absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
45
+
46
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
47
+ const override {
48
+ return attributes_;
49
+ }
50
+
51
+ void SetAttributes(
52
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
53
+ attributes_ = std::move(attributes);
54
+ }
55
+
56
+ void SetDebugString(std::string debug_string) {
57
+ debug_string_ = std::move(debug_string);
58
+ }
59
+
60
+ void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
61
+
62
+ void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
63
+
64
+ private:
65
+ const int id_;
66
+ const int process_index_;
67
+ const std::string device_kind_;
68
+ std::string debug_string_ = "<unknown SE device>";
69
+ std::string to_string_ = "<unknown SE device>";
70
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
71
+ std::array<int, 1> coords_;
72
+ };
73
+ } // namespace xla
74
+
75
+ #endif // XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
@@ -0,0 +1,57 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
17
+ #define XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
18
+
19
+ #include <optional>
20
+
21
+ #include "xla/backends/cpu/collectives/cpu_collectives.h"
22
+ #include "xla/pjrt/pjrt_executable.h"
23
+
24
+ namespace xla {
25
+
26
+ // ExecuteContext for XLA:CPU PjRtLoadedExecutable::Execute calls.
27
+ class CpuExecuteContext : public ExecuteContext {
28
+ public:
29
+ ~CpuExecuteContext() override = default;
30
+
31
+ // If specified, override the process ID specified in
32
+ // `CpuClientOptions::process_id` for a particular call of
33
+ // PjRtLoadedExecutable::Execute.
34
+ //
35
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
36
+ // option with `ExecuteOptions::multi_slice_config`.
37
+ std::optional<int>& process_index() { return process_index_; }
38
+ std::optional<int> process_index() const { return process_index_; }
39
+
40
+ // If specified, override CPU collectives specified in
41
+ // `CpuClientOptions::collectives` for a particular call of
42
+ // PjRtLoadedExecutable::Execute. Must remain valid until the execution
43
+ // finishes.
44
+ //
45
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
46
+ // option with `ExecuteOptions::multi_slice_config`.
47
+ cpu::CpuCollectives*& collectives() { return collectives_; }
48
+ cpu::CpuCollectives* collectives() const { return collectives_; }
49
+
50
+ private:
51
+ std::optional<int> process_index_;
52
+ cpu::CpuCollectives* collectives_ = nullptr;
53
+ };
54
+
55
+ } // namespace xla
56
+
57
+ #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
@@ -69,6 +69,10 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) {
69
69
  return global_device_id.value() / kMaxCpuDevicesPerProcess;
70
70
  }
71
71
 
72
+ inline int UnpackCpuLocalDeviceId(PjRtGlobalDeviceId global_device_id) {
73
+ return global_device_id.value() % kMaxCpuDevicesPerProcess;
74
+ }
75
+
72
76
  } // namespace xla
73
77
 
74
78
  #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_
@@ -18,6 +18,7 @@ limitations under the License.
18
18
 
19
19
  #include <string>
20
20
 
21
+ #include "absl/base/casts.h"
21
22
  #include "absl/status/statusor.h"
22
23
  #include "xla/literal.h"
23
24
  #include "xla/util.h"
@@ -19,16 +19,67 @@ limitations under the License.
19
19
  #include <functional>
20
20
  #include <memory>
21
21
  #include <optional>
22
+ #include <string>
22
23
 
23
- #include "absl/status/status.h"
24
+ #include "absl/log/check.h"
25
+ #include "absl/log/log.h"
24
26
  #include "absl/status/statusor.h"
27
+ #include "absl/strings/string_view.h"
25
28
  #include "absl/types/span.h"
29
+ #include "xla/hlo/ir/hlo_module.h"
30
+ #include "xla/hlo/parser/hlo_parser.h"
26
31
  #include "xla/service/compiler.h"
27
32
  #include "xla/service/hlo_module_config.h"
28
33
  #include "xla/shape.h"
34
+ #include "xla/util.h"
29
35
 
30
36
  namespace xla {
31
37
 
38
+ // Converts an HloModule from the given hlo textual IR string (in
39
+ // HloModule::ToString format).
40
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
41
+ absl::string_view hlo_string,
42
+ const DebugOptions& debug_options = DebugOptions::default_instance());
43
+
44
+ // Creates an HloModule from the given proto.
45
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
46
+ const HloModuleProto& proto,
47
+ const DebugOptions& debug_options = DebugOptions::default_instance());
48
+
49
+ // Create an HLO state from serialized representation. In addition to
50
+ // creating the proto with HloModule::CreateFromProto(...) it also
51
+ // uses HloVerifier to ensure basic invariants are held.
52
+ // The HLO module could be a pre-optimizations (default) or post-optimizations
53
+ // module, which affects how the HLO module is verified, e.g., mixed-precision
54
+ // is allowed in post-optimizations HLOs.
55
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
56
+ const HloModuleProto& proto, const HloModuleConfig& module_config,
57
+ bool is_module_post_optimizations = false);
58
+
59
+ // Reads the proto file in xla.HloProto format, creates and returns the
60
+ // HloModule.
61
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
62
+ absl::string_view filename,
63
+ const DebugOptions& debug_options = DebugOptions::default_instance());
64
+
65
+ // Reads the proto file in xla.HloModule format, creates and returns the
66
+ // HloModule.
67
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
68
+ absl::string_view filename, const DebugOptions& debug_options);
69
+
70
+ // Reads the HLO text dump file in HloModule::ToString format, creates and
71
+ // returns the HloModule.
72
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
73
+ absl::string_view filename,
74
+ const DebugOptions& debug_options = DebugOptions::default_instance(),
75
+ const HloParserOptions& options = HloParserOptions());
76
+
77
+ // Reads the proto file in xla.HloProto format, creates and returns the
78
+ // HloModule.
79
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
80
+ absl::string_view hlo_file,
81
+ const DebugOptions& debug_options = DebugOptions::default_instance());
82
+
32
83
  // Creates an HloModuleConfig for a given program shape and arguments.
33
84
  // If execution_options does not set num_replicas, default_num_replicas is used.
34
85
  // num_threads is optional; if not given, intra_op_parallelism_threads not set.
@@ -18,8 +18,6 @@ limitations under the License.
18
18
  #ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
19
19
  #define XLA_SERVICE_HLO_PROTO_UTIL_H_
20
20
 
21
- #include <string>
22
-
23
21
  #include "absl/status/status.h"
24
22
  #include "xla/hlo/ir/hlo_module.h"
25
23
  #include "xla/service/buffer_assignment.h"
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
35
33
  // will not be included in the output.
36
34
  HloProto MakeHloProto(const HloModule& module);
37
35
 
38
- // Create an HLO state from serialized representation. In addition to
39
- // creating the proto with HloModule::CreateFromProto(...) it also
40
- // uses HloVerifier to ensure basic invariants are held.
41
- // The HLO module could be a pre-optimizations (default) or post-optimizations
42
- // module, which affects how the HLO module is verified, e.g., mixed-precision
43
- // is allowed in post-optimizations HLOs.
44
- absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
45
- const HloModuleProto& proto, const HloModuleConfig& module_config,
46
- bool is_module_post_optimizations = false);
47
-
48
36
  // Returns the shapes of the parameters of the entry computation. Shape pointers
49
37
  // refer to shapes inside of the given HloProto.
50
38
  absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
1604
1604
  Index padding_left = 0, Index padding_right = 0) {
1605
1605
  typedef typename internal::traits<Input>::Index TensorIndex;
1606
1606
  typedef typename internal::traits<Input>::Scalar InputScalar;
1607
- TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
- internal::traits<Input>::Layout, TensorIndex> >
1607
+ TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
+ internal::traits<Input>::Layout, TensorIndex> >
1609
1609
  in(input);
1610
- TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1611
- internal::traits<Kernel>::NumDimensions,
1612
- internal::traits<Kernel>::Layout, TensorIndex> >
1610
+ TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
1611
+ internal::traits<Kernel>::NumDimensions,
1612
+ internal::traits<Kernel>::Layout, TensorIndex> >
1613
1613
  kern(kernel);
1614
1614
 
1615
1615
  EIGEN_STATIC_ASSERT(
Binary file
Binary file
Binary file
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
29
29
  # This value changes every day with an automatic CL. It can be modified in code
30
30
  # via `forward_compatibility_horizon()` or with the environment variable
31
31
  # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
32
- _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 19)
32
+ _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 20)
33
33
  _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
34
34
  _FORWARD_COMPATIBILITY_DATE_NUMBER = None
35
35
 
@@ -43,13 +43,13 @@ def imperative_grad(tape,
43
43
  target: either a Tensor or list of Tensors to be differentiated.
44
44
  sources: list of Tensors for which we want gradients
45
45
  output_gradients: if not None, a list of gradient provided for each Target,
46
- or None if we are to use the target's computed downstream gradient.
46
+ or None if we are to use the target's computed downstream gradient.
47
47
  sources_raw: if not None, a list of the source python objects from which the
48
- sources were generated. Should have the same length as sources. Only needs
49
- to be populated if unconnected_gradients is 'zero'.
48
+ sources were generated. Should have the same length as sources. Only needs
49
+ to be populated if unconnected_gradients is 'zero'.
50
50
  unconnected_gradients: determines the value returned if the target and
51
- sources are unconnected. When 'none' the value returned is None wheras when
52
- 'zero' a zero tensor in the same shape as the sources is returned.
51
+ sources are unconnected. When 'none' the value returned is None whereas
52
+ when 'zero' a zero tensor in the same shape as the sources is returned.
53
53
 
54
54
  Returns:
55
55
  the gradient wrt each of the sources.
@@ -55,7 +55,7 @@ class CallOptions:
55
55
  # Used by ACD to list Ops/Tensors/Callables that must be called in advance.
56
56
  control_captures: List[Any] = dataclasses.field(default_factory=list)
57
57
 
58
- # Determines what kind of partitoned call is used for this function.
58
+ # Determines what kind of partitioned call is used for this function.
59
59
  is_stateful: bool = False
60
60
 
61
61
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Implmentation for defining get_compiler_ir."""
15
+ """Implementation for defining get_compiler_ir."""
16
16
  from typing import List, Optional
17
17
  import warnings
18
18
 
@@ -966,7 +966,7 @@ class Function(core.PolymorphicFunction, trackable.Trackable):
966
966
 
967
967
  def _check_inputs(args, kwargs):
968
968
  all_inputs = list(args) + list(kwargs.values())
969
- # Emtpy input is okay.
969
+ # Empty input is okay.
970
970
  if not all_inputs:
971
971
  return
972
972
  if any(map(is_tensor_spec, all_inputs)) and any(
@@ -1423,7 +1423,8 @@ def function(
1423
1423
  thought of as compile-time constants), and builds a separate `tf.Graph` for
1424
1424
  each set of Python arguments that it encounters.
1425
1425
  For more information, see the
1426
- [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1426
+ [tf.function
1427
+ guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1427
1428
 
1428
1429
  Executing a `PolymorphicFunction` will select and execute the appropriate
1429
1430
  `ConcreteFunction` based on the argument types and values.
@@ -1440,14 +1441,17 @@ def function(
1440
1441
  >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1441
1442
  True
1442
1443
 
1443
- `ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but their
1444
- input is resticted to the types to which they're specialized.
1444
+ `ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
1445
+ their
1446
+ input is restricted to the types to which they're specialized.
1445
1447
 
1446
1448
  ## Retracing
1447
1449
 
1448
- `ConcreteFunctions` are built (traced) on the fly, as the `PolymorphicFunction` is
1450
+ `ConcreteFunctions` are built (traced) on the fly, as the
1451
+ `PolymorphicFunction` is
1449
1452
  called with new TensorFlow types or shapes, or with new Python values as
1450
- arguments. When `PolymorphicFunction` builds a new trace, it is said that `func`
1453
+ arguments. When `PolymorphicFunction` builds a new trace, it is said that
1454
+ `func`
1451
1455
  is retraced. Retracing is a frequent performance concern for `tf.function` as
1452
1456
  it can be considerably slower than executing a graph that's already been
1453
1457
  traced. It is ideal to minimize the amount of retracing in your code.
@@ -1473,7 +1477,8 @@ def function(
1473
1477
 
1474
1478
  ## Input signatures
1475
1479
 
1476
- For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction` for
1480
+ For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
1481
+ for
1477
1482
  every unique set of input shapes and datatypes. The example below creates two
1478
1483
  separate `ConcreteFunction`s, each specialized to a different shape:
1479
1484
 
@@ -1580,59 +1585,58 @@ def function(
1580
1585
  `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1581
1586
  autograph: Whether autograph should be applied on `func` before tracing a
1582
1587
  graph. Data-dependent Python control flow statements require
1583
- `autograph=True`. For more information, see the
1584
- [tf.function and AutoGraph guide](
1588
+ `autograph=True`. For more information, see the [tf.function and AutoGraph
1589
+ guide](
1585
1590
  https://www.tensorflow.org/guide/function#autograph_transformations).
1586
1591
  jit_compile: If `True`, compiles the function using
1587
1592
  [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1588
1593
  such as fusion, and attempts to emit more efficient code. This may
1589
- drastically improve the performance. If set to `True`,
1590
- the whole function needs to be compilable by XLA, or an
1591
- `errors.InvalidArgumentError` is thrown.
1592
- If `None` (default), compiles the function with XLA when running on TPU
1593
- and goes through the regular function execution path when running on
1594
- other devices.
1595
- If `False`, executes the function without XLA compilation. Set this value
1596
- to `False` when directly running a multi-device function on TPUs (e.g. two
1597
- TPU cores, one TPU core and its host CPU).
1598
- Not all functions are compilable, see a list of
1599
- [sharp corners](https://tensorflow.org/xla/known_issues).
1600
- reduce_retracing: When True, `tf.function` attempts to reduce the
1601
- amount of retracing, for example by using more generic shapes. This
1602
- can be controlled for user objects by customizing their associated
1594
+ drastically improve the performance. If set to `True`, the whole function
1595
+ needs to be compilable by XLA, or an `errors.InvalidArgumentError` is
1596
+ thrown. If `None` (default), compiles the function with XLA when running
1597
+ on TPU and goes through the regular function execution path when running
1598
+ on other devices. If `False`, executes the function without XLA
1599
+ compilation. Set this value to `False` when directly running a
1600
+ multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
1601
+ host CPU). Not all functions are compilable, see a list of [sharp
1602
+ corners](https://tensorflow.org/xla/known_issues).
1603
+ reduce_retracing: When True, `tf.function` attempts to reduce the amount of
1604
+ retracing, for example by using more generic shapes. This can be
1605
+ controlled for user objects by customizing their associated
1603
1606
  `tf.types.experimental.TraceType`.
1604
1607
  experimental_implements: If provided, contains a name of a "known" function
1605
- this implements. For example "mycompany.my_recurrent_cell".
1606
- This is stored as an attribute in inference function,
1607
- which can then be detected when processing serialized function.
1608
- See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long
1609
- for details. For an example of utilizing this attribute see this
1608
+ this implements. For example "mycompany.my_recurrent_cell". This is stored
1609
+ as an attribute in inference function, which can then be detected when
1610
+ processing serialized function. See [standardizing composite
1611
+ ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)
1612
+ # pylint: disable=line-too-long for details. For an example of utilizing
1613
+ this attribute see this
1610
1614
  [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1611
1615
  The code above automatically detects and substitutes function that
1612
1616
  implements "embedded_matmul" and allows TFLite to substitute its own
1613
- implementations. For instance, a tensorflow user can use this
1614
- attribute to mark that their function also implements
1615
- `embedded_matmul` (perhaps more efficiently!)
1616
- by specifying it using this parameter:
1617
- `@tf.function(experimental_implements="embedded_matmul")`
1618
- This can either be specified as just the string name of the function or
1619
- a NameAttrList corresponding to a list of key-value attributes associated
1620
- with the function name. The name of the function will be in the 'name'
1621
- field of the NameAttrList. To define a formal TF op for this function
1622
- implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1617
+ implementations. For instance, a tensorflow user can use this attribute to
1618
+ mark that their function also implements `embedded_matmul` (perhaps more
1619
+ efficiently!) by specifying it using this parameter:
1620
+ `@tf.function(experimental_implements="embedded_matmul")` This can either
1621
+ be specified as just the string name of the function or a NameAttrList
1622
+ corresponding to a list of key-value attributes associated with the
1623
+ function name. The name of the function will be in the 'name' field of the
1624
+ NameAttrList. To define a formal TF op for this function implements, try
1625
+ the experimental [composite
1626
+ TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1623
1627
  project.
1624
1628
  experimental_autograph_options: Optional tuple of
1625
1629
  `tf.autograph.experimental.Feature` values.
1626
1630
  experimental_attributes: Optional dictionary of attributes to include in the
1627
1631
  generated FunctionDefs.
1628
- experimental_relax_shapes: Deprecated. Use `reduce_retracing`
1629
- instead.
1632
+ experimental_relax_shapes: Deprecated. Use `reduce_retracing` instead.
1630
1633
  experimental_compile: Deprecated alias to 'jit_compile'.
1631
1634
  experimental_follow_type_hints: Deprecated. Please use input_signature or
1632
1635
  reduce_retracing instead.
1633
1636
 
1634
1637
  Returns:
1635
- If `func` is not None, returns a `tf.types.experimental.PolymorphicFunction`.
1638
+ If `func` is not None, returns a
1639
+ `tf.types.experimental.PolymorphicFunction`.
1636
1640
  If `func` is None, returns a decorator that, when invoked with a single
1637
1641
  `func` argument, returns a `tf.types.experimental.PolymorphicFunction`.
1638
1642
 
@@ -48,9 +48,9 @@ def watch(tape, tensor):
48
48
  def default_get_variables(variable):
49
49
  return [variable]
50
50
 
51
- # Gets a list of changed variables. Can be overriden using
51
+ # Gets a list of changed variables. Can be overridden using
52
52
  # register_variables_override. An example of overriding is for getting the
53
- # varibles within a distributed context.
53
+ # variables within a distributed context.
54
54
  _variables_override = default_get_variables
55
55
 
56
56
 
Binary file
@@ -151,7 +151,11 @@ def _legacy_contrib_should_record_summaries():
151
151
 
152
152
  def is_recording_summaries():
153
153
  """Returns non-Tensor boolean indicating if summaries are being recorded."""
154
- return _summary_state.is_recording is not None and _summary_state.is_recording
154
+ if _summary_state.writer is None:
155
+ return False
156
+ if _summary_state.is_recording is None:
157
+ return False
158
+ return _summary_state.is_recording
155
159
 
156
160
 
157
161
  @tf_export("summary.record_if", v1=[])