tf-nightly-cpu 2.20.0.dev20250220__cp311-cp311-win_amd64.whl → 2.20.0.dev20250222__cp311-cp311-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
- tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
- tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
- tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
- tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
- tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
- tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
- tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
- tensorflow/include/tensorflow/core/public/release_version.h +39 -0
- tensorflow/include/tensorflow/core/public/version.h +112 -127
- tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
- tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
- tensorflow/include/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/xla/service/constant_value.h +1 -0
- tensorflow/include/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
- tensorflow/python/eager/imperative_grad.py +5 -5
- tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
- tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
- tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
- tensorflow/python/eager/tape.py +2 -2
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/ops/summary_ops_v2.py +5 -1
- tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
- tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +2 -2
- tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +128 -123
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
- tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,67 @@ limitations under the License.
|
|
19
19
|
#include <functional>
|
20
20
|
#include <memory>
|
21
21
|
#include <optional>
|
22
|
+
#include <string>
|
22
23
|
|
23
|
-
#include "absl/
|
24
|
+
#include "absl/log/check.h"
|
25
|
+
#include "absl/log/log.h"
|
24
26
|
#include "absl/status/statusor.h"
|
27
|
+
#include "absl/strings/string_view.h"
|
25
28
|
#include "absl/types/span.h"
|
29
|
+
#include "xla/hlo/ir/hlo_module.h"
|
30
|
+
#include "xla/hlo/parser/hlo_parser.h"
|
26
31
|
#include "xla/service/compiler.h"
|
27
32
|
#include "xla/service/hlo_module_config.h"
|
28
33
|
#include "xla/shape.h"
|
34
|
+
#include "xla/util.h"
|
29
35
|
|
30
36
|
namespace xla {
|
31
37
|
|
38
|
+
// Converts an HloModule from the given hlo textual IR string (in
|
39
|
+
// HloModule::ToString format).
|
40
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
41
|
+
absl::string_view hlo_string,
|
42
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
43
|
+
|
44
|
+
// Creates an HloModule from the given proto.
|
45
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
46
|
+
const HloModuleProto& proto,
|
47
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
48
|
+
|
49
|
+
// Create an HLO state from serialized representation. In addition to
|
50
|
+
// creating the proto with HloModule::CreateFromProto(...) it also
|
51
|
+
// uses HloVerifier to ensure basic invariants are held.
|
52
|
+
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
53
|
+
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
54
|
+
// is allowed in post-optimizations HLOs.
|
55
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
56
|
+
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
57
|
+
bool is_module_post_optimizations = false);
|
58
|
+
|
59
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
60
|
+
// HloModule.
|
61
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
|
62
|
+
absl::string_view filename,
|
63
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
64
|
+
|
65
|
+
// Reads the proto file in xla.HloModule format, creates and returns the
|
66
|
+
// HloModule.
|
67
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
|
68
|
+
absl::string_view filename, const DebugOptions& debug_options);
|
69
|
+
|
70
|
+
// Reads the HLO text dump file in HloModule::ToString format, creates and
|
71
|
+
// returns the HloModule.
|
72
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
|
73
|
+
absl::string_view filename,
|
74
|
+
const DebugOptions& debug_options = DebugOptions::default_instance(),
|
75
|
+
const HloParserOptions& options = HloParserOptions());
|
76
|
+
|
77
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
78
|
+
// HloModule.
|
79
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
|
80
|
+
absl::string_view hlo_file,
|
81
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
82
|
+
|
32
83
|
// Creates an HloModuleConfig for a given program shape and arguments.
|
33
84
|
// If execution_options does not set num_replicas, default_num_replicas is used.
|
34
85
|
// num_threads is optional; if not given, intra_op_parallelism_threads not set.
|
@@ -18,8 +18,6 @@ limitations under the License.
|
|
18
18
|
#ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
|
19
19
|
#define XLA_SERVICE_HLO_PROTO_UTIL_H_
|
20
20
|
|
21
|
-
#include <string>
|
22
|
-
|
23
21
|
#include "absl/status/status.h"
|
24
22
|
#include "xla/hlo/ir/hlo_module.h"
|
25
23
|
#include "xla/service/buffer_assignment.h"
|
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
|
|
35
33
|
// will not be included in the output.
|
36
34
|
HloProto MakeHloProto(const HloModule& module);
|
37
35
|
|
38
|
-
// Create an HLO state from serialized representation. In addition to
|
39
|
-
// creating the proto with HloModule::CreateFromProto(...) it also
|
40
|
-
// uses HloVerifier to ensure basic invariants are held.
|
41
|
-
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
42
|
-
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
43
|
-
// is allowed in post-optimizations HLOs.
|
44
|
-
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
45
|
-
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
46
|
-
bool is_module_post_optimizations = false);
|
47
|
-
|
48
36
|
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
49
37
|
// refer to shapes inside of the given HloProto.
|
50
38
|
absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
|
@@ -35,9 +35,6 @@ limitations under the License.
|
|
35
35
|
#include "xla/tsl/platform/logging.h"
|
36
36
|
|
37
37
|
namespace tsl {
|
38
|
-
|
39
|
-
class NotifierListNode;
|
40
|
-
|
41
38
|
namespace internal {
|
42
39
|
|
43
40
|
template <typename T>
|
@@ -277,6 +274,8 @@ class AsyncValue {
|
|
277
274
|
protected:
|
278
275
|
friend class IndirectAsyncValue;
|
279
276
|
|
277
|
+
struct WaiterListNode;
|
278
|
+
|
280
279
|
static constexpr uint16_t kUnknownTypeId = 0;
|
281
280
|
|
282
281
|
// Utility template for tag dispatching.
|
@@ -311,7 +310,7 @@ class AsyncValue {
|
|
311
310
|
|
312
311
|
void NotifyAvailable(State available_state);
|
313
312
|
void Destroy();
|
314
|
-
void RunWaiters(
|
313
|
+
void RunWaiters(WaiterListNode* list);
|
315
314
|
|
316
315
|
// IsTypeIdCompatible returns true if the type value stored in this AsyncValue
|
317
316
|
// instance can be safely cast to `T`. This is a conservative check. I.e.
|
@@ -369,6 +368,16 @@ class AsyncValue {
|
|
369
368
|
// This is a 16-bit value that identifies the type.
|
370
369
|
uint16_t type_id_ = 0;
|
371
370
|
|
371
|
+
// This is a singly linked list of nodes waiting for notification, hanging off
|
372
|
+
// of AsyncValue. When the value becomes available or if an error occurs, the
|
373
|
+
// callbacks are informed.
|
374
|
+
struct WaiterListNode {
|
375
|
+
virtual ~WaiterListNode() = default;
|
376
|
+
virtual void operator()() = 0;
|
377
|
+
|
378
|
+
WaiterListNode* next = nullptr;
|
379
|
+
};
|
380
|
+
|
372
381
|
// The waiter list and the state are compacted into one single atomic word as
|
373
382
|
// accesses to them are tightly related. To change the state from unavailable
|
374
383
|
// (i.e. kUnconstructed or kConstructed) to available
|
@@ -379,7 +388,7 @@ class AsyncValue {
|
|
379
388
|
// Invariant: If the state is not available, then the waiter list must be
|
380
389
|
// nullptr.
|
381
390
|
struct WaitersAndState {
|
382
|
-
// We rely on the fact that all `
|
391
|
+
// We rely on the fact that all `WaiterListNode` values are aligned at
|
383
392
|
// least to 4 bytes and we can encode state in the lowest 2 bits. We use
|
384
393
|
// the conservative estimation of the minimal alignment of pointers returned
|
385
394
|
// from memory allocation functions.
|
@@ -390,7 +399,7 @@ class AsyncValue {
|
|
390
399
|
static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
|
391
400
|
static constexpr uintptr_t kPointerMask = ~kStateMask;
|
392
401
|
|
393
|
-
WaitersAndState(
|
402
|
+
WaitersAndState(WaiterListNode* ptr, State state) {
|
394
403
|
value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
|
395
404
|
(state & kStateMask);
|
396
405
|
}
|
@@ -399,8 +408,8 @@ class AsyncValue {
|
|
399
408
|
return State(static_cast<State::StateEnum>(value & kStateMask));
|
400
409
|
}
|
401
410
|
|
402
|
-
|
403
|
-
return reinterpret_cast<
|
411
|
+
WaiterListNode* waiter() const {
|
412
|
+
return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
|
404
413
|
}
|
405
414
|
|
406
415
|
uintptr_t value;
|
@@ -466,8 +475,26 @@ class AsyncValue {
|
|
466
475
|
return (*type_info_table)[type_id_ - 1];
|
467
476
|
}
|
468
477
|
|
469
|
-
|
470
|
-
|
478
|
+
// Adds a waiter list node to the waiter linked list. If the value is
|
479
|
+
// available or becomes available, this calls the waiter immediately.
|
480
|
+
// Otherwise, we add waiter to the list where it will be called when the value
|
481
|
+
// becomes available.
|
482
|
+
void EnqueueWaiterListNode(WaiterListNode* waiter,
|
483
|
+
WaitersAndState waiters_and_state);
|
484
|
+
|
485
|
+
template <typename Waiter>
|
486
|
+
void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
|
487
|
+
static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
|
488
|
+
|
489
|
+
struct Node final : public WaiterListNode {
|
490
|
+
explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
|
491
|
+
void operator()() final { waiter(); }
|
492
|
+
Waiter waiter;
|
493
|
+
};
|
494
|
+
|
495
|
+
EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
|
496
|
+
waiters_and_state);
|
497
|
+
}
|
471
498
|
|
472
499
|
// This is a global counter of the number of AsyncValue instances currently
|
473
500
|
// live in the process. This is intended to be used for debugging only, and
|
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
|
|
983
1010
|
// Clients generally want to use AndThen without them each having to check
|
984
1011
|
// to see if the value is present. Check for them, and immediately run the
|
985
1012
|
// waiter if it is already here.
|
986
|
-
auto
|
987
|
-
if (
|
988
|
-
|
989
|
-
DCHECK_EQ(
|
1013
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1014
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1015
|
+
waiters_and_state.state() == State::kError) {
|
1016
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
990
1017
|
waiter();
|
991
1018
|
return;
|
992
1019
|
}
|
993
|
-
|
1020
|
+
|
1021
|
+
EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
|
994
1022
|
}
|
995
1023
|
|
996
1024
|
template <typename Waiter>
|
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
|
|
998
1026
|
// Clients generally want to use AndThen without them each having to check
|
999
1027
|
// to see if the value is present. Check for them, and immediately run the
|
1000
1028
|
// waiter if it is already here.
|
1001
|
-
auto
|
1002
|
-
if (
|
1003
|
-
|
1004
|
-
DCHECK_EQ(
|
1029
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1030
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1031
|
+
waiters_and_state.state() == State::kError) {
|
1032
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
1005
1033
|
executor.Execute(std::forward<Waiter>(waiter));
|
1006
1034
|
return;
|
1007
1035
|
}
|
1036
|
+
|
1008
1037
|
EnqueueWaiter(
|
1009
|
-
[&executor, waiter = std::forward<Waiter>(waiter)]
|
1038
|
+
[&executor, waiter = std::forward<Waiter>(waiter)] {
|
1010
1039
|
executor.Execute(std::move(waiter));
|
1011
1040
|
},
|
1012
|
-
|
1041
|
+
waiters_and_state);
|
1013
1042
|
}
|
1014
1043
|
|
1015
1044
|
inline void AsyncValue::Destroy() {
|
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
|
|
1604
1604
|
Index padding_left = 0, Index padding_right = 0) {
|
1605
1605
|
typedef typename internal::traits<Input>::Index TensorIndex;
|
1606
1606
|
typedef typename internal::traits<Input>::Scalar InputScalar;
|
1607
|
-
TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
-
|
1607
|
+
TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
+
internal::traits<Input>::Layout, TensorIndex> >
|
1609
1609
|
in(input);
|
1610
|
-
TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
-
|
1612
|
-
|
1610
|
+
TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
+
internal::traits<Kernel>::NumDimensions,
|
1612
|
+
internal::traits<Kernel>::Layout, TensorIndex> >
|
1613
1613
|
kern(kernel);
|
1614
1614
|
|
1615
1615
|
EIGEN_STATIC_ASSERT(
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|
29
29
|
# This value changes every day with an automatic CL. It can be modified in code
|
30
30
|
# via `forward_compatibility_horizon()` or with the environment variable
|
31
31
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
32
|
-
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2,
|
32
|
+
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 21)
|
33
33
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
34
34
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
35
35
|
|
Binary file
|
Binary file
|
@@ -43,13 +43,13 @@ def imperative_grad(tape,
|
|
43
43
|
target: either a Tensor or list of Tensors to be differentiated.
|
44
44
|
sources: list of Tensors for which we want gradients
|
45
45
|
output_gradients: if not None, a list of gradient provided for each Target,
|
46
|
-
|
46
|
+
or None if we are to use the target's computed downstream gradient.
|
47
47
|
sources_raw: if not None, a list of the source python objects from which the
|
48
|
-
|
49
|
-
|
48
|
+
sources were generated. Should have the same length as sources. Only needs
|
49
|
+
to be populated if unconnected_gradients is 'zero'.
|
50
50
|
unconnected_gradients: determines the value returned if the target and
|
51
|
-
|
52
|
-
|
51
|
+
sources are unconnected. When 'none' the value returned is None whereas
|
52
|
+
when 'zero' a zero tensor in the same shape as the sources is returned.
|
53
53
|
|
54
54
|
Returns:
|
55
55
|
the gradient wrt each of the sources.
|
@@ -55,7 +55,7 @@ class CallOptions:
|
|
55
55
|
# Used by ACD to list Ops/Tensors/Callables that must be called in advance.
|
56
56
|
control_captures: List[Any] = dataclasses.field(default_factory=list)
|
57
57
|
|
58
|
-
# Determines what kind of
|
58
|
+
# Determines what kind of partitioned call is used for this function.
|
59
59
|
is_stateful: bool = False
|
60
60
|
|
61
61
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""Implementation for defining get_compiler_ir."""
|
16
16
|
from typing import List, Optional
|
17
17
|
import warnings
|
18
18
|
|
@@ -966,7 +966,7 @@ class Function(core.PolymorphicFunction, trackable.Trackable):
|
|
966
966
|
|
967
967
|
def _check_inputs(args, kwargs):
|
968
968
|
all_inputs = list(args) + list(kwargs.values())
|
969
|
-
#
|
969
|
+
# Empty input is okay.
|
970
970
|
if not all_inputs:
|
971
971
|
return
|
972
972
|
if any(map(is_tensor_spec, all_inputs)) and any(
|
@@ -1423,7 +1423,8 @@ def function(
|
|
1423
1423
|
thought of as compile-time constants), and builds a separate `tf.Graph` for
|
1424
1424
|
each set of Python arguments that it encounters.
|
1425
1425
|
For more information, see the
|
1426
|
-
[tf.function
|
1426
|
+
[tf.function
|
1427
|
+
guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
|
1427
1428
|
|
1428
1429
|
Executing a `PolymorphicFunction` will select and execute the appropriate
|
1429
1430
|
`ConcreteFunction` based on the argument types and values.
|
@@ -1440,14 +1441,17 @@ def function(
|
|
1440
1441
|
>>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
|
1441
1442
|
True
|
1442
1443
|
|
1443
|
-
`ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
|
1444
|
-
|
1444
|
+
`ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
|
1445
|
+
their
|
1446
|
+
input is restricted to the types to which they're specialized.
|
1445
1447
|
|
1446
1448
|
## Retracing
|
1447
1449
|
|
1448
|
-
`ConcreteFunctions` are built (traced) on the fly, as the
|
1450
|
+
`ConcreteFunctions` are built (traced) on the fly, as the
|
1451
|
+
`PolymorphicFunction` is
|
1449
1452
|
called with new TensorFlow types or shapes, or with new Python values as
|
1450
|
-
arguments. When `PolymorphicFunction` builds a new trace, it is said that
|
1453
|
+
arguments. When `PolymorphicFunction` builds a new trace, it is said that
|
1454
|
+
`func`
|
1451
1455
|
is retraced. Retracing is a frequent performance concern for `tf.function` as
|
1452
1456
|
it can be considerably slower than executing a graph that's already been
|
1453
1457
|
traced. It is ideal to minimize the amount of retracing in your code.
|
@@ -1473,7 +1477,8 @@ def function(
|
|
1473
1477
|
|
1474
1478
|
## Input signatures
|
1475
1479
|
|
1476
|
-
For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
|
1480
|
+
For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
|
1481
|
+
for
|
1477
1482
|
every unique set of input shapes and datatypes. The example below creates two
|
1478
1483
|
separate `ConcreteFunction`s, each specialized to a different shape:
|
1479
1484
|
|
@@ -1580,59 +1585,58 @@ def function(
|
|
1580
1585
|
`func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
|
1581
1586
|
autograph: Whether autograph should be applied on `func` before tracing a
|
1582
1587
|
graph. Data-dependent Python control flow statements require
|
1583
|
-
`autograph=True`. For more information, see the
|
1584
|
-
|
1588
|
+
`autograph=True`. For more information, see the [tf.function and AutoGraph
|
1589
|
+
guide](
|
1585
1590
|
https://www.tensorflow.org/guide/function#autograph_transformations).
|
1586
1591
|
jit_compile: If `True`, compiles the function using
|
1587
1592
|
[XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
|
1588
1593
|
such as fusion, and attempts to emit more efficient code. This may
|
1589
|
-
drastically improve the performance. If set to `True`,
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
amount of retracing, for example by using more generic shapes. This
|
1602
|
-
can be controlled for user objects by customizing their associated
|
1594
|
+
drastically improve the performance. If set to `True`, the whole function
|
1595
|
+
needs to be compilable by XLA, or an `errors.InvalidArgumentError` is
|
1596
|
+
thrown. If `None` (default), compiles the function with XLA when running
|
1597
|
+
on TPU and goes through the regular function execution path when running
|
1598
|
+
on other devices. If `False`, executes the function without XLA
|
1599
|
+
compilation. Set this value to `False` when directly running a
|
1600
|
+
multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
|
1601
|
+
host CPU). Not all functions are compilable, see a list of [sharp
|
1602
|
+
corners](https://tensorflow.org/xla/known_issues).
|
1603
|
+
reduce_retracing: When True, `tf.function` attempts to reduce the amount of
|
1604
|
+
retracing, for example by using more generic shapes. This can be
|
1605
|
+
controlled for user objects by customizing their associated
|
1603
1606
|
`tf.types.experimental.TraceType`.
|
1604
1607
|
experimental_implements: If provided, contains a name of a "known" function
|
1605
|
-
this implements. For example "mycompany.my_recurrent_cell".
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
for details. For an example of utilizing
|
1608
|
+
this implements. For example "mycompany.my_recurrent_cell". This is stored
|
1609
|
+
as an attribute in inference function, which can then be detected when
|
1610
|
+
processing serialized function. See [standardizing composite
|
1611
|
+
ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)
|
1612
|
+
# pylint: disable=line-too-long for details. For an example of utilizing
|
1613
|
+
this attribute see this
|
1610
1614
|
[example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
|
1611
1615
|
The code above automatically detects and substitutes function that
|
1612
1616
|
implements "embedded_matmul" and allows TFLite to substitute its own
|
1613
|
-
implementations. For instance, a tensorflow user can use this
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1617
|
+
implementations. For instance, a tensorflow user can use this attribute to
|
1618
|
+
mark that their function also implements `embedded_matmul` (perhaps more
|
1619
|
+
efficiently!) by specifying it using this parameter:
|
1620
|
+
`@tf.function(experimental_implements="embedded_matmul")` This can either
|
1621
|
+
be specified as just the string name of the function or a NameAttrList
|
1622
|
+
corresponding to a list of key-value attributes associated with the
|
1623
|
+
function name. The name of the function will be in the 'name' field of the
|
1624
|
+
NameAttrList. To define a formal TF op for this function implements, try
|
1625
|
+
the experimental [composite
|
1626
|
+
TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
|
1623
1627
|
project.
|
1624
1628
|
experimental_autograph_options: Optional tuple of
|
1625
1629
|
`tf.autograph.experimental.Feature` values.
|
1626
1630
|
experimental_attributes: Optional dictionary of attributes to include in the
|
1627
1631
|
generated FunctionDefs.
|
1628
|
-
experimental_relax_shapes: Deprecated. Use `reduce_retracing`
|
1629
|
-
instead.
|
1632
|
+
experimental_relax_shapes: Deprecated. Use `reduce_retracing` instead.
|
1630
1633
|
experimental_compile: Deprecated alias to 'jit_compile'.
|
1631
1634
|
experimental_follow_type_hints: Deprecated. Please use input_signature or
|
1632
1635
|
reduce_retracing instead.
|
1633
1636
|
|
1634
1637
|
Returns:
|
1635
|
-
If `func` is not None, returns a
|
1638
|
+
If `func` is not None, returns a
|
1639
|
+
`tf.types.experimental.PolymorphicFunction`.
|
1636
1640
|
If `func` is None, returns a decorator that, when invoked with a single
|
1637
1641
|
`func` argument, returns a `tf.types.experimental.PolymorphicFunction`.
|
1638
1642
|
|
tensorflow/python/eager/tape.py
CHANGED
@@ -48,9 +48,9 @@ def watch(tape, tensor):
|
|
48
48
|
def default_get_variables(variable):
|
49
49
|
return [variable]
|
50
50
|
|
51
|
-
# Gets a list of changed variables. Can be
|
51
|
+
# Gets a list of changed variables. Can be overridden using
|
52
52
|
# register_variables_override. An example of overriding is for getting the
|
53
|
-
#
|
53
|
+
# variables within a distributed context.
|
54
54
|
_variables_override = default_get_variables
|
55
55
|
|
56
56
|
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -151,7 +151,11 @@ def _legacy_contrib_should_record_summaries():
|
|
151
151
|
|
152
152
|
def is_recording_summaries():
|
153
153
|
"""Returns non-Tensor boolean indicating if summaries are being recorded."""
|
154
|
-
|
154
|
+
if _summary_state.writer is None:
|
155
|
+
return False
|
156
|
+
if _summary_state.is_recording is None:
|
157
|
+
return False
|
158
|
+
return _summary_state.is_recording
|
155
159
|
|
156
160
|
|
157
161
|
@tf_export("summary.record_if", v1=[])
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -67,6 +67,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|
67
67
|
_PIPELINE_ATTRIBUTE = "_embedding_pipelining"
|
68
68
|
_PIPELINE_MODE_FORWARD = "forward"
|
69
69
|
_PIPELINE_MODE_BACKWARD = "backward"
|
70
|
+
_PIPELINE_MODEL_SEQUENTIAL = "_sequential"
|
70
71
|
|
71
72
|
|
72
73
|
TableConfig = tpu_embedding_v2_utils.TableConfig
|
@@ -95,15 +96,21 @@ class EmbeddingPipeliningContext(control_flow_ops.ControlFlowContext):
|
|
95
96
|
super().__init__()
|
96
97
|
self._name = "EmbeddingPipelinigContext"
|
97
98
|
self._mode = attr_value_pb2.AttrValue(s=compat.as_bytes(mode))
|
99
|
+
self._enable = enable
|
98
100
|
recording_summaries = summary_ops_v2.is_recording_summaries()
|
101
|
+
if not isinstance(recording_summaries, bool):
|
102
|
+
# We can't handle predicate functions at this point. So, we'll ignore the
|
103
|
+
# special casing of summary recording because, presumably, this is not
|
104
|
+
# a single step loop so pipelining is still valid.
|
105
|
+
recording_summaries = False
|
99
106
|
if enable and recording_summaries:
|
100
|
-
|
101
|
-
|
102
|
-
|
107
|
+
# We'll still flag these ops for the SC forward/backward pass, but we'll
|
108
|
+
# run them sequentially. This has to be handled in the MLIR passes
|
109
|
+
# embedding_pipelining.cc and embedding_sequencing.cc.
|
110
|
+
logging.info("Summary recording detected, disabling pipelining.")
|
111
|
+
self._mode = attr_value_pb2.AttrValue(
|
112
|
+
s=compat.as_bytes(mode + _PIPELINE_MODEL_SEQUENTIAL)
|
103
113
|
)
|
104
|
-
self._enable = False
|
105
|
-
else:
|
106
|
-
self._enable = enable
|
107
114
|
|
108
115
|
def to_control_flow_context_def(
|
109
116
|
self, context_def: Any, export_scope: Any = None
|
@@ -1637,7 +1644,7 @@ class TPUEmbeddingV2(tpu_embedding_base.TPUEmbeddingBase):
|
|
1637
1644
|
row_offset: int,
|
1638
1645
|
col_offset: int,
|
1639
1646
|
col_shift: int,
|
1640
|
-
|
1647
|
+
unused_vocab_size: int,
|
1641
1648
|
num_sc_per_chip: int,
|
1642
1649
|
num_sc_shards: int,
|
1643
1650
|
stacked_table_sample_count: int,
|