tf-nightly-cpu 2.20.0.dev20250220__cp39-cp39-win_amd64.whl → 2.20.0.dev20250222__cp39-cp39-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (130) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  14. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  15. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  17. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  18. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  19. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  20. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  21. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  22. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  23. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
  24. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  25. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  26. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  27. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  28. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  30. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  31. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  32. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  33. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  34. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  35. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  36. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  37. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  38. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  39. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  40. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  41. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  42. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  43. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  44. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  45. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  46. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  47. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  48. tensorflow/include/tensorflow/core/public/version.h +112 -127
  49. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  50. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  51. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  52. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  53. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
  54. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  55. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  56. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  57. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  58. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  59. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  60. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  61. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  62. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  63. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  64. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  65. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  66. tensorflow/include/xla/service/constant_value.h +1 -0
  67. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  68. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  69. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  70. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  71. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  72. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  73. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  74. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  75. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  76. tensorflow/python/_pywrap_mlir.pyd +0 -0
  77. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  78. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  79. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  80. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  81. tensorflow/python/_pywrap_tfe.pyd +0 -0
  82. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  83. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  84. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  85. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  86. tensorflow/python/compat/compat.py +1 -1
  87. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  88. tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
  89. tensorflow/python/eager/imperative_grad.py +5 -5
  90. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  91. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  92. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  93. tensorflow/python/eager/tape.py +2 -2
  94. tensorflow/python/framework/_dtypes.pyd +0 -0
  95. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  96. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  97. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  98. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  99. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  100. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  101. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  102. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  103. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  104. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  105. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  106. tensorflow/python/ops/summary_ops_v2.py +5 -1
  107. tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
  108. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  109. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  110. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  111. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  112. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  113. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  114. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  115. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  116. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  117. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  118. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  119. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  120. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  121. tensorflow/python/util/_tf_stack.pyd +0 -0
  122. tensorflow/tools/pip_package/setup.py +2 -2
  123. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  124. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  125. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +128 -123
  126. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  127. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  128. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  129. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  130. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,67 @@ limitations under the License.
19
19
  #include <functional>
20
20
  #include <memory>
21
21
  #include <optional>
22
+ #include <string>
22
23
 
23
- #include "absl/status/status.h"
24
+ #include "absl/log/check.h"
25
+ #include "absl/log/log.h"
24
26
  #include "absl/status/statusor.h"
27
+ #include "absl/strings/string_view.h"
25
28
  #include "absl/types/span.h"
29
+ #include "xla/hlo/ir/hlo_module.h"
30
+ #include "xla/hlo/parser/hlo_parser.h"
26
31
  #include "xla/service/compiler.h"
27
32
  #include "xla/service/hlo_module_config.h"
28
33
  #include "xla/shape.h"
34
+ #include "xla/util.h"
29
35
 
30
36
  namespace xla {
31
37
 
38
+ // Converts an HloModule from the given hlo textual IR string (in
39
+ // HloModule::ToString format).
40
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
41
+ absl::string_view hlo_string,
42
+ const DebugOptions& debug_options = DebugOptions::default_instance());
43
+
44
+ // Creates an HloModule from the given proto.
45
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
46
+ const HloModuleProto& proto,
47
+ const DebugOptions& debug_options = DebugOptions::default_instance());
48
+
49
+ // Create an HLO state from serialized representation. In addition to
50
+ // creating the proto with HloModule::CreateFromProto(...) it also
51
+ // uses HloVerifier to ensure basic invariants are held.
52
+ // The HLO module could be a pre-optimizations (default) or post-optimizations
53
+ // module, which affects how the HLO module is verified, e.g., mixed-precision
54
+ // is allowed in post-optimizations HLOs.
55
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
56
+ const HloModuleProto& proto, const HloModuleConfig& module_config,
57
+ bool is_module_post_optimizations = false);
58
+
59
+ // Reads the proto file in xla.HloProto format, creates and returns the
60
+ // HloModule.
61
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
62
+ absl::string_view filename,
63
+ const DebugOptions& debug_options = DebugOptions::default_instance());
64
+
65
+ // Reads the proto file in xla.HloModule format, creates and returns the
66
+ // HloModule.
67
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
68
+ absl::string_view filename, const DebugOptions& debug_options);
69
+
70
+ // Reads the HLO text dump file in HloModule::ToString format, creates and
71
+ // returns the HloModule.
72
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
73
+ absl::string_view filename,
74
+ const DebugOptions& debug_options = DebugOptions::default_instance(),
75
+ const HloParserOptions& options = HloParserOptions());
76
+
77
+ // Reads the proto file in xla.HloProto format, creates and returns the
78
+ // HloModule.
79
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
80
+ absl::string_view hlo_file,
81
+ const DebugOptions& debug_options = DebugOptions::default_instance());
82
+
32
83
  // Creates an HloModuleConfig for a given program shape and arguments.
