tf-nightly-cpu 2.20.0.dev20250221__cp310-cp310-win_amd64.whl → 2.20.0.dev20250222__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (69) hide show
  1. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  2. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  3. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  4. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  5. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  6. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  7. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  8. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  9. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +0 -1
  10. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  11. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  12. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  13. tensorflow/include/tensorflow/core/public/release_version.h +1 -1
  14. tensorflow/include/tensorflow/core/public/version.h +1 -1
  15. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  16. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +0 -1
  17. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  18. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  19. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  20. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  21. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  22. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  23. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  24. tensorflow/python/_pywrap_mlir.pyd +0 -0
  25. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  26. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  27. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  28. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  29. tensorflow/python/_pywrap_tfe.pyd +0 -0
  30. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  31. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  32. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  33. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  34. tensorflow/python/compat/compat.py +1 -1
  35. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  36. tensorflow/python/framework/_dtypes.pyd +0 -0
  37. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  38. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  39. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  40. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  41. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  42. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  43. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  44. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  45. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  46. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  47. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  48. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  49. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  50. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  51. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  52. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  53. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  54. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  55. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  56. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  57. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  58. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  59. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  60. tensorflow/python/util/_tf_stack.pyd +0 -0
  61. tensorflow/tools/pip_package/setup.py +1 -1
  62. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  63. {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  64. {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +67 -69
  65. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  66. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  67. {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  68. {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  69. {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
Binary file
@@ -5711,6 +5711,7 @@ public:
5711
5711
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
5712
5712
  ::llvm::LogicalResult verifyInvariantsImpl();
5713
5713
  ::llvm::LogicalResult verifyInvariants();
5714
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
5714
5715
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
5715
5716
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
5716
5717
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -5925,6 +5926,7 @@ public:
5925
5926
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
5926
5927
  ::llvm::LogicalResult verifyInvariantsImpl();
5927
5928
  ::llvm::LogicalResult verifyInvariants();
5929
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
5928
5930
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
5929
5931
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
5930
5932
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -6139,6 +6141,7 @@ public:
6139
6141
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
6140
6142
  ::llvm::LogicalResult verifyInvariantsImpl();
6141
6143
  ::llvm::LogicalResult verifyInvariants();
6144
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
6142
6145
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
6143
6146
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
6144
6147
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -6353,6 +6356,7 @@ public:
6353
6356
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
6354
6357
  ::llvm::LogicalResult verifyInvariantsImpl();
6355
6358
  ::llvm::LogicalResult verifyInvariants();
6359
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
6356
6360
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
6357
6361
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
6358
6362
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -81,6 +81,15 @@ class AggressiveFactorPropagation : public BasicFactorPropagation {
81
81
  PropagationDirectionAlongFactor directionAlongFactor,
82
82
  ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
83
83
  bool conservativePropagation) const override;
84
+
85
+ private:
86
+ // Returns the axes to propagate to an individual factor in the given
87
+ // `tensorFactorShardings` of a tensor.
88
+ SmallVector<AxisRefAttr> getPropagatedFactorSharding(
89
+ int64_t factorIndex, const TensorFactorShardings& tensorFactorShardings,
90
+ const FactorIndexToSharding& factorIndexToSharding,
91
+ AxesPerFactorRef axesPerFactor, MeshAttr mesh,
92
+ bool conservativePropagation, ArrayRef<int64_t> factorSizes) const;
84
93
  };
85
94
 
86
95
  } // namespace sdy
@@ -38,7 +38,7 @@ class Version {
38
38
  static FailureOr<Version> fromString(llvm::StringRef versionRef);
39
39
 
40
40
  /// Return a Version representing the current VHLO dialect version.
41
- static Version getCurrentVersion() { return Version(1, 9, 2); }
41
+ static Version getCurrentVersion() { return Version(1, 9, 3); }
42
42
 
43
43
  /// Return a Version representing the minimum supported VHLO dialect version.
44
44
  static Version getMinimumVersion() { return Version(0, 9, 0); }
@@ -38,7 +38,7 @@ class Version {
38
38
  static FailureOr<Version> fromString(llvm::StringRef versionRef);
39
39
 
40
40
  /// Return a Version representing the current VHLO dialect version.
41
- static Version getCurrentVersion() { return Version(1, 9, 2); }
41
+ static Version getCurrentVersion() { return Version(1, 9, 3); }
42
42
 
43
43
  /// Return a Version representing the minimum supported VHLO dialect version.
44
44
  static Version getMinimumVersion() { return Version(0, 9, 0); }
@@ -22,7 +22,7 @@ limitations under the License.
22
22
  #include <memory>
23
23
  #include <utility>
24
24
 
25
- #include "xla/backends/cpu/runtime/concurrency.h"
25
+ #include "xla/backends/cpu/runtime/work_queue.h"
26
26
  #include "xla/tsl/concurrency/async_value_ref.h"
27
27
  #include "xla/tsl/concurrency/chain.h"
28
28
  #include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep
@@ -30,7 +30,6 @@ limitations under the License.
30
30
 
31
31
  #define EIGEN_USE_THREADS
32
32
  #include "Eigen/Core"
33
- #include "Eigen/ThreadPool"
34
33
  #include "unsupported/Eigen/CXX11/Tensor"
35
34
 
36
35
  namespace xla::cpu::internal {
@@ -384,8 +383,9 @@ void EigenGenericConv2D(
384
383
  auto num_tasks = Eigen::numext::div_ceil(feature_group_count, task_size);
385
384
 
386
385
  if (use_thunk_runtime) {
387
- ScheduleAll(
388
- &device, num_tasks, [=, &device](Eigen::Index task_index) mutable {
386
+ Worker::Parallelize(
387
+ &device, /*num_workers=*/num_tasks, num_tasks,
388
+ [=, &device](Eigen::Index task_index) mutable {
389
389
  Eigen::Index start = task_index * task_size;
390
390
  Eigen::Index end = std::min(start + task_size, feature_group_count);
391
391
  for (Eigen::Index i = start; i < end; ++i) {
@@ -395,18 +395,16 @@ void EigenGenericConv2D(
395
395
  }
396
396
  });
397
397
  } else {
398
- Eigen::Barrier barrier(num_tasks);
399
- ScheduleAll(
400
- &device, num_tasks, [=, &device, &barrier](Eigen::Index task_index) {
398
+ tsl::BlockUntilReady(Worker::Parallelize(
399
+ &device, /*num_workers=*/num_tasks, num_tasks,
400
+ [=, &device](Eigen::Index task_index) {
401
401
  Eigen::Index start = task_index * task_size;
402
402
  Eigen::Index end = std::min(start + task_size, feature_group_count);
403
403
  for (Eigen::Index i = start; i < end; ++i) {
404
404
  auto [output, convolved] = convolve_group(i);
405
405
  output.device(device) = convolved;
406
406
  }
407
- barrier.Notify();
408
- });
409
- barrier.Wait();
407
+ }));
410
408
  }
411
409
 
412
410
  } else {
@@ -29,7 +29,6 @@ limitations under the License.
29
29
  #include "absl/base/attributes.h"
30
30
  #include "absl/base/optimization.h"
31
31
  #include "absl/container/fixed_array.h"
32
- #include "absl/log/check.h"
33
32
  #include "absl/status/status.h"
34
33
  #include "xla/tsl/concurrency/async_value_ref.h"
35
34
  #include "xla/tsl/concurrency/chain.h"
@@ -1914,6 +1914,18 @@ class HloInstruction {
1914
1914
  result_accuracy().mode() != ResultAccuracy::DEFAULT);
1915
1915
  }
1916
1916
 
1917
+ bool equal_result_accuracy(const HloInstruction* other) const {
1918
+ return result_accuracy().has_tolerance() ==
1919
+ other->result_accuracy().has_tolerance() &&
1920
+ result_accuracy().tolerance().atol() ==
1921
+ other->result_accuracy().tolerance().atol() &&
1922
+ result_accuracy().tolerance().rtol() ==
1923
+ other->result_accuracy().tolerance().rtol() &&
1924
+ result_accuracy().tolerance().ulps() ==
1925
+ other->result_accuracy().tolerance().ulps() &&
1926
+ result_accuracy().mode() == other->result_accuracy().mode();
1927
+ }
1928
+
1917
1929
  void add_single_statistic(Statistic statistic) {
1918
1930
  *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
1919
1931
  }
@@ -35,9 +35,6 @@ limitations under the License.
35
35
  #include "xla/tsl/platform/logging.h"
36
36
 
37
37
  namespace tsl {
38
-
39
- class NotifierListNode;
40
-
41
38
  namespace internal {
42
39
 
43
40
  template <typename T>
@@ -277,6 +274,8 @@ class AsyncValue {
277
274
  protected:
278
275
  friend class IndirectAsyncValue;
279
276
 
277
+ struct WaiterListNode;
278
+
280
279
  static constexpr uint16_t kUnknownTypeId = 0;
281
280
 
282
281
  // Utility template for tag dispatching.
@@ -311,7 +310,7 @@ class AsyncValue {
311
310
 
312
311
  void NotifyAvailable(State available_state);
313
312
  void Destroy();
314
- void RunWaiters(NotifierListNode* list);
313
+ void RunWaiters(WaiterListNode* list);
315
314
 
316
315
  // IsTypeIdCompatible returns true if the type value stored in this AsyncValue
317
316
  // instance can be safely cast to `T`. This is a conservative check. I.e.
@@ -369,6 +368,16 @@ class AsyncValue {
369
368
  // This is a 16-bit value that identifies the type.
370
369
  uint16_t type_id_ = 0;
371
370
 
371
+ // This is a singly linked list of nodes waiting for notification, hanging off
372
+ // of AsyncValue. When the value becomes available or if an error occurs, the
373
+ // callbacks are informed.
374
+ struct WaiterListNode {
375
+ virtual ~WaiterListNode() = default;
376
+ virtual void operator()() = 0;
377
+
378
+ WaiterListNode* next = nullptr;
379
+ };
380
+
372
381
  // The waiter list and the state are compacted into one single atomic word as
373
382
  // accesses to them are tightly related. To change the state from unavailable
374
383
  // (i.e. kUnconstructed or kConstructed) to available
@@ -379,7 +388,7 @@ class AsyncValue {
379
388
  // Invariant: If the state is not available, then the waiter list must be
380
389
  // nullptr.
381
390
  struct WaitersAndState {
382
- // We rely on the fact that all `NotifierListNode` values are aligned at
391
+ // We rely on the fact that all `WaiterListNode` values are aligned at
383
392
  // least to 4 bytes and we can encode state in the lowest 2 bits. We use
384
393
  // the conservative estimation of the minimal alignment of pointers returned
385
394
  // from memory allocation functions.
@@ -390,7 +399,7 @@ class AsyncValue {
390
399
  static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
391
400
  static constexpr uintptr_t kPointerMask = ~kStateMask;
392
401
 
393
- WaitersAndState(NotifierListNode* ptr, State state) {
402
+ WaitersAndState(WaiterListNode* ptr, State state) {
394
403
  value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
395
404
  (state & kStateMask);
396
405
  }
@@ -399,8 +408,8 @@ class AsyncValue {
399
408
  return State(static_cast<State::StateEnum>(value & kStateMask));
400
409
  }
401
410
 
402
- NotifierListNode* waiter() const {
403
- return reinterpret_cast<NotifierListNode*>(value & kPointerMask);
411
+ WaiterListNode* waiter() const {
412
+ return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
404
413
  }
405
414
 
406
415
  uintptr_t value;
@@ -466,8 +475,26 @@ class AsyncValue {
466
475
  return (*type_info_table)[type_id_ - 1];
467
476
  }
468
477
 
469
- void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
470
- WaitersAndState old_value);
478
+ // Adds a waiter list node to the waiter linked list. If the value is
479
+ // available or becomes available, this calls the waiter immediately.
480
+ // Otherwise, we add waiter to the list where it will be called when the value
481
+ // becomes available.
482
+ void EnqueueWaiterListNode(WaiterListNode* waiter,
483
+ WaitersAndState waiters_and_state);
484
+
485
+ template <typename Waiter>
486
+ void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
487
+ static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
488
+
489
+ struct Node final : public WaiterListNode {
490
+ explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
491
+ void operator()() final { waiter(); }
492
+ Waiter waiter;
493
+ };
494
+
495
+ EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
496
+ waiters_and_state);
497
+ }
471
498
 
472
499
  // This is a global counter of the number of AsyncValue instances currently
473
500
  // live in the process. This is intended to be used for debugging only, and
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
983
1010
  // Clients generally want to use AndThen without them each having to check
984
1011
  // to see if the value is present. Check for them, and immediately run the
985
1012
  // waiter if it is already here.
986
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
987
- if (old_value.state() == State::kConcrete ||
988
- old_value.state() == State::kError) {
989
- DCHECK_EQ(old_value.waiter(), nullptr);
1013
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1014
+ if (waiters_and_state.state() == State::kConcrete ||
1015
+ waiters_and_state.state() == State::kError) {
1016
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
990
1017
  waiter();
991
1018
  return;
992
1019
  }
993
- EnqueueWaiter(std::forward<Waiter>(waiter), old_value);
1020
+
1021
+ EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
994
1022
  }
995
1023
 
996
1024
  template <typename Waiter>
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
998
1026
  // Clients generally want to use AndThen without them each having to check
999
1027
  // to see if the value is present. Check for them, and immediately run the
1000
1028
  // waiter if it is already here.
1001
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
1002
- if (old_value.state() == State::kConcrete ||
1003
- old_value.state() == State::kError) {
1004
- DCHECK_EQ(old_value.waiter(), nullptr);
1029
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1030
+ if (waiters_and_state.state() == State::kConcrete ||
1031
+ waiters_and_state.state() == State::kError) {
1032
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
1005
1033
  executor.Execute(std::forward<Waiter>(waiter));
1006
1034
  return;
1007
1035
  }
1036
+
1008
1037
  EnqueueWaiter(
1009
- [&executor, waiter = std::forward<Waiter>(waiter)]() mutable {
1038
+ [&executor, waiter = std::forward<Waiter>(waiter)] {
1010
1039
  executor.Execute(std::move(waiter));
1011
1040
  },
1012
- old_value);
1041
+ waiters_and_state);
1013
1042
  }
1014
1043
 
1015
1044
  inline void AsyncValue::Destroy() {
@@ -35,8 +35,6 @@ class GetElementAtIndexOp : public AsyncOpKernel {
35
35
  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
36
36
  }
37
37
 
38
- ~GetElementAtIndexOp() override {}
39
-
40
38
  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
41
39
  unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
42
40
  ctx->SetStatus(DoCompute(ctx));
@@ -26,7 +26,7 @@ limitations under the License.
26
26
 
27
27
  // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
28
28
  // "-beta", "-rc", "-rc.1")
29
- #define TF_VERSION_SUFFIX "-dev20250221"
29
+ #define TF_VERSION_SUFFIX "-dev20250222"
30
30
 
31
31
  #define _TF_STR_HELPER(x) #x
32
32
  #define _TF_STR(x) _TF_STR_HELPER(x)
@@ -93,7 +93,7 @@ limitations under the License.
93
93
 
94
94
  #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
95
95
  #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
96
- #define TF_GRAPH_DEF_VERSION 2144 // Updated: 2025/2/20
96
+ #define TF_GRAPH_DEF_VERSION 2145 // Updated: 2025/2/21
97
97
 
98
98
  // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
99
99
  //
@@ -22,7 +22,7 @@ limitations under the License.
22
22
  #include <memory>
23
23
  #include <utility>
24
24
 
25
- #include "xla/backends/cpu/runtime/concurrency.h"
25
+ #include "xla/backends/cpu/runtime/work_queue.h"
26
26
  #include "xla/tsl/concurrency/async_value_ref.h"
27
27
  #include "xla/tsl/concurrency/chain.h"
28
28
  #include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep
@@ -30,7 +30,6 @@ limitations under the License.
30
30
 
31
31
  #define EIGEN_USE_THREADS
32
32
  #include "Eigen/Core"
33
- #include "Eigen/ThreadPool"
34
33
  #include "unsupported/Eigen/CXX11/Tensor"
35
34
 
36
35
  namespace xla::cpu::internal {
@@ -384,8 +383,9 @@ void EigenGenericConv2D(
384
383
  auto num_tasks = Eigen::numext::div_ceil(feature_group_count, task_size);
385
384
 
386
385
  if (use_thunk_runtime) {
387
- ScheduleAll(
388
- &device, num_tasks, [=, &device](Eigen::Index task_index) mutable {
386
+ Worker::Parallelize(
387
+ &device, /*num_workers=*/num_tasks, num_tasks,
388
+ [=, &device](Eigen::Index task_index) mutable {
389
389
  Eigen::Index start = task_index * task_size;
390
390
  Eigen::Index end = std::min(start + task_size, feature_group_count);
391
391
  for (Eigen::Index i = start; i < end; ++i) {
@@ -395,18 +395,16 @@ void EigenGenericConv2D(
395
395
  }
396
396
  });
397
397
  } else {
398
- Eigen::Barrier barrier(num_tasks);
399
- ScheduleAll(
400
- &device, num_tasks, [=, &device, &barrier](Eigen::Index task_index) {
398
+ tsl::BlockUntilReady(Worker::Parallelize(
399
+ &device, /*num_workers=*/num_tasks, num_tasks,
400
+ [=, &device](Eigen::Index task_index) {
401
401
  Eigen::Index start = task_index * task_size;
402
402
  Eigen::Index end = std::min(start + task_size, feature_group_count);
403
403
  for (Eigen::Index i = start; i < end; ++i) {
404
404
  auto [output, convolved] = convolve_group(i);
405
405
  output.device(device) = convolved;
406
406
  }
407
- barrier.Notify();
408
- });
409
- barrier.Wait();
407
+ }));
410
408
  }
411
409
 
412
410
  } else {
@@ -29,7 +29,6 @@ limitations under the License.
29
29
  #include "absl/base/attributes.h"
30
30
  #include "absl/base/optimization.h"
31
31
  #include "absl/container/fixed_array.h"
32
- #include "absl/log/check.h"
33
32
  #include "absl/status/status.h"
34
33
  #include "xla/tsl/concurrency/async_value_ref.h"
35
34
  #include "xla/tsl/concurrency/chain.h"
@@ -1914,6 +1914,18 @@ class HloInstruction {
1914
1914
  result_accuracy().mode() != ResultAccuracy::DEFAULT);
1915
1915
  }
1916
1916
 
1917
+ bool equal_result_accuracy(const HloInstruction* other) const {
1918
+ return result_accuracy().has_tolerance() ==
1919
+ other->result_accuracy().has_tolerance() &&
1920
+ result_accuracy().tolerance().atol() ==
1921
+ other->result_accuracy().tolerance().atol() &&
1922
+ result_accuracy().tolerance().rtol() ==
1923
+ other->result_accuracy().tolerance().rtol() &&
1924
+ result_accuracy().tolerance().ulps() ==
1925
+ other->result_accuracy().tolerance().ulps() &&
1926
+ result_accuracy().mode() == other->result_accuracy().mode();
1927
+ }
1928
+
1917
1929
  void add_single_statistic(Statistic statistic) {
1918
1930
  *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
1919
1931
  }
@@ -35,9 +35,6 @@ limitations under the License.
35
35
  #include "xla/tsl/platform/logging.h"
36
36
 
37
37
  namespace tsl {
38
-
39
- class NotifierListNode;
40
-
41
38
  namespace internal {
42
39
 
43
40
  template <typename T>
@@ -277,6 +274,8 @@ class AsyncValue {
277
274
  protected:
278
275
  friend class IndirectAsyncValue;
279
276
 
277
+ struct WaiterListNode;
278
+
280
279
  static constexpr uint16_t kUnknownTypeId = 0;
281
280
 
282
281
  // Utility template for tag dispatching.
@@ -311,7 +310,7 @@ class AsyncValue {
311
310
 
312
311
  void NotifyAvailable(State available_state);
313
312
  void Destroy();
314
- void RunWaiters(NotifierListNode* list);
313
+ void RunWaiters(WaiterListNode* list);
315
314
 
316
315
  // IsTypeIdCompatible returns true if the type value stored in this AsyncValue
317
316
  // instance can be safely cast to `T`. This is a conservative check. I.e.
@@ -369,6 +368,16 @@ class AsyncValue {
369
368
  // This is a 16-bit value that identifies the type.
370
369
  uint16_t type_id_ = 0;
371
370
 
371
+ // This is a singly linked list of nodes waiting for notification, hanging off
372
+ // of AsyncValue. When the value becomes available or if an error occurs, the
373
+ // callbacks are informed.
374
+ struct WaiterListNode {
375
+ virtual ~WaiterListNode() = default;
376
+ virtual void operator()() = 0;
377
+
378
+ WaiterListNode* next = nullptr;
379
+ };
380
+
372
381
  // The waiter list and the state are compacted into one single atomic word as
373
382
  // accesses to them are tightly related. To change the state from unavailable
374
383
  // (i.e. kUnconstructed or kConstructed) to available
@@ -379,7 +388,7 @@ class AsyncValue {
379
388
  // Invariant: If the state is not available, then the waiter list must be
380
389
  // nullptr.
381
390
  struct WaitersAndState {
382
- // We rely on the fact that all `NotifierListNode` values are aligned at
391
+ // We rely on the fact that all `WaiterListNode` values are aligned at
383
392
  // least to 4 bytes and we can encode state in the lowest 2 bits. We use
384
393
  // the conservative estimation of the minimal alignment of pointers returned
385
394
  // from memory allocation functions.
@@ -390,7 +399,7 @@ class AsyncValue {
390
399
  static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
391
400
  static constexpr uintptr_t kPointerMask = ~kStateMask;
392
401
 
393
- WaitersAndState(NotifierListNode* ptr, State state) {
402
+ WaitersAndState(WaiterListNode* ptr, State state) {
394
403
  value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
395
404
  (state & kStateMask);
396
405
  }
@@ -399,8 +408,8 @@ class AsyncValue {
399
408
  return State(static_cast<State::StateEnum>(value & kStateMask));
400
409
  }
401
410
 
402
- NotifierListNode* waiter() const {
403
- return reinterpret_cast<NotifierListNode*>(value & kPointerMask);
411
+ WaiterListNode* waiter() const {
412
+ return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
404
413
  }
405
414
 
406
415
  uintptr_t value;
@@ -466,8 +475,26 @@ class AsyncValue {
466
475
  return (*type_info_table)[type_id_ - 1];
467
476
  }
468
477
 
469
- void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
470
- WaitersAndState old_value);
478
+ // Adds a waiter list node to the waiter linked list. If the value is
479
+ // available or becomes available, this calls the waiter immediately.
480
+ // Otherwise, we add waiter to the list where it will be called when the value
481
+ // becomes available.
482
+ void EnqueueWaiterListNode(WaiterListNode* waiter,
483
+ WaitersAndState waiters_and_state);
484
+
485
+ template <typename Waiter>
486
+ void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
487
+ static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
488
+
489
+ struct Node final : public WaiterListNode {
490
+ explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
491
+ void operator()() final { waiter(); }
492
+ Waiter waiter;
493
+ };
494
+
495
+ EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
496
+ waiters_and_state);
497
+ }
471
498
 
472
499
  // This is a global counter of the number of AsyncValue instances currently
473
500
  // live in the process. This is intended to be used for debugging only, and
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
983
1010
  // Clients generally want to use AndThen without them each having to check
984
1011
  // to see if the value is present. Check for them, and immediately run the
985
1012
  // waiter if it is already here.
986
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
987
- if (old_value.state() == State::kConcrete ||
988
- old_value.state() == State::kError) {
989
- DCHECK_EQ(old_value.waiter(), nullptr);
1013
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1014
+ if (waiters_and_state.state() == State::kConcrete ||
1015
+ waiters_and_state.state() == State::kError) {
1016
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
990
1017
  waiter();
991
1018
  return;
992
1019
  }
993
- EnqueueWaiter(std::forward<Waiter>(waiter), old_value);
1020
+
1021
+ EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
994
1022
  }
995
1023
 
996
1024
  template <typename Waiter>
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
998
1026
  // Clients generally want to use AndThen without them each having to check
999
1027
  // to see if the value is present. Check for them, and immediately run the
1000
1028
  // waiter if it is already here.
1001
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
1002
- if (old_value.state() == State::kConcrete ||
1003
- old_value.state() == State::kError) {
1004
- DCHECK_EQ(old_value.waiter(), nullptr);
1029
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1030
+ if (waiters_and_state.state() == State::kConcrete ||
1031
+ waiters_and_state.state() == State::kError) {
1032
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
1005
1033
  executor.Execute(std::forward<Waiter>(waiter));
1006
1034
  return;
1007
1035
  }
1036
+
1008
1037
  EnqueueWaiter(
1009
- [&executor, waiter = std::forward<Waiter>(waiter)]() mutable {
1038
+ [&executor, waiter = std::forward<Waiter>(waiter)] {
1010
1039
  executor.Execute(std::move(waiter));
1011
1040
  },
1012
- old_value);
1041
+ waiters_and_state);
1013
1042
  }
1014
1043
 
1015
1044
  inline void AsyncValue::Destroy() {
Binary file
Binary file
Binary file
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
29
29
  # This value changes every day with an automatic CL. It can be modified in code
30
30
  # via `forward_compatibility_horizon()` or with the environment variable
31
31
  # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
32
- _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 20)
32
+ _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 21)
33
33
  _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
34
34
  _FORWARD_COMPATIBILITY_DATE_NUMBER = None
35
35
 
Binary file