torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from functools import wraps
|
|
10
|
+
from typing import overload, TypeVar
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import is_tensor_collection
|
|
14
|
+
from tensordict.base import TensorDictBase
|
|
15
|
+
|
|
16
|
+
from torchrl.data.tensor_specs import DEVICE_TYPING, TensorSpec
|
|
17
|
+
from torchrl.envs.common import EnvBase
|
|
18
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@overload
|
|
24
|
+
def _maybe_to_device(r: tuple, device: DEVICE_TYPING) -> tuple:
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@overload
|
|
29
|
+
def _maybe_to_device(r: list, device: DEVICE_TYPING) -> list:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@overload
|
|
34
|
+
def _maybe_to_device(r: dict, device: DEVICE_TYPING) -> dict:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def _maybe_to_device(r: TensorDictBase, device: DEVICE_TYPING) -> TensorDictBase:
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def _maybe_to_device(r: T, device: DEVICE_TYPING) -> T:
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _maybe_to_device(r, device):
|
|
49
|
+
if isinstance(r, tuple):
|
|
50
|
+
return tuple(_maybe_to_device(r_i, device) for r_i in r)
|
|
51
|
+
if isinstance(r, list):
|
|
52
|
+
return [_maybe_to_device(r_i, device) for r_i in r]
|
|
53
|
+
if isinstance(r, dict):
|
|
54
|
+
return {k: _maybe_to_device(v, device) for k, v in r.items()}
|
|
55
|
+
if hasattr(r, "to"):
|
|
56
|
+
return r.to(device)
|
|
57
|
+
return r
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@overload
|
|
61
|
+
def _maybe_clear_device(r: tuple) -> tuple:
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@overload
|
|
66
|
+
def _maybe_clear_device(r: list) -> list:
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@overload
|
|
71
|
+
def _maybe_clear_device(r: dict) -> dict:
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@overload
|
|
76
|
+
def _maybe_clear_device(r: TensorDictBase) -> TensorDictBase:
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@overload
|
|
81
|
+
def _maybe_clear_device(r: T) -> T:
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _maybe_clear_device(r):
|
|
86
|
+
if isinstance(r, tuple):
|
|
87
|
+
return tuple(_maybe_clear_device(r_i) for r_i in r)
|
|
88
|
+
if isinstance(r, list):
|
|
89
|
+
return [_maybe_clear_device(r_i) for r_i in r]
|
|
90
|
+
if isinstance(r, dict):
|
|
91
|
+
return {k: _maybe_clear_device(v) for k, v in r.items()}
|
|
92
|
+
if is_tensor_collection(r) or isinstance(r, TensorSpec):
|
|
93
|
+
r = r.clone()
|
|
94
|
+
r = r.cpu().clear_device_()
|
|
95
|
+
return r
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _map_input_output_device(func: Callable):
|
|
99
|
+
"""Decorator that maps inputs to CPU and outputs to the local device.
|
|
100
|
+
|
|
101
|
+
This decorator ensures that:
|
|
102
|
+
1. All inputs are moved to CPU before being sent to the remote Ray actor
|
|
103
|
+
2. All outputs are moved to the local device (if set) after receiving from the Ray actor
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
func: The method to decorate
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
The decorated method
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
@wraps(func)
|
|
113
|
+
def wrapper(self, *args, **kwargs):
|
|
114
|
+
args = _maybe_clear_device(args)
|
|
115
|
+
kwargs = _maybe_clear_device(kwargs)
|
|
116
|
+
r = func(self, *args, **kwargs)
|
|
117
|
+
if hasattr(self, "_device"):
|
|
118
|
+
if self._device is not None:
|
|
119
|
+
r = _maybe_to_device(r, self._device)
|
|
120
|
+
else:
|
|
121
|
+
r = _maybe_clear_device(r)
|
|
122
|
+
return r
|
|
123
|
+
|
|
124
|
+
return wrapper
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class RayTransform(Transform, ABC):
|
|
128
|
+
"""Base class for transforms that delegate operations to Ray remote actors.
|
|
129
|
+
|
|
130
|
+
This class provides a framework for creating transforms that offload their operations
|
|
131
|
+
to Ray remote actors, enabling:
|
|
132
|
+
- Resource isolation and dedicated CPU/GPU allocation
|
|
133
|
+
- Shared state across multiple environment instances
|
|
134
|
+
- Distributed computation for expensive operations
|
|
135
|
+
|
|
136
|
+
The class automatically handles:
|
|
137
|
+
- Ray actor lifecycle management (creation, reuse, cleanup)
|
|
138
|
+
- Device mapping between local client and remote actor contexts
|
|
139
|
+
- Transparent method delegation with proper error handling
|
|
140
|
+
- Local management of parent/container relationships
|
|
141
|
+
|
|
142
|
+
Subclasses only need to implement `_create_actor()` to specify how their
|
|
143
|
+
specific Ray actor should be created and configured.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
num_cpus: CPU cores to allocate to the Ray actor
|
|
147
|
+
num_gpus: GPU devices to allocate to the Ray actor
|
|
148
|
+
device: Local device for tensor operations (client-side)
|
|
149
|
+
actor_name: Optional name for actor reuse across instances
|
|
150
|
+
**kwargs: Additional arguments passed to Transform base class
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
```python
|
|
154
|
+
class MyRayTransform(RayTransform):
|
|
155
|
+
def _create_actor(self, **kwargs):
|
|
156
|
+
RemoteClass = self._ray.remote(num_cpus=self._num_cpus)(MyClass)
|
|
157
|
+
return RemoteClass.remote(**kwargs)
|
|
158
|
+
```
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def _ray(self):
|
|
163
|
+
ray = self.__dict__.get("_ray_val", None)
|
|
164
|
+
if ray is not None:
|
|
165
|
+
return ray
|
|
166
|
+
# Import ray here to avoid requiring it as a dependency
|
|
167
|
+
try:
|
|
168
|
+
import ray
|
|
169
|
+
except ImportError:
|
|
170
|
+
raise ImportError(
|
|
171
|
+
"Ray is required for RayTransform. Install with: pip install ray"
|
|
172
|
+
)
|
|
173
|
+
self.__dict__["_ray_val"] = ray
|
|
174
|
+
return ray
|
|
175
|
+
|
|
176
|
+
@_ray.setter
|
|
177
|
+
def _ray(self, value):
|
|
178
|
+
self.__dict__["_ray_val"] = value
|
|
179
|
+
|
|
180
|
+
def __getstate__(self):
|
|
181
|
+
state = super().__getstate__()
|
|
182
|
+
state.pop("_ray_val", None)
|
|
183
|
+
return state
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
*,
|
|
188
|
+
num_cpus: int | None = None,
|
|
189
|
+
num_gpus: int | None = None,
|
|
190
|
+
device: DEVICE_TYPING | None = None,
|
|
191
|
+
actor_name: str | None = None,
|
|
192
|
+
**kwargs,
|
|
193
|
+
):
|
|
194
|
+
"""Initialize the RayTransform.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
num_cpus: Number of CPUs to allocate to the Ray actor
|
|
198
|
+
num_gpus: Number of GPUs to allocate to the Ray actor
|
|
199
|
+
device: Local device for tensor operations
|
|
200
|
+
actor_name: Name of the Ray actor (for reuse)
|
|
201
|
+
**kwargs: Additional arguments passed to Transform
|
|
202
|
+
"""
|
|
203
|
+
super().__init__(in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys"))
|
|
204
|
+
|
|
205
|
+
self._num_cpus = num_cpus
|
|
206
|
+
self._num_gpus = num_gpus
|
|
207
|
+
self._device = device
|
|
208
|
+
self._actor_name = actor_name
|
|
209
|
+
self._actor = None
|
|
210
|
+
|
|
211
|
+
# Initialize the Ray actor
|
|
212
|
+
self._initialize_actor(**kwargs)
|
|
213
|
+
|
|
214
|
+
def _initialize_actor(self, **kwargs):
|
|
215
|
+
"""Initialize the Ray actor, either by reusing existing or creating new."""
|
|
216
|
+
# First attempt to get the actor if it already exists
|
|
217
|
+
if self._actor_name is not None:
|
|
218
|
+
try:
|
|
219
|
+
existing_actor = self._ray.get_actor(self._actor_name)
|
|
220
|
+
self._actor = existing_actor
|
|
221
|
+
return
|
|
222
|
+
except ValueError:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
# Create new actor
|
|
226
|
+
self._actor = self._create_actor(**kwargs)
|
|
227
|
+
|
|
228
|
+
@abstractmethod
|
|
229
|
+
def _create_actor(self, **kwargs):
|
|
230
|
+
"""Create and return a Ray actor.
|
|
231
|
+
|
|
232
|
+
This method should be implemented by subclasses to create the specific
|
|
233
|
+
Ray actor needed for their operations.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
**kwargs: Additional arguments for actor creation
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
The created Ray actor
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
# Container management - handled locally, not delegated to remote actor
|
|
243
|
+
def set_container(self, container: Transform | EnvBase) -> None:
|
|
244
|
+
"""Set the container for this transform. This is handled locally."""
|
|
245
|
+
result = super().set_container(container)
|
|
246
|
+
|
|
247
|
+
# After setting the container locally, provide batch size information to the remote actor
|
|
248
|
+
# This ensures the remote actor has the right batch size for proper shape handling
|
|
249
|
+
if self.parent is not None:
|
|
250
|
+
parent_batch_size = self.parent.batch_size
|
|
251
|
+
|
|
252
|
+
# Set the batch size directly on the remote actor to override its initialization
|
|
253
|
+
self._ray.get(self._actor._set_attr.remote("batch_size", parent_batch_size))
|
|
254
|
+
|
|
255
|
+
# Also disable validation on the remote actor since we'll handle consistency locally
|
|
256
|
+
self._ray.get(self._actor._set_attr.remote("_validated", True))
|
|
257
|
+
|
|
258
|
+
return result
|
|
259
|
+
|
|
260
|
+
def reset_parent(self) -> None:
|
|
261
|
+
"""Reset the parent. This is handled locally."""
|
|
262
|
+
return super().reset_parent()
|
|
263
|
+
|
|
264
|
+
def clone(self):
|
|
265
|
+
"""Clone the transform."""
|
|
266
|
+
# Use the parent's clone method to properly copy all Transform attributes
|
|
267
|
+
new_instance = super().clone()
|
|
268
|
+
# Then copy our specific Ray attributes to share the same actor
|
|
269
|
+
new_instance._actor = self._actor
|
|
270
|
+
new_instance._ray = self._ray
|
|
271
|
+
new_instance._device = getattr(self, "_device", None)
|
|
272
|
+
new_instance._num_cpus = self._num_cpus
|
|
273
|
+
new_instance._num_gpus = self._num_gpus
|
|
274
|
+
new_instance._actor_name = self._actor_name
|
|
275
|
+
return new_instance
|
|
276
|
+
|
|
277
|
+
def empty_cache(self):
|
|
278
|
+
"""Empty cache."""
|
|
279
|
+
super().empty_cache()
|
|
280
|
+
return self._ray.get(self._actor.empty_cache.remote())
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def container(self) -> EnvBase | None:
|
|
284
|
+
"""Returns the env containing the transform. This is handled locally."""
|
|
285
|
+
return super().container
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def parent(self) -> EnvBase | None:
|
|
289
|
+
"""Returns the parent env of the transform. This is handled locally."""
|
|
290
|
+
return super().parent
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def base_env(self):
|
|
294
|
+
"""Returns the base environment. This traverses the parent chain locally."""
|
|
295
|
+
return (
|
|
296
|
+
getattr(self.parent, "base_env", None) if self.parent is not None else None
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def __repr__(self):
|
|
300
|
+
"""String representation."""
|
|
301
|
+
try:
|
|
302
|
+
if hasattr(self, "_actor") and self._actor is not None:
|
|
303
|
+
return self._ray.get(self._actor.__repr__.remote())
|
|
304
|
+
else:
|
|
305
|
+
return f"{self.__class__.__name__}(actor=None)"
|
|
306
|
+
except Exception:
|
|
307
|
+
return f"{self.__class__.__name__}(actor={getattr(self, '_actor', 'None')})"
|
|
308
|
+
|
|
309
|
+
# Properties - access via generic attribute getter since Ray doesn't support direct property access
|
|
310
|
+
@property
|
|
311
|
+
def device(self):
|
|
312
|
+
"""Get device property."""
|
|
313
|
+
return getattr(self, "_device", None)
|
|
314
|
+
|
|
315
|
+
@device.setter
|
|
316
|
+
def device(self, value):
|
|
317
|
+
"""Set device property."""
|
|
318
|
+
raise NotImplementedError(
|
|
319
|
+
f"device setter is not implemented for {self.__class__.__name__}. Use transform.to() instead."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# TensorDictPrimer methods
|
|
323
|
+
def init(self, tensordict: TensorDictBase | None):
|
|
324
|
+
"""Initialize."""
|
|
325
|
+
return self._ray.get(self._actor.init.remote(tensordict))
|
|
326
|
+
|
|
327
|
+
@_map_input_output_device
|
|
328
|
+
def _reset_func(
|
|
329
|
+
self, tensordict: TensorDictBase | None, tensordict_reset: TensorDictBase | None
|
|
330
|
+
) -> TensorDictBase | None:
|
|
331
|
+
"""Reset function."""
|
|
332
|
+
result = self._ray.get(
|
|
333
|
+
self._actor._reset_func.remote(tensordict, tensordict_reset)
|
|
334
|
+
)
|
|
335
|
+
return result
|
|
336
|
+
|
|
337
|
+
@_map_input_output_device
|
|
338
|
+
def _reset(
|
|
339
|
+
self, tensordict: TensorDictBase | None, tensordict_reset: TensorDictBase | None
|
|
340
|
+
) -> TensorDictBase | None:
|
|
341
|
+
"""Reset method for TensorDictPrimer."""
|
|
342
|
+
return self._ray.get(self._actor._reset.remote(tensordict, tensordict_reset))
|
|
343
|
+
|
|
344
|
+
@_map_input_output_device
|
|
345
|
+
def _reset_env_preprocess(
|
|
346
|
+
self, tensordict: TensorDictBase | None
|
|
347
|
+
) -> TensorDictBase | None:
|
|
348
|
+
"""Reset environment preprocess - crucial for call_before_env_reset=True."""
|
|
349
|
+
return self._ray.get(self._actor._reset_env_preprocess.remote(tensordict))
|
|
350
|
+
|
|
351
|
+
def close(self):
|
|
352
|
+
"""Close the transform."""
|
|
353
|
+
return self._ray.get(self._actor.close.remote())
|
|
354
|
+
|
|
355
|
+
@_map_input_output_device
|
|
356
|
+
def _apply_transform(self, obs: torch.Tensor | None) -> torch.Tensor | None:
|
|
357
|
+
"""Apply transform."""
|
|
358
|
+
return self._ray.get(self._actor._apply_transform.remote(obs))
|
|
359
|
+
|
|
360
|
+
@_map_input_output_device
|
|
361
|
+
def _call(self, next_tensordict: TensorDictBase | None) -> TensorDictBase | None:
|
|
362
|
+
"""Call method."""
|
|
363
|
+
return self._ray.get(self._actor._call.remote(next_tensordict))
|
|
364
|
+
|
|
365
|
+
@_map_input_output_device
|
|
366
|
+
def forward(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
|
|
367
|
+
"""Forward pass."""
|
|
368
|
+
return self._ray.get(self._actor.forward.remote(tensordict))
|
|
369
|
+
|
|
370
|
+
@_map_input_output_device
|
|
371
|
+
def _inv_apply_transform(
|
|
372
|
+
self, state: TensorDictBase | None
|
|
373
|
+
) -> TensorDictBase | None:
|
|
374
|
+
"""Inverse apply transform."""
|
|
375
|
+
return self._ray.get(self._actor._inv_apply_transform.remote(state))
|
|
376
|
+
|
|
377
|
+
@_map_input_output_device
|
|
378
|
+
def _inv_call(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
|
|
379
|
+
"""Inverse call."""
|
|
380
|
+
return self._ray.get(self._actor._inv_call.remote(tensordict))
|
|
381
|
+
|
|
382
|
+
@_map_input_output_device
|
|
383
|
+
def inv(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
|
|
384
|
+
"""Inverse."""
|
|
385
|
+
return self._ray.get(self._actor.inv.remote(tensordict))
|
|
386
|
+
|
|
387
|
+
@_map_input_output_device
|
|
388
|
+
def _step(
|
|
389
|
+
self, tensordict: TensorDictBase | None, next_tensordict: TensorDictBase | None
|
|
390
|
+
) -> TensorDictBase | None:
|
|
391
|
+
"""Step method."""
|
|
392
|
+
return self._ray.get(self._actor._step.remote(tensordict, next_tensordict))
|
|
393
|
+
|
|
394
|
+
def transform_env_device(self, device):
|
|
395
|
+
"""Transform environment device."""
|
|
396
|
+
return self._ray.get(self._actor.transform_env_device.remote(device))
|
|
397
|
+
|
|
398
|
+
def transform_env_batch_size(self, batch_size):
|
|
399
|
+
"""Transform environment batch size."""
|
|
400
|
+
return self._ray.get(self._actor.transform_env_batch_size.remote(batch_size))
|
|
401
|
+
|
|
402
|
+
@_map_input_output_device
|
|
403
|
+
def transform_output_spec(self, output_spec):
|
|
404
|
+
"""Transform output spec."""
|
|
405
|
+
return self._ray.get(self._actor.transform_output_spec.remote(output_spec))
|
|
406
|
+
|
|
407
|
+
@_map_input_output_device
|
|
408
|
+
def transform_input_spec(self, input_spec):
|
|
409
|
+
"""Transform input spec."""
|
|
410
|
+
return self._ray.get(self._actor.transform_input_spec.remote(input_spec))
|
|
411
|
+
|
|
412
|
+
@_map_input_output_device
|
|
413
|
+
def transform_observation_spec(self, observation_spec):
|
|
414
|
+
"""Transform observation spec."""
|
|
415
|
+
return self._ray.get(
|
|
416
|
+
self._actor.transform_observation_spec.remote(observation_spec)
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
@_map_input_output_device
|
|
420
|
+
def transform_reward_spec(self, reward_spec):
|
|
421
|
+
"""Transform reward spec."""
|
|
422
|
+
return self._ray.get(self._actor.transform_reward_spec.remote(reward_spec))
|
|
423
|
+
|
|
424
|
+
@_map_input_output_device
|
|
425
|
+
def transform_done_spec(self, done_spec):
|
|
426
|
+
"""Transform done spec."""
|
|
427
|
+
return self._ray.get(self._actor.transform_done_spec.remote(done_spec))
|
|
428
|
+
|
|
429
|
+
@_map_input_output_device
|
|
430
|
+
def transform_action_spec(self, action_spec):
|
|
431
|
+
"""Transform action spec."""
|
|
432
|
+
return self._ray.get(self._actor.transform_action_spec.remote(action_spec))
|
|
433
|
+
|
|
434
|
+
@_map_input_output_device
|
|
435
|
+
def transform_state_spec(self, state_spec):
|
|
436
|
+
"""Transform state spec."""
|
|
437
|
+
return self._ray.get(self._actor.transform_state_spec.remote(state_spec))
|
|
438
|
+
|
|
439
|
+
def dump(self, **kwargs):
|
|
440
|
+
"""Dump method."""
|
|
441
|
+
return self._ray.get(self._actor.dump.remote(**kwargs))
|
|
442
|
+
|
|
443
|
+
def set_missing_tolerance(self, mode=False):
|
|
444
|
+
"""Set missing tolerance."""
|
|
445
|
+
return self._ray.get(self._actor.set_missing_tolerance.remote(mode))
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def missing_tolerance(self):
|
|
449
|
+
"""Get missing tolerance."""
|
|
450
|
+
return self._ray.get(self._actor.missing_tolerance.remote())
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def primers(self):
|
|
454
|
+
"""Get primers."""
|
|
455
|
+
return self._ray.get(self._actor.__getattribute__.remote("primers"))
|
|
456
|
+
|
|
457
|
+
@primers.setter
|
|
458
|
+
def primers(self, value):
|
|
459
|
+
"""Set primers."""
|
|
460
|
+
self.__dict__["_primers"] = value
|
|
461
|
+
if hasattr(self, "_actor"):
|
|
462
|
+
self._ray.get(self._actor._set_attr.remote("primers", value))
|
|
463
|
+
|
|
464
|
+
def to(self, *args, **kwargs):
|
|
465
|
+
"""Move to device."""
|
|
466
|
+
# Parse the device from args/kwargs like torch does
|
|
467
|
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
|
468
|
+
*args, **kwargs
|
|
469
|
+
)
|
|
470
|
+
if device is not None:
|
|
471
|
+
self._device = device
|
|
472
|
+
# Don't delegate to remote actor - just register device locally
|
|
473
|
+
return super().to(*args, **kwargs)
|
|
474
|
+
|
|
475
|
+
# Properties that should be accessed from the remote actor
|
|
476
|
+
@property
|
|
477
|
+
def in_keys(self):
|
|
478
|
+
"""Get in_keys property."""
|
|
479
|
+
return self._ray.get(self._actor.__getattribute__.remote("in_keys"))
|
|
480
|
+
|
|
481
|
+
@in_keys.setter
|
|
482
|
+
def in_keys(self, value):
|
|
483
|
+
"""Set in_keys property."""
|
|
484
|
+
self.__dict__["_in_keys"] = value
|
|
485
|
+
if hasattr(self, "_actor"):
|
|
486
|
+
self._ray.get(self._actor._set_attr.remote("in_keys", value))
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def out_keys(self):
|
|
490
|
+
"""Get out_keys property."""
|
|
491
|
+
return self._ray.get(self._actor.__getattribute__.remote("out_keys"))
|
|
492
|
+
|
|
493
|
+
@out_keys.setter
|
|
494
|
+
def out_keys(self, value):
|
|
495
|
+
"""Set out_keys property."""
|
|
496
|
+
self.__dict__["_out_keys"] = value
|
|
497
|
+
if hasattr(self, "_actor"):
|
|
498
|
+
self._ray.get(self._actor._set_attr.remote("out_keys", value))
|
|
499
|
+
|
|
500
|
+
@property
|
|
501
|
+
def in_keys_inv(self):
|
|
502
|
+
"""Get in_keys_inv property."""
|
|
503
|
+
return self._ray.get(self._actor.__getattribute__.remote("in_keys_inv"))
|
|
504
|
+
|
|
505
|
+
@in_keys_inv.setter
|
|
506
|
+
def in_keys_inv(self, value):
|
|
507
|
+
"""Set in_keys_inv property."""
|
|
508
|
+
self.__dict__["_in_keys_inv"] = value
|
|
509
|
+
if hasattr(self, "_actor"):
|
|
510
|
+
self._ray.get(self._actor._set_attr.remote("in_keys_inv", value))
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def out_keys_inv(self):
|
|
514
|
+
"""Get out_keys_inv property."""
|
|
515
|
+
return self._ray.get(self._actor.__getattribute__.remote("out_keys_inv"))
|
|
516
|
+
|
|
517
|
+
@out_keys_inv.setter
|
|
518
|
+
def out_keys_inv(self, value):
|
|
519
|
+
"""Set out_keys_inv property."""
|
|
520
|
+
self.__dict__["_out_keys_inv"] = value
|
|
521
|
+
if hasattr(self, "_actor"):
|
|
522
|
+
self._ray.get(self._actor._set_attr.remote("out_keys_inv", value))
|
|
523
|
+
|
|
524
|
+
# Generic attribute access for any remaining attributes
|
|
525
|
+
def __getattr__(self, name):
|
|
526
|
+
"""Get attribute from the remote actor.
|
|
527
|
+
|
|
528
|
+
This method should only be called for attributes that don't exist locally
|
|
529
|
+
and should be delegated to the remote actor (inward-facing).
|
|
530
|
+
|
|
531
|
+
Outward-facing attributes (parent, container, base_env, etc.) should be handled
|
|
532
|
+
by the Transform base class and never reach this method.
|
|
533
|
+
"""
|
|
534
|
+
# Upward-facing attributes that should never be delegated to remote actor
|
|
535
|
+
upward_attrs = {"parent", "container", "base_env", "_parent", "_container"}
|
|
536
|
+
|
|
537
|
+
if name in upward_attrs:
|
|
538
|
+
# These should be handled by the local Transform implementation
|
|
539
|
+
raise AttributeError(
|
|
540
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Only delegate to remote actor if we're sure this is an inward-facing attribute
|
|
544
|
+
# and the actor is properly initialized
|
|
545
|
+
actor = self.__dict__.get("_actor", None)
|
|
546
|
+
if actor is None:
|
|
547
|
+
raise AttributeError(
|
|
548
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Only delegate specific DataLoadingPrimer methods/attributes to the remote actor
|
|
552
|
+
# This is a whitelist approach to be more conservative
|
|
553
|
+
delegated_methods = {
|
|
554
|
+
# DataLoadingPrimer methods that should be called on the remote actor
|
|
555
|
+
"_call",
|
|
556
|
+
"_reset",
|
|
557
|
+
"_inv_call",
|
|
558
|
+
"forward",
|
|
559
|
+
"inv",
|
|
560
|
+
"_apply_transform",
|
|
561
|
+
"_inv_apply_transform",
|
|
562
|
+
"_reset_func",
|
|
563
|
+
"init", # TensorDictPrimer specific methods
|
|
564
|
+
"primers",
|
|
565
|
+
"dataloader", # Properties
|
|
566
|
+
# Add other specific methods that should be delegated as needed
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
if name in delegated_methods:
|
|
570
|
+
try:
|
|
571
|
+
result = self._ray.get(getattr(actor, name).remote())
|
|
572
|
+
# If it's a method, wrap it to make remote calls
|
|
573
|
+
if callable(result):
|
|
574
|
+
return lambda *args, **kwargs: self._ray.get(
|
|
575
|
+
getattr(actor, name).remote(*args, **kwargs)
|
|
576
|
+
)
|
|
577
|
+
return result
|
|
578
|
+
except (AttributeError, TypeError):
|
|
579
|
+
# If that fails, it might be a callable method
|
|
580
|
+
try:
|
|
581
|
+
remote_method = getattr(actor, name)
|
|
582
|
+
return lambda *args, **kwargs: self._ray.get(
|
|
583
|
+
remote_method.remote(*args, **kwargs)
|
|
584
|
+
)
|
|
585
|
+
except AttributeError:
|
|
586
|
+
pass
|
|
587
|
+
|
|
588
|
+
# If not in our whitelist, don't delegate to remote actor
|
|
589
|
+
raise AttributeError(
|
|
590
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
def __setattr__(self, name, value):
|
|
594
|
+
"""Set attribute on the remote actor or locally."""
|
|
595
|
+
# Local attributes that should never be delegated to remote actor
|
|
596
|
+
local_attrs = {
|
|
597
|
+
"_actor",
|
|
598
|
+
"_ray",
|
|
599
|
+
"_parent",
|
|
600
|
+
"_container",
|
|
601
|
+
"_missing_tolerance",
|
|
602
|
+
"_in_keys",
|
|
603
|
+
"_out_keys",
|
|
604
|
+
"_in_keys_inv",
|
|
605
|
+
"_out_keys_inv",
|
|
606
|
+
"in_keys",
|
|
607
|
+
"out_keys",
|
|
608
|
+
"in_keys_inv",
|
|
609
|
+
"out_keys_inv",
|
|
610
|
+
"_modules",
|
|
611
|
+
"_parameters",
|
|
612
|
+
"_buffers",
|
|
613
|
+
"_device",
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
if name in local_attrs:
|
|
617
|
+
super().__setattr__(name, value)
|
|
618
|
+
else:
|
|
619
|
+
# Try to set on remote actor for other attributes
|
|
620
|
+
try:
|
|
621
|
+
if hasattr(self, "_actor") and self._actor is not None:
|
|
622
|
+
self._ray.get(self._actor._set_attr.remote(name, value))
|
|
623
|
+
else:
|
|
624
|
+
super().__setattr__(name, value)
|
|
625
|
+
except Exception:
|
|
626
|
+
# Fall back to local setting for attributes that can't be set remotely
|
|
627
|
+
super().__setattr__(name, value)
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
class _RayServiceMetaClass(type):
|
|
631
|
+
"""Metaclass that enables dynamic class selection based on use_ray_service parameter.
|
|
632
|
+
|
|
633
|
+
This metaclass allows a class to dynamically return either itself or a Ray-based
|
|
634
|
+
alternative class when instantiated with use_ray_service=True.
|
|
635
|
+
|
|
636
|
+
Usage:
|
|
637
|
+
>>> class MyRayClass():
|
|
638
|
+
... def __init__(self, **kwargs):
|
|
639
|
+
... ...
|
|
640
|
+
...
|
|
641
|
+
>>> class MyClass(metaclass=_RayServiceMetaClass):
|
|
642
|
+
... _RayServiceClass = MyRayClass
|
|
643
|
+
...
|
|
644
|
+
... def __init__(self, use_ray_service=False, **kwargs):
|
|
645
|
+
... # Regular implementation
|
|
646
|
+
... pass
|
|
647
|
+
...
|
|
648
|
+
>>> # Returns MyClass instance
|
|
649
|
+
>>> obj1 = MyClass(use_ray_service=False)
|
|
650
|
+
>>>
|
|
651
|
+
>>> # Returns MyRayClass instance
|
|
652
|
+
>>> obj2 = MyClass(use_ray_service=True)
|
|
653
|
+
"""
|
|
654
|
+
|
|
655
|
+
def __call__(cls, *args, use_ray_service=False, **kwargs):
|
|
656
|
+
if use_ray_service:
|
|
657
|
+
if not hasattr(cls, "_RayServiceClass"):
|
|
658
|
+
raise ValueError(
|
|
659
|
+
f"Class {cls.__name__} does not have a _RayServiceClass attribute"
|
|
660
|
+
)
|
|
661
|
+
return cls._RayServiceClass(*args, **kwargs)
|
|
662
|
+
else:
|
|
663
|
+
return super().__call__(*args, **kwargs)
|