33
84
  // If execution_options does not set num_replicas, default_num_replicas is used.
34
85
  // num_threads is optional; if not given, intra_op_parallelism_threads not set.
@@ -18,8 +18,6 @@ limitations under the License.
18
18
  #ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
19
19
  #define XLA_SERVICE_HLO_PROTO_UTIL_H_
20
20
 
21
- #include <string>
22
-
23
21
  #include "absl/status/status.h"
24
22
  #include "xla/hlo/ir/hlo_module.h"
25
23
  #include "xla/service/buffer_assignment.h"
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
35
33
  // will not be included in the output.
36
34
  HloProto MakeHloProto(const HloModule& module);
37
35
 
38
- // Create an HLO state from serialized representation. In addition to
39
- // creating the proto with HloModule::CreateFromProto(...) it also
40
- // uses HloVerifier to ensure basic invariants are held.
41
- // The HLO module could be a pre-optimizations (default) or post-optimizations
42
- // module, which affects how the HLO module is verified, e.g., mixed-precision
43
- // is allowed in post-optimizations HLOs.
44
- absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
45
- const HloModuleProto& proto, const HloModuleConfig& module_config,
46
- bool is_module_post_optimizations = false);
47
-
48
36
  // Returns the shapes of the parameters of the entry computation. Shape pointers
49
37
  // refer to shapes inside of the given HloProto.
50
38
  absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
@@ -35,9 +35,6 @@ limitations under the License.
35
35
  #include "xla/tsl/platform/logging.h"
36
36
 
