tf-nightly-cpu 2.20.0.dev20250221__cp311-cp311-win_amd64.whl → 2.20.0.dev20250222__cp311-cp311-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
- tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
- tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +0 -1
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
- tensorflow/include/tensorflow/core/public/release_version.h +1 -1
- tensorflow/include/tensorflow/core/public/version.h +1 -1
- tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +0 -1
- tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +1 -1
- tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
- {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +67 -69
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
- tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
- {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250221.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
Binary file
|
Binary file
|
Binary file
|
@@ -5711,6 +5711,7 @@ public:
|
|
5711
5711
|
static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
|
5712
5712
|
::llvm::LogicalResult verifyInvariantsImpl();
|
5713
5713
|
::llvm::LogicalResult verifyInvariants();
|
5714
|
+
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
|
5714
5715
|
static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
|
5715
5716
|
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
|
5716
5717
|
void print(::mlir::OpAsmPrinter &_odsPrinter);
|
@@ -5925,6 +5926,7 @@ public:
|
|
5925
5926
|
static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
|
5926
5927
|
::llvm::LogicalResult verifyInvariantsImpl();
|
5927
5928
|
::llvm::LogicalResult verifyInvariants();
|
5929
|
+
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
|
5928
5930
|
static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
|
5929
5931
|
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
|
5930
5932
|
void print(::mlir::OpAsmPrinter &_odsPrinter);
|
@@ -6139,6 +6141,7 @@ public:
|
|
6139
6141
|
static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
|
6140
6142
|
::llvm::LogicalResult verifyInvariantsImpl();
|
6141
6143
|
::llvm::LogicalResult verifyInvariants();
|
6144
|
+
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
|
6142
6145
|
static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
|
6143
6146
|
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
|
6144
6147
|
void print(::mlir::OpAsmPrinter &_odsPrinter);
|
@@ -6353,6 +6356,7 @@ public:
|
|
6353
6356
|
static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
|
6354
6357
|
::llvm::LogicalResult verifyInvariantsImpl();
|
6355
6358
|
::llvm::LogicalResult verifyInvariants();
|
6359
|
+
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
|
6356
6360
|
static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
|
6357
6361
|
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
|
6358
6362
|
void print(::mlir::OpAsmPrinter &_odsPrinter);
|
@@ -81,6 +81,15 @@ class AggressiveFactorPropagation : public BasicFactorPropagation {
|
|
81
81
|
PropagationDirectionAlongFactor directionAlongFactor,
|
82
82
|
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
|
83
83
|
bool conservativePropagation) const override;
|
84
|
+
|
85
|
+
private:
|
86
|
+
// Returns the axes to propagate to an individual factor in the given
|
87
|
+
// `tensorFactorShardings` of a tensor.
|
88
|
+
SmallVector<AxisRefAttr> getPropagatedFactorSharding(
|
89
|
+
int64_t factorIndex, const TensorFactorShardings& tensorFactorShardings,
|
90
|
+
const FactorIndexToSharding& factorIndexToSharding,
|
91
|
+
AxesPerFactorRef axesPerFactor, MeshAttr mesh,
|
92
|
+
bool conservativePropagation, ArrayRef<int64_t> factorSizes) const;
|
84
93
|
};
|
85
94
|
|
86
95
|
} // namespace sdy
|
@@ -38,7 +38,7 @@ class Version {
|
|
38
38
|
static FailureOr<Version> fromString(llvm::StringRef versionRef);
|
39
39
|
|
40
40
|
/// Return a Version representing the current VHLO dialect version.
|
41
|
-
static Version getCurrentVersion() { return Version(1, 9,
|
41
|
+
static Version getCurrentVersion() { return Version(1, 9, 3); }
|
42
42
|
|
43
43
|
/// Return a Version representing the minimum supported VHLO dialect version.
|
44
44
|
static Version getMinimumVersion() { return Version(0, 9, 0); }
|
@@ -38,7 +38,7 @@ class Version {
|
|
38
38
|
static FailureOr<Version> fromString(llvm::StringRef versionRef);
|
39
39
|
|
40
40
|
/// Return a Version representing the current VHLO dialect version.
|
41
|
-
static Version getCurrentVersion() { return Version(1, 9,
|
41
|
+
static Version getCurrentVersion() { return Version(1, 9, 3); }
|
42
42
|
|
43
43
|
/// Return a Version representing the minimum supported VHLO dialect version.
|
44
44
|
static Version getMinimumVersion() { return Version(0, 9, 0); }
|
tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h
CHANGED
@@ -22,7 +22,7 @@ limitations under the License.
|
|
22
22
|
#include <memory>
|
23
23
|
#include <utility>
|
24
24
|
|
25
|
-
#include "xla/backends/cpu/runtime/
|
25
|
+
#include "xla/backends/cpu/runtime/work_queue.h"
|
26
26
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
27
27
|
#include "xla/tsl/concurrency/chain.h"
|
28
28
|
#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep
|
@@ -30,7 +30,6 @@ limitations under the License.
|
|
30
30
|
|
31
31
|
#define EIGEN_USE_THREADS
|
32
32
|
#include "Eigen/Core"
|
33
|
-
#include "Eigen/ThreadPool"
|
34
33
|
#include "unsupported/Eigen/CXX11/Tensor"
|
35
34
|
|
36
35
|
namespace xla::cpu::internal {
|
@@ -384,8 +383,9 @@ void EigenGenericConv2D(
|
|
384
383
|
auto num_tasks = Eigen::numext::div_ceil(feature_group_count, task_size);
|
385
384
|
|
386
385
|
if (use_thunk_runtime) {
|
387
|
-
|
388
|
-
&device, num_tasks,
|
386
|
+
Worker::Parallelize(
|
387
|
+
&device, /*num_workers=*/num_tasks, num_tasks,
|
388
|
+
[=, &device](Eigen::Index task_index) mutable {
|
389
389
|
Eigen::Index start = task_index * task_size;
|
390
390
|
Eigen::Index end = std::min(start + task_size, feature_group_count);
|
391
391
|
for (Eigen::Index i = start; i < end; ++i) {
|
@@ -395,18 +395,16 @@ void EigenGenericConv2D(
|
|
395
395
|
}
|
396
396
|
});
|
397
397
|
} else {
|
398
|
-
|
399
|
-
|
400
|
-
|
398
|
+
tsl::BlockUntilReady(Worker::Parallelize(
|
399
|
+
&device, /*num_workers=*/num_tasks, num_tasks,
|
400
|
+
[=, &device](Eigen::Index task_index) {
|
401
401
|
Eigen::Index start = task_index * task_size;
|
402
402
|
Eigen::Index end = std::min(start + task_size, feature_group_count);
|
403
403
|
for (Eigen::Index i = start; i < end; ++i) {
|
404
404
|
auto [output, convolved] = convolve_group(i);
|
405
405
|
output.device(device) = convolved;
|
406
406
|
}
|
407
|
-
|
408
|
-
});
|
409
|
-
barrier.Wait();
|
407
|
+
}));
|
410
408
|
}
|
411
409
|
|
412
410
|
} else {
|
@@ -29,7 +29,6 @@ limitations under the License.
|
|
29
29
|
#include "absl/base/attributes.h"
|
30
30
|
#include "absl/base/optimization.h"
|
31
31
|
#include "absl/container/fixed_array.h"
|
32
|
-
#include "absl/log/check.h"
|
33
32
|
#include "absl/status/status.h"
|
34
33
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
35
34
|
#include "xla/tsl/concurrency/chain.h"
|
@@ -1914,6 +1914,18 @@ class HloInstruction {
|
|
1914
1914
|
result_accuracy().mode() != ResultAccuracy::DEFAULT);
|
1915
1915
|
}
|
1916
1916
|
|
1917
|
+
bool equal_result_accuracy(const HloInstruction* other) const {
|
1918
|
+
return result_accuracy().has_tolerance() ==
|
1919
|
+
other->result_accuracy().has_tolerance() &&
|
1920
|
+
result_accuracy().tolerance().atol() ==
|
1921
|
+
other->result_accuracy().tolerance().atol() &&
|
1922
|
+
result_accuracy().tolerance().rtol() ==
|
1923
|
+
other->result_accuracy().tolerance().rtol() &&
|
1924
|
+
result_accuracy().tolerance().ulps() ==
|
1925
|
+
other->result_accuracy().tolerance().ulps() &&
|
1926
|
+
result_accuracy().mode() == other->result_accuracy().mode();
|
1927
|
+
}
|
1928
|
+
|
1917
1929
|
void add_single_statistic(Statistic statistic) {
|
1918
1930
|
*mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
|
1919
1931
|
}
|
@@ -35,9 +35,6 @@ limitations under the License.
|
|
35
35
|
#include "xla/tsl/platform/logging.h"
|
36
36
|
|
37
37
|
namespace tsl {
|
38
|
-
|
39
|
-
class NotifierListNode;
|
40
|
-
|
41
38
|
namespace internal {
|
42
39
|
|
43
40
|
template <typename T>
|
@@ -277,6 +274,8 @@ class AsyncValue {
|
|
277
274
|
protected:
|
278
275
|
friend class IndirectAsyncValue;
|
279
276
|
|
277
|
+
struct WaiterListNode;
|
278
|
+
|
280
279
|
static constexpr uint16_t kUnknownTypeId = 0;
|
281
280
|
|
282
281
|
// Utility template for tag dispatching.
|
@@ -311,7 +310,7 @@ class AsyncValue {
|
|
311
310
|
|
312
311
|
void NotifyAvailable(State available_state);
|
313
312
|
void Destroy();
|
314
|
-
void RunWaiters(
|
313
|
+
void RunWaiters(WaiterListNode* list);
|
315
314
|
|
316
315
|
// IsTypeIdCompatible returns true if the type value stored in this AsyncValue
|
317
316
|
// instance can be safely cast to `T`. This is a conservative check. I.e.
|
@@ -369,6 +368,16 @@ class AsyncValue {
|
|
369
368
|
// This is a 16-bit value that identifies the type.
|
370
369
|
uint16_t type_id_ = 0;
|
371
370
|
|
371
|
+
// This is a singly linked list of nodes waiting for notification, hanging off
|
372
|
+
// of AsyncValue. When the value becomes available or if an error occurs, the
|
373
|
+
// callbacks are informed.
|
374
|
+
struct WaiterListNode {
|
375
|
+
virtual ~WaiterListNode() = default;
|
376
|
+
virtual void operator()() = 0;
|
377
|
+
|
378
|
+
WaiterListNode* next = nullptr;
|
379
|
+
};
|
380
|
+
|
372
381
|
// The waiter list and the state are compacted into one single atomic word as
|
373
382
|
// accesses to them are tightly related. To change the state from unavailable
|
374
383
|
// (i.e. kUnconstructed or kConstructed) to available
|
@@ -379,7 +388,7 @@ class AsyncValue {
|
|
379
388
|
// Invariant: If the state is not available, then the waiter list must be
|
380
389
|
// nullptr.
|
381
390
|
struct WaitersAndState {
|
382
|
-
// We rely on the fact that all `
|
391
|
+
// We rely on the fact that all `WaiterListNode` values are aligned at
|
383
392
|
// least to 4 bytes and we can encode state in the lowest 2 bits. We use
|
384
393
|
// the conservative estimation of the minimal alignment of pointers returned
|
385
394
|
// from memory allocation functions.
|
@@ -390,7 +399,7 @@ class AsyncValue {
|
|
390
399
|
static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
|
391
400
|
static constexpr uintptr_t kPointerMask = ~kStateMask;
|
392
401
|
|
393
|
-
WaitersAndState(
|
402
|
+
WaitersAndState(WaiterListNode* ptr, State state) {
|
394
403
|
value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
|
395
404
|
(state & kStateMask);
|
396
405
|
}
|
@@ -399,8 +408,8 @@ class AsyncValue {
|
|
399
408
|
return State(static_cast<State::StateEnum>(value & kStateMask));
|
400
409
|
}
|
401
410
|
|
402
|
-
|
403
|
-
return reinterpret_cast<
|
411
|
+
WaiterListNode* waiter() const {
|
412
|
+
return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
|
404
413
|
}
|
405
414
|
|
406
415
|
uintptr_t value;
|
@@ -466,8 +475,26 @@ class AsyncValue {
|
|
466
475
|
return (*type_info_table)[type_id_ - 1];
|
467
476
|
}
|
468
477
|
|
469
|
-
|
470
|
-
|
478
|
+
// Adds a waiter list node to the waiter linked list. If the value is
|
479
|
+
// available or becomes available, this calls the waiter immediately.
|
480
|
+
// Otherwise, we add waiter to the list where it will be called when the value
|
481
|
+
// becomes available.
|
482
|
+
void EnqueueWaiterListNode(WaiterListNode* waiter,
|
483
|
+
WaitersAndState waiters_and_state);
|
484
|
+
|
485
|
+
template <typename Waiter>
|
486
|
+
void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
|
487
|
+
static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
|
488
|
+
|
489
|
+
struct Node final : public WaiterListNode {
|
490
|
+
explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
|
491
|
+
void operator()() final { waiter(); }
|
492
|
+
Waiter waiter;
|
493
|
+
};
|
494
|
+
|
495
|
+
EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
|
496
|
+
waiters_and_state);
|
497
|
+
}
|
471
498
|
|
472
499
|
// This is a global counter of the number of AsyncValue instances currently
|
473
500
|
// live in the process. This is intended to be used for debugging only, and
|
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
|
|
983
1010
|
// Clients generally want to use AndThen without them each having to check
|
984
1011
|
// to see if the value is present. Check for them, and immediately run the
|
985
1012
|
// waiter if it is already here.
|
986
|
-
auto
|
987
|
-
if (
|
988
|
-
|
989
|
-
DCHECK_EQ(
|
1013
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1014
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1015
|
+
waiters_and_state.state() == State::kError) {
|
1016
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
990
1017
|
waiter();
|
991
1018
|
return;
|
992
1019
|
}
|
993
|
-
|
1020
|
+
|
1021
|
+
EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
|
994
1022
|
}
|
995
1023
|
|
996
1024
|
template <typename Waiter>
|
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
|
|
998
1026
|
// Clients generally want to use AndThen without them each having to check
|
999
1027
|
// to see if the value is present. Check for them, and immediately run the
|
1000
1028
|
// waiter if it is already here.
|
1001
|
-
auto
|
1002
|
-
if (
|
1003
|
-
|
1004
|
-
DCHECK_EQ(
|
1029
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1030
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1031
|
+
waiters_and_state.state() == State::kError) {
|
1032
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
1005
1033
|
executor.Execute(std::forward<Waiter>(waiter));
|
1006
1034
|
return;
|
1007
1035
|
}
|
1036
|
+
|
1008
1037
|
EnqueueWaiter(
|
1009
|
-
[&executor, waiter = std::forward<Waiter>(waiter)]
|
1038
|
+
[&executor, waiter = std::forward<Waiter>(waiter)] {
|
1010
1039
|
executor.Execute(std::move(waiter));
|
1011
1040
|
},
|
1012
|
-
|
1041
|
+
waiters_and_state);
|
1013
1042
|
}
|
1014
1043
|
|
1015
1044
|
inline void AsyncValue::Destroy() {
|
@@ -35,8 +35,6 @@ class GetElementAtIndexOp : public AsyncOpKernel {
|
|
35
35
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
36
36
|
}
|
37
37
|
|
38
|
-
~GetElementAtIndexOp() override {}
|
39
|
-
|
40
38
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
41
39
|
unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
|
42
40
|
ctx->SetStatus(DoCompute(ctx));
|
@@ -26,7 +26,7 @@ limitations under the License.
|
|
26
26
|
|
27
27
|
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
28
28
|
// "-beta", "-rc", "-rc.1")
|
29
|
-
#define TF_VERSION_SUFFIX "-
|
29
|
+
#define TF_VERSION_SUFFIX "-dev20250222"
|
30
30
|
|
31
31
|
#define _TF_STR_HELPER(x) #x
|
32
32
|
#define _TF_STR(x) _TF_STR_HELPER(x)
|
@@ -93,7 +93,7 @@ limitations under the License.
|
|
93
93
|
|
94
94
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
95
95
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
96
|
-
#define TF_GRAPH_DEF_VERSION
|
96
|
+
#define TF_GRAPH_DEF_VERSION 2145 // Updated: 2025/2/21
|
97
97
|
|
98
98
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
99
99
|
//
|
@@ -22,7 +22,7 @@ limitations under the License.
|
|
22
22
|
#include <memory>
|
23
23
|
#include <utility>
|
24
24
|
|
25
|
-
#include "xla/backends/cpu/runtime/
|
25
|
+
#include "xla/backends/cpu/runtime/work_queue.h"
|
26
26
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
27
27
|
#include "xla/tsl/concurrency/chain.h"
|
28
28
|
#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep
|
@@ -30,7 +30,6 @@ limitations under the License.
|
|
30
30
|
|
31
31
|
#define EIGEN_USE_THREADS
|
32
32
|
#include "Eigen/Core"
|
33
|
-
#include "Eigen/ThreadPool"
|
34
33
|
#include "unsupported/Eigen/CXX11/Tensor"
|
35
34
|
|
36
35
|
namespace xla::cpu::internal {
|
@@ -384,8 +383,9 @@ void EigenGenericConv2D(
|
|
384
383
|
auto num_tasks = Eigen::numext::div_ceil(feature_group_count, task_size);
|
385
384
|
|
386
385
|
if (use_thunk_runtime) {
|
387
|
-
|
388
|
-
&device, num_tasks,
|
386
|
+
Worker::Parallelize(
|
387
|
+
&device, /*num_workers=*/num_tasks, num_tasks,
|
388
|
+
[=, &device](Eigen::Index task_index) mutable {
|
389
389
|
Eigen::Index start = task_index * task_size;
|
390
390
|
Eigen::Index end = std::min(start + task_size, feature_group_count);
|
391
391
|
for (Eigen::Index i = start; i < end; ++i) {
|
@@ -395,18 +395,16 @@ void EigenGenericConv2D(
|
|
395
395
|
}
|
396
396
|
});
|
397
397
|
} else {
|
398
|
-
|
399
|
-
|
400
|
-
|
398
|
+
tsl::BlockUntilReady(Worker::Parallelize(
|
399
|
+
&device, /*num_workers=*/num_tasks, num_tasks,
|
400
|
+
[=, &device](Eigen::Index task_index) {
|
401
401
|
Eigen::Index start = task_index * task_size;
|
402
402
|
Eigen::Index end = std::min(start + task_size, feature_group_count);
|
403
403
|
for (Eigen::Index i = start; i < end; ++i) {
|
404
404
|
auto [output, convolved] = convolve_group(i);
|
405
405
|
output.device(device) = convolved;
|
406
406
|
}
|
407
|
-
|
408
|
-
});
|
409
|
-
barrier.Wait();
|
407
|
+
}));
|
410
408
|
}
|
411
409
|
|
412
410
|
} else {
|
@@ -29,7 +29,6 @@ limitations under the License.
|
|
29
29
|
#include "absl/base/attributes.h"
|
30
30
|
#include "absl/base/optimization.h"
|
31
31
|
#include "absl/container/fixed_array.h"
|
32
|
-
#include "absl/log/check.h"
|
33
32
|
#include "absl/status/status.h"
|
34
33
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
35
34
|
#include "xla/tsl/concurrency/chain.h"
|
@@ -1914,6 +1914,18 @@ class HloInstruction {
|
|
1914
1914
|
result_accuracy().mode() != ResultAccuracy::DEFAULT);
|
1915
1915
|
}
|
1916
1916
|
|
1917
|
+
bool equal_result_accuracy(const HloInstruction* other) const {
|
1918
|
+
return result_accuracy().has_tolerance() ==
|
1919
|
+
other->result_accuracy().has_tolerance() &&
|
1920
|
+
result_accuracy().tolerance().atol() ==
|
1921
|
+
other->result_accuracy().tolerance().atol() &&
|
1922
|
+
result_accuracy().tolerance().rtol() ==
|
1923
|
+
other->result_accuracy().tolerance().rtol() &&
|
1924
|
+
result_accuracy().tolerance().ulps() ==
|
1925
|
+
other->result_accuracy().tolerance().ulps() &&
|
1926
|
+
result_accuracy().mode() == other->result_accuracy().mode();
|
1927
|
+
}
|
1928
|
+
|
1917
1929
|
void add_single_statistic(Statistic statistic) {
|
1918
1930
|
*mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
|
1919
1931
|
}
|
@@ -35,9 +35,6 @@ limitations under the License.
|
|
35
35
|
#include "xla/tsl/platform/logging.h"
|
36
36
|
|
37
37
|
namespace tsl {
|
38
|
-
|
39
|
-
class NotifierListNode;
|
40
|
-
|
41
38
|
namespace internal {
|
42
39
|
|
43
40
|
template <typename T>
|
@@ -277,6 +274,8 @@ class AsyncValue {
|
|
277
274
|
protected:
|
278
275
|
friend class IndirectAsyncValue;
|
279
276
|
|
277
|
+
struct WaiterListNode;
|
278
|
+
|
280
279
|
static constexpr uint16_t kUnknownTypeId = 0;
|
281
280
|
|
282
281
|
// Utility template for tag dispatching.
|
@@ -311,7 +310,7 @@ class AsyncValue {
|
|
311
310
|
|
312
311
|
void NotifyAvailable(State available_state);
|
313
312
|
void Destroy();
|
314
|
-
void RunWaiters(
|
313
|
+
void RunWaiters(WaiterListNode* list);
|
315
314
|
|
316
315
|
// IsTypeIdCompatible returns true if the type value stored in this AsyncValue
|
317
316
|
// instance can be safely cast to `T`. This is a conservative check. I.e.
|
@@ -369,6 +368,16 @@ class AsyncValue {
|
|
369
368
|
// This is a 16-bit value that identifies the type.
|
370
369
|
uint16_t type_id_ = 0;
|
371
370
|
|
371
|
+
// This is a singly linked list of nodes waiting for notification, hanging off
|
372
|
+
// of AsyncValue. When the value becomes available or if an error occurs, the
|
373
|
+
// callbacks are informed.
|
374
|
+
struct WaiterListNode {
|
375
|
+
virtual ~WaiterListNode() = default;
|
376
|
+
virtual void operator()() = 0;
|
377
|
+
|
378
|
+
WaiterListNode* next = nullptr;
|
379
|
+
};
|
380
|
+
|
372
381
|
// The waiter list and the state are compacted into one single atomic word as
|
373
382
|
// accesses to them are tightly related. To change the state from unavailable
|
374
383
|
// (i.e. kUnconstructed or kConstructed) to available
|
@@ -379,7 +388,7 @@ class AsyncValue {
|
|
379
388
|
// Invariant: If the state is not available, then the waiter list must be
|
380
389
|
// nullptr.
|
381
390
|
struct WaitersAndState {
|
382
|
-
// We rely on the fact that all `
|
391
|
+
// We rely on the fact that all `WaiterListNode` values are aligned at
|
383
392
|
// least to 4 bytes and we can encode state in the lowest 2 bits. We use
|
384
393
|
// the conservative estimation of the minimal alignment of pointers returned
|
385
394
|
// from memory allocation functions.
|
@@ -390,7 +399,7 @@ class AsyncValue {
|
|
390
399
|
static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
|
391
400
|
static constexpr uintptr_t kPointerMask = ~kStateMask;
|
392
401
|
|
393
|
-
WaitersAndState(
|
402
|
+
WaitersAndState(WaiterListNode* ptr, State state) {
|
394
403
|
value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
|
395
404
|
(state & kStateMask);
|
396
405
|
}
|
@@ -399,8 +408,8 @@ class AsyncValue {
|
|
399
408
|
return State(static_cast<State::StateEnum>(value & kStateMask));
|
400
409
|
}
|
401
410
|
|
402
|
-
|
403
|
-
return reinterpret_cast<
|
411
|
+
WaiterListNode* waiter() const {
|
412
|
+
return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
|
404
413
|
}
|
405
414
|
|
406
415
|
uintptr_t value;
|
@@ -466,8 +475,26 @@ class AsyncValue {
|
|
466
475
|
return (*type_info_table)[type_id_ - 1];
|
467
476
|
}
|
468
477
|
|
469
|
-
|
470
|
-
|
478
|
+
// Adds a waiter list node to the waiter linked list. If the value is
|
479
|
+
// available or becomes available, this calls the waiter immediately.
|
480
|
+
// Otherwise, we add waiter to the list where it will be called when the value
|
481
|
+
// becomes available.
|
482
|
+
void EnqueueWaiterListNode(WaiterListNode* waiter,
|
483
|
+
WaitersAndState waiters_and_state);
|
484
|
+
|
485
|
+
template <typename Waiter>
|
486
|
+
void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
|
487
|
+
static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
|
488
|
+
|
489
|
+
struct Node final : public WaiterListNode {
|
490
|
+
explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
|
491
|
+
void operator()() final { waiter(); }
|
492
|
+
Waiter waiter;
|
493
|
+
};
|
494
|
+
|
495
|
+
EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
|
496
|
+
waiters_and_state);
|
497
|
+
}
|
471
498
|
|
472
499
|
// This is a global counter of the number of AsyncValue instances currently
|
473
500
|
// live in the process. This is intended to be used for debugging only, and
|
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
|
|
983
1010
|
// Clients generally want to use AndThen without them each having to check
|
984
1011
|
// to see if the value is present. Check for them, and immediately run the
|
985
1012
|
// waiter if it is already here.
|
986
|
-
auto
|
987
|
-
if (
|
988
|
-
|
989
|
-
DCHECK_EQ(
|
1013
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1014
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1015
|
+
waiters_and_state.state() == State::kError) {
|
1016
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
990
1017
|
waiter();
|
991
1018
|
return;
|
992
1019
|
}
|
993
|
-
|
1020
|
+
|
1021
|
+
EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
|
994
1022
|
}
|
995
1023
|
|
996
1024
|
template <typename Waiter>
|
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
|
|
998
1026
|
// Clients generally want to use AndThen without them each having to check
|
999
1027
|
// to see if the value is present. Check for them, and immediately run the
|
1000
1028
|
// waiter if it is already here.
|
1001
|
-
auto
|
1002
|
-
if (
|
1003
|
-
|
1004
|
-
DCHECK_EQ(
|
1029
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1030
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1031
|
+
waiters_and_state.state() == State::kError) {
|
1032
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
1005
1033
|
executor.Execute(std::forward<Waiter>(waiter));
|
1006
1034
|
return;
|
1007
1035
|
}
|
1036
|
+
|
1008
1037
|
EnqueueWaiter(
|
1009
|
-
[&executor, waiter = std::forward<Waiter>(waiter)]
|
1038
|
+
[&executor, waiter = std::forward<Waiter>(waiter)] {
|
1010
1039
|
executor.Execute(std::move(waiter));
|
1011
1040
|
},
|
1012
|
-
|
1041
|
+
waiters_and_state);
|
1013
1042
|
}
|
1014
1043
|
|
1015
1044
|
inline void AsyncValue::Destroy() {
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|
29
29
|
# This value changes every day with an automatic CL. It can be modified in code
|
30
30
|
# via `forward_compatibility_horizon()` or with the environment variable
|
31
31
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
32
|
-
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2,
|
32
|
+
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 21)
|
33
33
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
34
34
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
35
35
|
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|