37
37
  namespace tsl {
38
-
39
- class NotifierListNode;
40
-
41
38
  namespace internal {
42
39
 
43
40
  template <typename T>
@@ -277,6 +274,8 @@ class AsyncValue {
277
274
  protected:
278
275
  friend class IndirectAsyncValue;
279
276
 
277
+ struct WaiterListNode;
278
+
280
279
  static constexpr uint16_t kUnknownTypeId = 0;
281
280
 
282
281
  // Utility template for tag dispatching.
@@ -311,7 +310,7 @@ class AsyncValue {
311
310
 
312
311
  void NotifyAvailable(State available_state);
313
312
  void Destroy();
314
- void RunWaiters(NotifierListNode* list);
313
+ void RunWaiters(WaiterListNode* list);
315
314
 
316
315
  // IsTypeIdCompatible returns true if the type value stored in this AsyncValue
317
316
  // instance can be safely cast to `T`. This is a conservative check. I.e.
@@ -369,6 +368,16 @@ class AsyncValue {
369
368
  // This is a 16-bit value that identifies the type.
370
369
  uint16_t type_id_ = 0;
371
370
 
371
+ // This is a singly linked list of nodes waiting for notification, hanging off
372
+ // of AsyncValue. When the value becomes available or if an error occurs, the
373
+ // callbacks are informed.
374
+ struct WaiterListNode {
375
+ virtual ~WaiterListNode() = default;
376
+ virtual void operator()() = 0;
377
+
378
+ WaiterListNode* next = nullptr;
379
+ };
380
+
372
381
  // The waiter list and the state are compacted into one single atomic word as
373
382
  // accesses to them are tightly related. To change the state from unavailable
374
383
  // (i.e. kUnconstructed or kConstructed) to available
@@ -379,7 +388,7 @@ class AsyncValue {
379
388
  // Invariant: If the state is not available, then the waiter list must be
380
389
  // nullptr.
381
390
  struct WaitersAndState {
382
- // We rely on the fact that all `NotifierListNode` values are aligned at
391
+ // We rely on the fact that all `WaiterListNode` values are aligned at
383
392
  // least to 4 bytes and we can encode state in the lowest 2 bits. We use
384
393
  // the conservative estimation of the minimal alignment of pointers returned
385
394
  // from memory allocation functions.
@@ -390,7 +399,7 @@ class AsyncValue {
390
399
  static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
391
400
  static constexpr uintptr_t kPointerMask = ~kStateMask;
392
401
 
393
- WaitersAndState(NotifierListNode* ptr, State state) {
402
+ WaitersAndState(WaiterListNode* ptr, State state) {
394
403
  value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
395
404
  (state & kStateMask);
396
405
  }
@@ -399,8 +408,8 @@ class AsyncValue {
399
408
  return State(static_cast<State::StateEnum>(value & kStateMask));
400
409
  }
401
410
 
402
- NotifierListNode* waiter() const {
403
- return reinterpret_cast<NotifierListNode*>(value & kPointerMask);
411
+ WaiterListNode* waiter() const {
412
+ return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
404
413
  }
405
414
 
406
415
  uintptr_t value;
@@ -466,8 +475,26 @@ class AsyncValue {
466
475
  return (*type_info_table)[type_id_ - 1];
467
476
  }
468
477
 
469
- void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
470
- WaitersAndState old_value);
478
+ // Adds a waiter list node to the waiter linked list. If the value is
479
+ // available or becomes available, this calls the waiter immediately.
480
+ // Otherwise, we add waiter to the list where it will be called when the value
481
+ // becomes available.
482
+ void EnqueueWaiterListNode(WaiterListNode* waiter,
483
+ WaitersAndState waiters_and_state);
484
+
485
+ template <typename Waiter>
486
+ void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
487
+ static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
488
+
489
+ struct Node final : public WaiterListNode {
490
+ explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
491
+ void operator()() final { waiter(); }
492
+ Waiter waiter;
493
+ };
494
+
495
+ EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
496
+ waiters_and_state);
497
+ }
471
498
 
472
499
  // This is a global counter of the number of AsyncValue instances currently
473
500
  // live in the process. This is intended to be used for debugging only, and
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
983
1010
  // Clients generally want to use AndThen without them each having to check
984
1011
  // to see if the value is present. Check for them, and immediately run the
985
1012
  // waiter if it is already here.
986
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
987
- if (old_value.state() == State::kConcrete ||
988
- old_value.state() == State::kError) {
989
- DCHECK_EQ(old_value.waiter(), nullptr);
1013
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1014
+ if (waiters_and_state.state() == State::kConcrete ||
1015
+ waiters_and_state.state() == State::kError) {
1016
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
990
1017
  waiter();
991
1018
  return;
992
1019
  }
993
- EnqueueWaiter(std::forward<Waiter>(waiter), old_value);
1020
+
1021
+ EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
994
1022
  }
995
1023
 
996
1024
  template <typename Waiter>
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
998
1026
  // Clients generally want to use AndThen without them each having to check
999
1027
  // to see if the value is present. Check for them, and immediately run the
1000
1028
  // waiter if it is already here.
1001
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
1002
- if (old_value.state() == State::kConcrete ||
1003
- old_value.state() == State::kError) {
1004
- DCHECK_EQ(old_value.waiter(), nullptr);
1029
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1030
+ if (waiters_and_state.state() == State::kConcrete ||
1031
+ waiters_and_state.state() == State::kError) {
1032
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
1005
1033
  executor.Execute(std::forward<Waiter>(waiter));
1006
1034
  return;
1007
1035
  }
1036
+
1008
1037
  EnqueueWaiter(
1009
- [&executor, waiter = std::forward<Waiter>(waiter)]() mutable {
1038
+ [&executor, waiter = std::forward<Waiter>(waiter)] {
1010
1039
  executor.Execute(std::move(waiter));
1011
1040
  },
1012
- old_value);
1041
+ waiters_and_state);
1013
1042
  }
1014
1043
 
1015
1044
  inline void AsyncValue::Destroy() {
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
1604
1604
  Index padding_left = 0, Index padding_right = 0) {
1605
1605
  typedef typename internal::traits<Input>::Index TensorIndex;
1606
1606
  typedef typename internal::traits<Input>::Scalar InputScalar;
1607
- TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
- internal::traits<Input>::Layout, TensorIndex> >
1607
+ TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
+ internal::traits<Input>::Layout, TensorIndex> >
1609
1609
  in(input);
1610
- TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1611
- internal::traits<Kernel>::NumDimensions,
1612
- internal::traits<Kernel>::Layout, TensorIndex> >
1610
+ TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
1611
+ internal::traits<Kernel>::NumDimensions,
1612
+ internal::traits<Kernel>::Layout, TensorIndex> >
1613
1613
  kern(kernel);
1614
1614
 
1615
1615
  EIGEN_STATIC_ASSERT(
Binary file
Binary file
Binary file
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
29
29
  # This value changes every day with an automatic CL. It can be modified in code
30
30
  # via `forward_compatibility_horizon()` or with the environment variable
31
31
  # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
32
- _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 19)
32
+ _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 21)
33
33
  _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
34
34
  _FORWARD_COMPATIBILITY_DATE_NUMBER = None
35
35
 
@@ -43,13 +43,13 @@ def imperative_grad(tape,
43
43
  target: either a Tensor or list of Tensors to be differentiated.
44
44
  sources: list of Tensors for which we want gradients
45
45
  output_gradients: if not None, a list of gradient provided for each Target,
46
- or None if we are to use the target's computed downstream gradient.
46
+ or None if we are to use the target's computed downstream gradient.
47
47
  sources_raw: if not None, a list of the source python objects from which the
48
- sources were generated. Should have the same length as sources. Only needs
49
- to be populated if unconnected_gradients is 'zero'.
48
+ sources were generated. Should have the same length as sources. Only needs
49
+ to be populated if unconnected_gradients is 'zero'.
50
50
  unconnected_gradients: determines the value returned if the target and
51
- sources are unconnected. When 'none' the value returned is None wheras when
52
- 'zero' a zero tensor in the same shape as the sources is returned.
51
+ sources are unconnected. When 'none' the value returned is None whereas
52
+ when 'zero' a zero tensor in the same shape as the sources is returned.
53
53
 
54
54
  Returns:
55
55
  the gradient wrt each of the sources.
@@ -55,7 +55,7 @@ class CallOptions:
55
55
  # Used by ACD to list Ops/Tensors/Callables that must be called in advance.
56
56
  control_captures: List[Any] = dataclasses.field(default_factory=list)
57
57
 
58
- # Determines what kind of partitoned call is used for this function.
58
+ # Determines what kind of partitioned call is used for this function.
59
59
  is_stateful: bool = False
60
60
 
61
61
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Implmentation for defining get_compiler_ir."""
15
+ """Implementation for defining get_compiler_ir."""
16
16
  from typing import List, Optional
17
17
  import warnings
18
18
 
@@ -966,7 +966,7 @@ class Function(core.PolymorphicFunction, trackable.Trackable):
966
966
 
967
967
  def _check_inputs(args, kwargs):
968
968
  all_inputs = list(args) + list(kwargs.values())
969
- # Emtpy input is okay.
969
+ # Empty input is okay.
970
970
  if not all_inputs:
971
971
  return
972
972
  if any(map(is_tensor_spec, all_inputs)) and any(
@@ -1423,7 +1423,8 @@ def function(
1423
1423
  thought of as compile-time constants), and builds a separate `tf.Graph` for
1424
1424
  each set of Python arguments that it encounters.
1425
1425
  For more information, see the
1426
- [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1426
+ [tf.function
1427
+ guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1427
1428
 
1428
1429
  Executing a `PolymorphicFunction` will select and execute the appropriate
1429
1430
  `ConcreteFunction` based on the argument types and values.
@@ -1440,14 +1441,17 @@ def function(
1440
1441
  >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1441
1442
  True
1442
1443
 
1443
- `ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but their
1444
- input is resticted to the types to which they're specialized.
1444
+ `ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
1445
+ their
1446
+ input is restricted to the types to which they're specialized.
1445
1447
 
1446
1448
  ## Retracing
1447
1449
 
1448
- `ConcreteFunctions` are built (traced) on the fly, as the `PolymorphicFunction` is
1450
+ `ConcreteFunctions` are built (traced) on the fly, as the
1451
+ `PolymorphicFunction` is
1449
1452
  called with new TensorFlow types or shapes, or with new Python values as
1450
- arguments. When `PolymorphicFunction` builds a new trace, it is said that `func`
1453
+ arguments. When `PolymorphicFunction` builds a new trace, it is said that
1454
+ `func`
1451
1455
  is retraced. Retracing is a frequent performance concern for `tf.function` as
1452
1456
  it can be considerably slower than executing a graph that's already been
1453
1457
  traced. It is ideal to minimize the amount of retracing in your code.
@@ -1473,7 +1477,8 @@ def function(
1473
1477
 
1474
1478
  ## Input signatures
1475
1479
 
1476
- For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction` for
1480
+ For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
1481
+ for
1477
1482
  every unique set of input shapes and datatypes. The example below creates two
1478
1483
  separate `ConcreteFunction`s, each specialized to a different shape:
1479
1484
 
@@ -1580,59 +1585,58 @@ def function(
1580
1585
  `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1581
1586
  autograph: Whether autograph should be applied on `func` before tracing a
1582
1587
  graph. Data-dependent Python control flow statements require
1583
- `autograph=True`. For more information, see the
1584
- [tf.function and AutoGraph guide](
1588
+ `autograph=True`. For more information, see the [tf.function and AutoGraph
1589
+ guide](
1585
1590
  https://www.tensorflow.org/guide/function#autograph_transformations).
1586
1591
  jit_compile: If `True`, compiles the function using
1587
1592
  [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1588
1593
  such as fusion, and attempts to emit more efficient code. This may
1589
- drastically improve the performance. If set to `True`,
1590
- the whole function needs to be compilable by XLA, or an
1591
- `errors.InvalidArgumentError` is thrown.
1592
- If `None` (default), compiles the function with XLA when running on TPU
1593
- and goes through the regular function execution path when running on
1594
- other devices.
1595
- If `False`, executes the function without XLA compilation. Set this value
1596
- to `False` when directly running a multi-device function on TPUs (e.g. two
1597
- TPU cores, one TPU core and its host CPU).
1598
- Not all functions are compilable, see a list of
1599
- [sharp corners](https://tensorflow.org/xla/known_issues).
1600
- reduce_retracing: When True, `tf.function` attempts to reduce the
1601
- amount of retracing, for example by using more generic shapes. This
1602
- can be controlled for user objects by customizing their associated
1594
+ drastically improve the performance. If set to `True`, the whole function
1595
+ needs to be compilable by XLA, or an `errors.InvalidArgumentError` is
1596
+ thrown. If `None` (default), compiles the function with XLA when running
1597
+ on TPU and goes through the regular function execution path when running
1598
+ on other devices. If `False`, executes the function without XLA
1599
+ compilation. Set this value to `False` when directly running a
1600
+ multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
1601
+ host CPU). Not all functions are compilable, see a list of [sharp
1602
+ corners](https://tensorflow.org/xla/known_issues).
1603
+ reduce_retracing: When True, `tf.function` attempts to reduce the amount of
1604
+ retracing, for example by using more generic shapes. This can be
1605
+ controlled for user objects by customizing their associated
1603
1606
  `tf.types.experimental.TraceType`.
1604
1607
  experimental_implements: If provided, contains a name of a "known" function
1605
- this implements. For example "mycompany.my_recurrent_cell".
1606
- This is stored as an attribute in inference function,
1607
- which can then be detected when processing serialized function.
1608
- See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long
1609
- for details. For an example of utilizing this attribute see this
1608
+ this implements. For example "mycompany.my_recurrent_cell". This is stored
1609
+ as an attribute in inference function, which can then be detected when
1610
+ processing serialized function. See [standardizing composite
1611
+ ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)
1612
+ # pylint: disable=line-too-long for details. For an example of utilizing
1613
+ this attribute see this
1610
1614
  [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1611
1615
  The code above automatically detects and substitutes function that
1612
1616
  implements "embedded_matmul" and allows TFLite to substitute its own
1613
- implementations. For instance, a tensorflow user can use this
1614
- attribute to mark that their function also implements
1615
- `embedded_matmul` (perhaps more efficiently!)
1616
- by specifying it using this parameter:
1617
- `@tf.function(experimental_implements="embedded_matmul")`
1618
- This can either be specified as just the string name of the function or
1619
- a NameAttrList corresponding to a list of key-value attributes associated
1620
- with the function name. The name of the function will be in the 'name'
1621
- field of the NameAttrList. To define a formal TF op for this function
1622
- implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1617
+ implementations. For instance, a tensorflow user can use this attribute to
1618
+ mark that their function also implements `embedded_matmul` (perhaps more
1619
+ efficiently!) by specifying it using this parameter:
1620
+ `@tf.function(experimental_implements="embedded_matmul")` This can either
1621
+ be specified as just the string name of the function or a NameAttrList
1622
+ corresponding to a list of key-value attributes associated with the
1623
+ function name. The name of the function will be in the 'name' field of the
1624
+ NameAttrList. To define a formal TF op for this function implements, try
1625
+ the experimental [composite
1626
+ TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1623
1627
  project.
1624
1628
  experimental_autograph_options: Optional tuple of
1625
1629
  `tf.autograph.experimental.Feature` values.
1626
1630
  experimental_attributes: Optional dictionary of attributes to include in the
1627
1631
  generated FunctionDefs.
1628
- experimental_relax_shapes: Deprecated. Use `reduce_retracing`
1629
- instead.
1632
+ experimental_relax_shapes: Deprecated. Use `reduce_retracing` instead.
1630
1633
  experimental_compile: Deprecated alias to 'jit_compile'.
1631
1634
  experimental_follow_type_hints: Deprecated. Please use input_signature or
1632
1635
  reduce_retracing instead.
1633
1636
 
1634
1637
  Returns:
1635
- If `func` is not None, returns a `tf.types.experimental.PolymorphicFunction`.
1638
+ If `func` is not None, returns a
1639
+ `tf.types.experimental.PolymorphicFunction`.
1636
1640
  If `func` is None, returns a decorator that, when invoked with a single
1637
1641
  `func` argument, returns a `tf.types.experimental.PolymorphicFunction`.
1638
1642
 
@@ -48,9 +48,9 @@ def watch(tape, tensor):
48
48
  def default_get_variables(variable):
49
49
  return [variable]
50
50
 
51
- # Gets a list of changed variables. Can be overriden using
51
+ # Gets a list of changed variables. Can be overridden using
52
52
  # register_variables_override. An example of overriding is for getting the
53
- # varibles within a distributed context.
53
+ # variables within a distributed context.
54
54
  _variables_override = default_get_variables
55
55
 
56
56
 
Binary file
@@ -151,7 +151,11 @@ def _legacy_contrib_should_record_summaries():
151
151
 
152
152
  def is_recording_summaries():
153
153
  """Returns non-Tensor boolean indicating if summaries are being recorded."""
154
- return _summary_state.is_recording is not None and _summary_state.is_recording
154
+ if _summary_state.writer is None:
155
+ return False
156
+ if _summary_state.is_recording is None:
157
+ return False
158
+ return _summary_state.is_recording
155
159
 
156
160
 
157
161
  @tf_export("summary.record_if", v1=[])
Binary file
@@ -67,6 +67,7 @@ from tensorflow.python.util.tf_export import tf_export
67
67
  _PIPELINE_ATTRIBUTE = "_embedding_pipelining"
68
68
  _PIPELINE_MODE_FORWARD = "forward"
69
69
  _PIPELINE_MODE_BACKWARD = "backward"
70
+ _PIPELINE_MODEL_SEQUENTIAL = "_sequential"
70
71
 
71
72
 
72
73
  TableConfig = tpu_embedding_v2_utils.TableConfig
@@ -95,15 +96,21 @@ class EmbeddingPipeliningContext(control_flow_ops.ControlFlowContext):
95
96
  super().__init__()
96
97
  self._name = "EmbeddingPipelinigContext"
97
98
  self._mode = attr_value_pb2.AttrValue(s=compat.as_bytes(mode))
99
+ self._enable = enable
98
100
  recording_summaries = summary_ops_v2.is_recording_summaries()
101
+ if not isinstance(recording_summaries, bool):
102
+ # We can't handle predicate functions at this point. So, we'll ignore the
103
+ # special casing of summary recording because, presumably, this is not
104
+ # a single step loop so pipelining is still valid.
105
+ recording_summaries = False
99
106
  if enable and recording_summaries:
100
- logging.info(
101
- "Embedding pipelining requested but summaries are being recorded:"
102
- " Disabling embedding pipelining."
107
+ # We'll still flag these ops for the SC forward/backward pass, but we'll
108
+ # run them sequentially. This has to be handled in the MLIR passes
109
+ # embedding_pipelining.cc and embedding_sequencing.cc.
110
+ logging.info("Summary recording detected, disabling pipelining.")
111
+ self._mode = attr_value_pb2.AttrValue(
112
+ s=compat.as_bytes(mode + _PIPELINE_MODEL_SEQUENTIAL)
103
113
  )
104
- self._enable = False
105
- else:
106
- self._enable = enable
107
114
 
108
115
  def to_control_flow_context_def(
109
116
  self, context_def: Any, export_scope: Any = None
@@ -1637,7 +1644,7 @@ class TPUEmbeddingV2(tpu_embedding_base.TPUEmbeddingBase):
1637
1644
  row_offset: int,
1638
1645
  col_offset: int,
1639
1646
  col_shift: int,
1640
- vocab_size: int,
1647
+ unused_vocab_size: int,
1641
1648
  num_sc_per_chip: int,
1642
1649
  num_sc_shards: int,
1643
1650
  stacked_table_sample_count: int,